Sinabs#
import sinabs.activation as sa
import sinabs.layers as sl
import torch.nn as nn
import torch
from tqdm.notebook import tqdm
hidden_dim1 = 256
hidden_dim2 = 128
class SNN(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(28, hidden_dim1)
self.spike1 = sl.LIF(tau_mem=10.0, spike_fn=sa.SingleSpike)
self.linear2 = nn.Linear(hidden_dim1, hidden_dim2)
self.spike2 = sl.LIF(tau_mem=10.0, spike_fn=sa.SingleSpike)
self.linear3 = nn.Linear(hidden_dim2, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out1 = self.spike1(self.linear1(x))
out2 = self.spike2(self.linear2(out1))
out3 = self.linear3(out2)
return out3, (out1, out2)
/home/docs/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
from torchvision import datasets, transforms
batch_size = 128
trainset = datasets.MNIST(
root="../data/", train=True, transform=transforms.ToTensor(), download=True
);
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, drop_last=True, num_workers=4
);
testset = datasets.MNIST(
root="../data/", train=False, transform=transforms.ToTensor(), download=True
);
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, drop_last=True, num_workers=4
);
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
62%|██████▏ | 6193152/9912422 [00:00<00:00, 61761944.65it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 77743869.88it/s]
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 120653081.50it/s]
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 28011581.47it/s]
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
100%|██████████| 4542/4542 [00:00<00:00, 6955286.15it/s]
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
/home/docs/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
import sinabs
import torchmetrics
from tqdm.notebook import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
snn = SNN().to(device)
optim = torch.optim.Adam(snn.parameters())
criterion = torch.nn.functional.cross_entropy
accuracy = torchmetrics.Accuracy("multiclass", num_classes=10).to(device)
for epoch in range(3):
losses = []
accuracy.reset()
for x, y in tqdm(trainloader):
x, y = x.to(device), y.to(device)
sinabs.reset_states(snn)
optim.zero_grad()
# we squeeze the image channel dimension
output, (out1, out2) = snn(x.squeeze())
y_hat = output.sum(1) # we sum over time
loss = criterion(y_hat, y)
losses.append(loss)
batch_stats = accuracy(y_hat, y)
loss.backward()
optim.step()
print(
f"Epoch {epoch}: loss {torch.stack(losses).mean()} training accuracy {accuracy.compute()}"
)
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
/tmp/ipykernel_254/1006095978.py in <module>
13 losses = []
14 accuracy.reset()
---> 15 for x, y in tqdm(trainloader):
16 x, y = x.to(device), y.to(device)
17 sinabs.reset_states(snn)
~/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/tqdm/notebook.py in __init__(self, *args, **kwargs)
236 unit_scale = 1 if self.unit_scale is True else self.unit_scale or 1
237 total = self.total * unit_scale if self.total else self.total
--> 238 self.container = self.status_printer(self.fp, total, self.desc, self.ncols)
239 self.container.pbar = proxy(self)
240 self.displayed = False
~/checkouts/readthedocs.org/user_builds/snnmetrics/envs/latest/lib/python3.7/site-packages/tqdm/notebook.py in status_printer(_, total, desc, ncols)
111 # Prepare IPython progress bar
112 if IProgress is None: # #187 #451 #558 #872
--> 113 raise ImportError(WARN_NOIPYW)
114 if total:
115 pbar = IProgress(min=0, max=total)
ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
import snnmetrics as sm
synops1 = sm.SynOps(fanout=hidden_dim1)
synops2 = sm.SynOps(fanout=hidden_dim2)
snn.eval()
losses = []
accuracy.reset()
for x, y in tqdm(testloader):
x, y = x.to(device), y.to(device)
sinabs.reset_states(snn)
output, (out1, out2) = snn(x.squeeze()) # we squeeze the single channel dimension
y_hat = output.sum(1) # we sum over time
batch_syn1 = synops1(out1.sum(1))
batch_syn2 = synops2(out2.sum(1))
batch_acc = accuracy(y_hat, y)
print(
f"Test accuracy {accuracy.compute()}, synops layer 1 {synops1.compute()['synops']}, layer 2 {synops2.compute()['synops']}"
)
synops1.compute()["synops_per_neuron"] / hidden_dim1
out1.sum(1).mean(0) * hidden_dim1
batch_syn1["synops_per_neuron"]