# Sinabs

In [None]:
import sinabs.activation as sa
import sinabs.layers as sl
import torch.nn as nn
import torch
from tqdm.notebook import tqdm


hidden_dim1 = 256
hidden_dim2 = 128


class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(28, hidden_dim1)
        self.spike1 = sl.LIF(tau_mem=10.0, spike_fn=sa.SingleSpike)
        self.linear2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.spike2 = sl.LIF(tau_mem=10.0, spike_fn=sa.SingleSpike)
        self.linear3 = nn.Linear(hidden_dim2, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out1 = self.spike1(self.linear1(x))
        out2 = self.spike2(self.linear2(out1))
        out3 = self.linear3(out2)
        return out3, (out1, out2)

In [None]:
from torchvision import datasets, transforms

batch_size = 128

trainset = datasets.MNIST(
    root="../data/", train=True, transform=transforms.ToTensor(), download=True
);
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, drop_last=True, num_workers=4
);

testset = datasets.MNIST(
    root="../data/", train=False, transform=transforms.ToTensor(), download=True
);
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, drop_last=True, num_workers=4
);

In [None]:
import sinabs
import torchmetrics
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

snn = SNN().to(device)
optim = torch.optim.Adam(snn.parameters())
criterion = torch.nn.functional.cross_entropy
accuracy = torchmetrics.Accuracy("multiclass", num_classes=10).to(device)

for epoch in range(3):
    losses = []
    accuracy.reset()
    for x, y in tqdm(trainloader):
        x, y = x.to(device), y.to(device)
        sinabs.reset_states(snn)
        optim.zero_grad()
        # we squeeze the image channel dimension
        output, (out1, out2) = snn(x.squeeze())
        y_hat = output.sum(1)  # we sum over time
        loss = criterion(y_hat, y)
        losses.append(loss)
        batch_stats = accuracy(y_hat, y)
        loss.backward()
        optim.step()
    print(
        f"Epoch {epoch}: loss {torch.stack(losses).mean()} training accuracy {accuracy.compute()}"
    )

In [None]:
import snnmetrics as sm

synops1 = sm.SynOps(fanout=hidden_dim1)
synops2 = sm.SynOps(fanout=hidden_dim2)

snn.eval()
losses = []
accuracy.reset()
for x, y in tqdm(testloader):
    x, y = x.to(device), y.to(device)
    sinabs.reset_states(snn)
    output, (out1, out2) = snn(x.squeeze())  # we squeeze the single channel dimension
    y_hat = output.sum(1)  # we sum over time
    batch_syn1 = synops1(out1.sum(1))
    batch_syn2 = synops2(out2.sum(1))
    batch_acc = accuracy(y_hat, y)
print(
    f"Test accuracy {accuracy.compute()}, synops layer 1 {synops1.compute()['synops']}, layer 2 {synops2.compute()['synops']}"
)

In [None]:
synops1.compute()["synops_per_neuron"] / hidden_dim1

In [None]:
out1.sum(1).mean(0) * hidden_dim1

In [None]:
batch_syn1["synops_per_neuron"]