Source code for snnmetrics.synops
from typing import Optional, Union
import torch
from torchmetrics.metric import Metric
[docs]class SynOps(Metric):
"""A metric that calculates the number of synaptic operations, both for every neuron in the
layer and for the sum over all neurons in the layer. The number of synaptic operations is
defined as number of spikes times the fanout, which are the number of connections each neuron
has to the next layer. Whereas the fanout using fully-connected connectivity is equal to the
number of neurons (or features) in the next layer, the situation for convolutional layers is
more complex. Parameters such as stride, kernel size, grouping and others all have influence on
convolutional fanout. When you think about a convolutional kernel that is applied to every
receptive field, the neurons at the edge of the input will be seen less often (given a padding
of zero) than neurons in the middle. The convolutional fanout can be approximated when the
spatial input size is large enough.
Parameters:
fanout: Can either be a float or a tensor of shape (C,H,W).
"""
is_differentiable: bool = True
higher_is_better: Optional[bool] = None
full_state_update: bool = False
def __init__(
self, fanout: Union[float, torch.Tensor], sample_time: Optional[float] = None
):
super().__init__()
self.fanout = torch.as_tensor(fanout)
self.sample_time = sample_time
self.add_state(
"synops_per_neuron",
default=[]
if self.fanout.shape == torch.Size([])
else torch.zeros(self.fanout.shape),
dist_reduce_fx="sum",
)
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, output: torch.Tensor):
if self.fanout.shape == torch.Size([]):
self.synops_per_neuron.append(output.sum(0) * self.fanout)
else:
self.synops_per_neuron += output.sum(0) * self.fanout
self.total += output.shape[0]
[docs] def compute(self):
if self.fanout.shape == torch.Size([]):
synops = torch.stack(self.synops_per_neuron).sum(0) / self.total
else:
synops = self.synops_per_neuron / self.total
result_dict = {"synops_per_neuron": synops, "synops": synops.sum()}
if self.sample_time is not None:
result_dict["synops/s"] = synops.mean() / self.sample_time
return result_dict