Sinabs#

import sinabs.activation as sa
import sinabs.layers as sl
import torch.nn as nn
import torch
from tqdm.notebook import tqdm


hidden_dim1 = 256
hidden_dim2 = 128


class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(28, hidden_dim1)
        self.spike1 = sl.LIF(tau_mem=10.0, spike_fn=sa.SingleSpike)
        self.linear2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.spike2 = sl.LIF(tau_mem=10.0, spike_fn=sa.SingleSpike)
        self.linear3 = nn.Linear(hidden_dim2, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out1 = self.spike1(self.linear1(x))
        out2 = self.spike2(self.linear2(out1))
        out3 = self.linear3(out2)
        return out3, (out1, out2)
/home/docs/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
from torchvision import datasets, transforms

batch_size = 128

trainset = datasets.MNIST(
    root="../data/", train=True, transform=transforms.ToTensor(), download=True
);
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, drop_last=True, num_workers=4
);

testset = datasets.MNIST(
    root="../data/", train=False, transform=transforms.ToTensor(), download=True
);
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, drop_last=True, num_workers=4
);
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/9912422 [00:00<?, ?it/s]
 62%|██████▏   | 6193152/9912422 [00:00<00:00, 61761944.65it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 77743869.88it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 120653081.50it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/1648877 [00:00<?, ?it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 28011581.47it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/4542 [00:00<?, ?it/s]
100%|██████████| 4542/4542 [00:00<00:00, 6955286.15it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
/home/docs/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
import sinabs
import torchmetrics
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

snn = SNN().to(device)
optim = torch.optim.Adam(snn.parameters())
criterion = torch.nn.functional.cross_entropy
accuracy = torchmetrics.Accuracy("multiclass", num_classes=10).to(device)

for epoch in range(3):
    losses = []
    accuracy.reset()
    for x, y in tqdm(trainloader):
        x, y = x.to(device), y.to(device)
        sinabs.reset_states(snn)
        optim.zero_grad()
        # we squeeze the image channel dimension
        output, (out1, out2) = snn(x.squeeze())
        y_hat = output.sum(1)  # we sum over time
        loss = criterion(y_hat, y)
        losses.append(loss)
        batch_stats = accuracy(y_hat, y)
        loss.backward()
        optim.step()
    print(
        f"Epoch {epoch}: loss {torch.stack(losses).mean()} training accuracy {accuracy.compute()}"
    )
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
/tmp/ipykernel_254/1006095978.py in <module>
     13     losses = []
     14     accuracy.reset()
---> 15     for x, y in tqdm(trainloader):
     16         x, y = x.to(device), y.to(device)
     17         sinabs.reset_states(snn)

~/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/tqdm/notebook.py in __init__(self, *args, **kwargs)
    236         unit_scale = 1 if self.unit_scale is True else self.unit_scale or 1
    237         total = self.total * unit_scale if self.total else self.total
--> 238         self.container = self.status_printer(self.fp, total, self.desc, self.ncols)
    239         self.container.pbar = proxy(self)
    240         self.displayed = False

~/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/tqdm/notebook.py in status_printer(_, total, desc, ncols)
    111         # Prepare IPython progress bar
    112         if IProgress is None:  # #187 #451 #558 #872
--> 113             raise ImportError(WARN_NOIPYW)
    114         if total:
    115             pbar = IProgress(min=0, max=total)

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
import snnmetrics as sm

synops1 = sm.SynOps(fanout=hidden_dim1)
synops2 = sm.SynOps(fanout=hidden_dim2)

snn.eval()
losses = []
accuracy.reset()
for x, y in tqdm(testloader):
    x, y = x.to(device), y.to(device)
    sinabs.reset_states(snn)
    output, (out1, out2) = snn(x.squeeze())  # we squeeze the single channel dimension
    y_hat = output.sum(1)  # we sum over time
    batch_syn1 = synops1(out1.sum(1))
    batch_syn2 = synops2(out2.sum(1))
    batch_acc = accuracy(y_hat, y)
print(
    f"Test accuracy {accuracy.compute()}, synops layer 1 {synops1.compute()['synops']}, layer 2 {synops2.compute()['synops']}"
)
synops1.compute()["synops_per_neuron"] / hidden_dim1
out1.sum(1).mean(0) * hidden_dim1
batch_syn1["synops_per_neuron"]