def fit(self, train_set, test_set, lr=1e-4, epochs=10, batch_sz=1, print_every=40): train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=True, num_workers=2) test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2) N = train_set.__len__() # number of samples # self.loss = nn.MSELoss().to(device) self.loss = DecoderLoss(alpha=10).to(device) self.optimizer = optim.Adam(self.parameters(), lr=lr) n_batches = N // batch_sz train_costs, test_costs = [], [] for i in range(epochs): cost = 0 print("epoch:", i, "n_batches:", n_batches) for j, batch in enumerate(train_loader): cost += self.train_step( batch['net'].transpose(0, 1).to(device), # batch['net'].to(device), batch['stim'].to(device)) # try sending batch to GPU, then passing (then delete) del batch # test whether useful for clearing off GPU if j % print_every == 0: # costs and accuracies for test set test_cost = 0 for t, testB in enumerate(test_loader, 1): testB_cost = self.get_cost( testB['net'].transpose(0, 1).to(device), # testB['net'].to(device), testB['stim'].to(device)) test_cost += testB_cost test_cost /= t + 1 del testB print("cost: %f" % (test_cost)) # for plotting train_costs.append(cost / n_batches) test_costs.append(test_cost) # plot cost and accuracy progression fig, axes = plt.subplots(1) axes.plot(train_costs, label='training') axes.plot(test_costs, label='validation') axes.set_xlabel('Epoch') axes.set_ylabel('Cost') plt.legend() plt.show()
def fit(self, train_set, test_set, lr=1e-4, epochs=10, batch_sz=1, loss_alpha=10, print_every=40): train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=True, num_workers=2) test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2) N = train_set.__len__() # number of samples # DecoderLoss equivalent to MSE when alpha=0 (original default: 10) self.loss = DecoderLoss(alpha=loss_alpha).to(device) self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8) n_batches = N // batch_sz train_costs, test_costs = [], [] for i in range(epochs): cost = 0 print("epoch:", i, "n_batches:", n_batches) for j, batch in enumerate(train_loader): net, stim = batch['net'].to(device), batch['stim'].to(device) cost += self.train_step(net.transpose(0, 1), stim) del net, stim, batch if j % print_every == 0: # costs and accuracies for test set test_cost = 0 for t, testB in enumerate(test_loader, 1): net = testB['net'].to(device) stim = testB['stim'].to(device) testB_cost = self.get_cost(net.transpose(0, 1), stim) del net, stim, testB test_cost += testB_cost test_cost /= t + 1 print("cost: %f" % (test_cost)) # for plotting train_costs.append(cost / n_batches) test_costs.append(test_cost) # plot cost and accuracy progression fig, axes = plt.subplots(1) axes.plot(train_costs, label='training') axes.plot(test_costs, label='validation') axes.set_xlabel('Epoch') axes.set_ylabel('Cost') plt.legend() plt.show()
class RetinaDecoder(nn.Module): def __init__(self, pre_pool, grp_tempo_params, conv_params, crnn_cell_params, temp3d_stack_params, decode_params): super(RetinaDecoder, self).__init__() # layer parameters self.pre_pool = pre_pool self.grp_tempo_params = grp_tempo_params self.conv_params = conv_params self.crnn_cell_params = crnn_cell_params self.temp3d_stack_params = temp3d_stack_params self.decode_params = decode_params # create model and send to correct device (GPU if available) self.build() self.dv = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.to(self.dv) def build(self): # # # # # # # # # # ENCODER NETWORK # # # # # # # # # # encoder_mods = [] # pooling operation before any processing if 'op' in self.pre_pool: # skip by leaving param dict empty encoder_mods.append(make_pool3d_layer(self.pre_pool)) # Grouped Temporal CNN, operating on each cluster channel separately for p in self.grp_tempo_params: encoder_mods.append( TemporalConv3dStack(p['in'], p['out'], p.get('kernel', (2, 1, 1)), p.get('space_dilation', 1), p.get('groups', 1), p.get('dropout', 0), p.get('activation', nn.ReLU))) if 'pool' in p: encoder_mods.append(make_pool3d_layer(p['pool'])) # Spatial Only (non-causal) convolutional layers for p in self.conv_params: d, h, w = p.get('kernel', (1, 3, 3)) pad = (d // 2, h // 2, w // 2) encoder_mods.append( nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1), pad, p.get('dilation', 1), p.get('groups', 1), p.get('bias', True))) encoder_mods.append(nn.BatchNorm3d(p['out'])) encoder_mods.append(p.get('activation', nn.ReLU)()) if 'pool' in p: encoder_mods.append(make_pool3d_layer(p['pool'])) # Stack of Convolutional Recurrent Network(s) if len(self.crnn_cell_params) > 0: # swap time from depth dimension to first dimension for CRNN(s) # (N, C, T, H, W) -> (T, N, C, H, W) encoder_mods.append(Permuter((2, 0, 1, 3, 4))) for p in self.crnn_cell_params: # recurrenct convolutional cells (GRU or LSTM) encoder_mods.append( p.get('crnn_cell', crnns.ConvGRUCell_wnorm)(p['dims'], p['in_kernel'], p['out_kernel'], p['in'], p['out'], p.get('learn_initial', False), p.get('return_hidden', False))) if 'post_activation' in p: encoder_mods.append(p['post_activation']()) if len(self.crnn_cell_params) > 0: # swap time back to depth dimension following CRNN(s) # (T, N, C, H, W) -> (N, C, T, H, W) encoder_mods.append(Permuter((1, 2, 0, 3, 4))) # Temporal CNN for p in self.temp3d_stack_params: encoder_mods.append( TemporalConv3dStack(p['in'], p['out'], p.get('kernel', (2, 3, 3)), p.get('space_dilation', 1), p.get('groups', 1), p.get('dropout', 0), p.get('activation', nn.ReLU))) # package encoding layers as a Sequential network self.encoder_net = nn.Sequential(*encoder_mods) # # # # # # # # # # DECODER NETWORK # # # # # # # # # # decoder_mods = [] # Transpose Convolutional layers (upsampling) for p in self.decode_params: # unpack kernel etc dimensions d, h, w = p.get('kernel', (1, 3, 3)) st_d, st_h, st_w = p.get('stride', (1, 1, 1)) dil_d, dil_h, dil_w = p.get('dilation', (1, 1, 1)) # causal transpose if p.get('type', 'causal') == 'causal': decoder_mods.append( CausalTranspose3d(p['in'], p['out'], p['kernel'], p['stride'], p.get('groups', 1), p.get('bias', True), p.get('dilations', (1, 1, 1)))) # non-causal transpose elif p['type'] == 'trans': pad = (d // 2, h // 2, w // 2) decoder_mods.append( nn.ConvTranspose3d(p['in'], p['out'], p['kernel'], p['stride'], pad, pad, p.get('groups', 1), p.get('bias', True), p.get('dilations', (1, 1, 1)))) # plain convolution (spatial only) -> p['type'] == 'conv' else: pad = (d // 2, h // 2, w // 2) decoder_mods.append( nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1), pad, p.get('dilation', 1), p.get('groups', 1), p.get('bias', True))) decoder_mods.append(nn.BatchNorm3d(p['out'])) decoder_mods.append(p.get('activation', nn.Tanh)()) # package decoding layers as a Sequential network self.decoder_net = nn.Sequential(*decoder_mods) def forward(self, X): X = self.encoder_net(X) X = self.decoder_net(X) return X def fit(self, train_set, test_set, lr=1e-4, epochs=10, batch_sz=1, loss_alpha=10, loss_decay=1, print_every=0, peons=2): train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=True, num_workers=peons) test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=peons) N = train_set.__len__() # number of samples # DecoderLoss equivalent to MSE when alpha=0 (original default: 10) self.loss = DecoderLoss(alpha=loss_alpha, decay=loss_decay).to(self.dv) self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8) n_batches = np.ceil(N / batch_sz).astype('int') print_every = n_batches if print_every < 1 else print_every train_prog = None train_costs, test_costs = [], [] for i in range(epochs): cost = 0 print("epoch:", i, "n_batches:", n_batches) # start = 0 for j, batch in enumerate(train_loader): # print('time to load batch', timer.time()-start) # start = timer.time() net, stim = batch['net'].to(self.dv), batch['stim'].to(self.dv) cost += self.train_step(net, stim) del net, stim, batch # print('time to train', timer.time()-start) train_prog.step() if train_prog is not None else 0 if j % print_every == 0: test_prog = ProgressBar( np.ceil(test_set.__len__() / batch_sz).astype('int'), size=np.ceil(test_set.__len__() / batch_sz).astype('int'), label='validating: ') # costs and accuracies for test set test_cost = 0 for t, testB in enumerate(test_loader, 1): net = testB['net'].to(self.dv) stim = testB['stim'].to(self.dv) testB_cost = self.get_cost(net, stim) del net, stim, testB test_cost += testB_cost test_prog.step() test_cost /= t + 1 print("validation cost: %f" % (test_cost)) train_prog = ProgressBar(print_every, size=test_set.__len__() * 2 // batch_sz, label='training: ') train_prog.step() if j == 0 else 0 # hack, skipped batch # start = timer.time() # Decay DecoderLoss sparsity penalty self.loss.decay() # for plotting train_costs.append(cost / n_batches) test_costs.append(test_cost) # plot cost and accuracy progression fig, axes = plt.subplots(1) axes.plot(train_costs, label='training') axes.plot(test_costs, label='validation') axes.set_xlabel('Epoch') axes.set_ylabel('Cost') plt.legend() plt.show() def train_step(self, inputs, targets): self.train() # set the model to training mode self.optimizer.zero_grad() # Reset gradient # Forward decoded = self.forward(inputs) # (N, C, T, H, W) output = self.loss.forward( # swap time to second dimension -> (N, T, C, H, W) decoded.transpose(1, 2), targets) # Backward output.backward() # compute gradients self.optimizer.step() # Update parameters return output.item() # cost def get_cost(self, inputs, targets): self.eval() # set the model to testing mode self.optimizer.zero_grad() # Reset gradient with torch.no_grad(): # Forward decoded = self.forward(inputs) # (N, C, T, H, W) output = self.loss.forward( # swap time to second dimension -> (N, T, C, H, W) decoded.transpose(1, 2), targets) return output.item() def decode(self, sample_set): self.eval() # set the model to testing mode sample_loader = DataLoader(sample_set, batch_size=1, shuffle=True, num_workers=2) for i, sample in enumerate(sample_loader): with torch.no_grad(): # get stimulus prediction from network activity net = sample['net'].to(self.dv) decoded = self.forward(net) del net # Reduce out batch and channel dims, then put time last # (N, C, T, H, W) -> (H, W, T) decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0) net = sample['net'].squeeze().numpy().sum(axis=0) net = net.transpose(1, 2, 0) stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0) # synced scrollable videos of cell actity, decoding, and stimulus fig, ax = plt.subplots(1, 3, figsize=(17, 6)) net_stack = StackPlotter(ax[0], net, delta=1, vmin=0) deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1) stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1) fig.canvas.mpl_connect('scroll_event', net_stack.onscroll) fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll) fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll) ax[0].set_title('Network Recording') ax[1].set_title('Decoding') ax[2].set_title('Stimulus') fig.tight_layout() plt.show() again = input("Show another reconstruction? Enter 'n' to quit\n") if again == 'n': break def save_decodings(self, sample_set, name=None): self.eval() # set the model to testing mode sample_loader = DataLoader(sample_set, batch_size=1, num_workers=2) # make a parent output folder for this dataset if it doesn't exist outfold = os.path.join(sample_set.root_dir, 'outputs') if not os.path.isdir(outfold): os.mkdir(outfold) # prompt for name of and create this particular runs output folder while True: nametag = input("Decoding set name: ") if name is None else name name = None # if parameter name fails, get input next loop basefold = os.path.join(outfold, nametag) if not os.path.isdir(basefold): os.mkdir(basefold) break else: print('Folder exists, provide another name...') # generate decoding of every sample in given dataset for i, sample in enumerate(sample_loader): with torch.no_grad(): # get stimulus prediction from network activity net = sample['net'].to(self.dv) decoded = self.forward(net) del sample, net # Reduce out batch and channel dims # (T, N, C, H, W) -> (T, H, W) decoded = decoded.squeeze().cpu().numpy() # save into subfolder corresponding to originating network decofold = os.path.join( basefold, sample_set.rec_frame.iloc[i, 0], # net folder name ) if not os.path.isdir(decofold): os.mkdir(decofold) # .npy format np.save( # file name corresponding to stimulus os.path.join(decofold, sample_set.rec_frame.iloc[i, 1]), decoded)
def fit(self, train_set, test_set, lr=1e-4, epochs=10, batch_sz=1, loss_alpha=10, loss_decay=1, print_every=0, peons=2): train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=True, num_workers=peons) test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=peons) N = train_set.__len__() # number of samples # DecoderLoss equivalent to MSE when alpha=0 (original default: 10) self.loss = DecoderLoss(alpha=loss_alpha, decay=loss_decay).to(self.dv) self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8) n_batches = np.ceil(N / batch_sz).astype('int') print_every = n_batches if print_every < 1 else print_every train_prog = None train_costs, test_costs = [], [] for i in range(epochs): cost = 0 print("epoch:", i, "n_batches:", n_batches) # start = 0 for j, batch in enumerate(train_loader): # print('time to load batch', timer.time()-start) # start = timer.time() net, stim = batch['net'].to(self.dv), batch['stim'].to(self.dv) cost += self.train_step(net, stim) del net, stim, batch # print('time to train', timer.time()-start) train_prog.step() if train_prog is not None else 0 if j % print_every == 0: test_prog = ProgressBar( np.ceil(test_set.__len__() / batch_sz).astype('int'), size=np.ceil(test_set.__len__() / batch_sz).astype('int'), label='validating: ') # costs and accuracies for test set test_cost = 0 for t, testB in enumerate(test_loader, 1): net = testB['net'].to(self.dv) stim = testB['stim'].to(self.dv) testB_cost = self.get_cost(net, stim) del net, stim, testB test_cost += testB_cost test_prog.step() test_cost /= t + 1 print("validation cost: %f" % (test_cost)) train_prog = ProgressBar(print_every, size=test_set.__len__() * 2 // batch_sz, label='training: ') train_prog.step() if j == 0 else 0 # hack, skipped batch # start = timer.time() # Decay DecoderLoss sparsity penalty self.loss.decay() # for plotting train_costs.append(cost / n_batches) test_costs.append(test_cost) # plot cost and accuracy progression fig, axes = plt.subplots(1) axes.plot(train_costs, label='training') axes.plot(test_costs, label='validation') axes.set_xlabel('Epoch') axes.set_ylabel('Cost') plt.legend() plt.show()
class RetinaDecoder(nn.Module): def __init__(self, crnn_cell_params, crnn_cell=crnns.ConvGRUCell, learn_initial=False): super(RetinaDecoder, self).__init__() self.crnn_cell_params = crnn_cell_params self.crnn_cell = crnn_cell self.learn_initial = learn_initial self.build() self.to(device) def build(self): self.crnn_stack = nn.ModuleList() for i, params in enumerate(self.crnn_cell_params): # recurrenct convolutional cells (GRU or LSTM) self.crnn_stack.append( self.crnn_cell(*params, learn_initial=self.learn_initial)) self.reduce_conv = nn.Conv2d(params[-1], 1, (1, 1)) self.reduce_bnorm = nn.BatchNorm2d(1) def forward(self, X): # stacked convolutional recurrent cells for cell in self.crnn_stack: X, _ = cell(X) # reduce channel dimensionality to 1, frame by frame. frames = [] for frame in X: frames.append(self.reduce_bnorm(self.reduce_conv(frame))) X = torch.stack(frames, dim=0) del frames X = torch.tanh(X) return X def fit(self, train_set, test_set, lr=1e-4, epochs=10, batch_sz=1, print_every=40): train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=True, num_workers=2) test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2) N = train_set.__len__() # number of samples # self.loss = nn.MSELoss().to(device) self.loss = DecoderLoss(alpha=10).to(device) self.optimizer = optim.Adam(self.parameters(), lr=lr) n_batches = N // batch_sz train_costs, test_costs = [], [] for i in range(epochs): cost = 0 print("epoch:", i, "n_batches:", n_batches) for j, batch in enumerate(train_loader): cost += self.train_step( batch['net'].transpose(0, 1).to(device), # batch['net'].to(device), batch['stim'].to(device)) # try sending batch to GPU, then passing (then delete) del batch # test whether useful for clearing off GPU if j % print_every == 0: # costs and accuracies for test set test_cost = 0 for t, testB in enumerate(test_loader, 1): testB_cost = self.get_cost( testB['net'].transpose(0, 1).to(device), # testB['net'].to(device), testB['stim'].to(device)) test_cost += testB_cost test_cost /= t + 1 del testB print("cost: %f" % (test_cost)) # for plotting train_costs.append(cost / n_batches) test_costs.append(test_cost) # plot cost and accuracy progression fig, axes = plt.subplots(1) axes.plot(train_costs, label='training') axes.plot(test_costs, label='validation') axes.set_xlabel('Epoch') axes.set_ylabel('Cost') plt.legend() plt.show() def train_step(self, inputs, targets): self.train() # set the model to training mode self.optimizer.zero_grad() # Reset gradient # Forward decoded = self.forward(inputs) output = self.loss.forward( # swap batch to first dimension decoded.transpose(0, 1), targets) # Backward output.backward() # compute gradients self.optimizer.step() # Update parameters return output.item() # cost def get_cost(self, inputs, targets): self.eval() # set the model to testing mode self.optimizer.zero_grad() # Reset gradient with torch.no_grad(): # Forward decoded = self.forward(inputs) output = self.loss.forward( # swap batch to first dimension decoded.transpose(0, 1), targets) return output.item() def decode(self, sample_set): self.eval() # set the model to testing mode sample_loader = DataLoader(sample_set, batch_size=1, shuffle=True, num_workers=2) for i, sample in enumerate(sample_loader): with torch.no_grad(): # get stimulus prediction from network activity decoded = self.forward(sample['net'].to(device)) # Reduce out batch and channel dims, then put time last # (T, N, C, H, W) -> (H, W, T) decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0) net = sample['net'].squeeze().numpy().sum(axis=1) net = net.transpose(1, 2, 0) stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0) # synced scrollable videos of cell actity, decoding, and stimulus fig, ax = plt.subplots(1, 3) net_stack = StackPlotter(ax[0], net, delta=1, vmin=0) deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1) stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1) fig.canvas.mpl_connect('scroll_event', net_stack.onscroll) fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll) fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll) ax[0].set_title('Network Recording') ax[1].set_title('Decoding') ax[2].set_title('Stimulus') fig.tight_layout() plt.show() again = input("Show another reconstruction? Enter 'n' to quit\n") if again == 'n': break
class RetinaDecoder(nn.Module): def __init__(self, grp_conv_params, conv_params, crnn_cell_params, trans_params, post_conv_params, crnn_cell=crnns.ConvGRUCell, learn_initial=False): super(RetinaDecoder, self).__init__() # layer parameters self.grp_conv_params = grp_conv_params self.conv_params = conv_params self.crnn_cell_params = crnn_cell_params self.trans_params = trans_params self.post_conv_params = post_conv_params # ConvRNN settings self.crnn_cell = crnn_cell self.learn_initial = learn_initial # create model and send to GPU self.build() self.to(device) def build(self): # grouped convolutions self.grp_conv_layers = nn.ModuleList() self.grp_conv_bnorms = nn.ModuleList() # convolutions self.conv_layers = nn.ModuleList() self.conv_bnorms = nn.ModuleList() # recurrent convolutions self.crnn_stack = nn.ModuleList() # transpose convolutions self.trans_layers = nn.ModuleList() self.trans_bnorms = nn.ModuleList() # post-upsampling convolutions self.post_conv_layers = nn.ModuleList() self.post_conv_bnorms = nn.ModuleList() for params in self.grp_conv_params: # params: [in, out, (kernel), (stride), (dilation), groups] pad = ((params[2][0] * params[4][0] - 1) // 2, (params[2][1] * params[4][1] - 1) // 2, (params[2][2] * params[4][2] - 1) // 2) self.grp_conv_layers.append( nn.Conv3d(*params[:4], pad, *params[4:])) self.grp_conv_bnorms.append(nn.BatchNorm3d(params[1])) for params in self.conv_params: # params: [in, out, (kernel), (stride)] pad = (params[2][0] // 2, params[2][1] // 2, params[2][2] // 2) self.conv_layers.append(nn.Conv3d(*params, pad)) self.conv_bnorms.append(nn.BatchNorm3d(params[1])) for params in self.crnn_cell_params: # params: [(dims), (in_kernel), (out_kernel), in_C, out_C] # recurrenct convolutional cells (GRU or LSTM) self.crnn_stack.append( self.crnn_cell(*params, learn_initial=self.learn_initial)) for params in self.trans_params: pad = (params[2][0] // 2, params[2][1] // 2, params[2][2] // 2) self.trans_layers.append( nn.ConvTranspose3d(*params, padding=pad, output_padding=pad)) self.trans_bnorms.append(nn.BatchNorm3d(params[1])) for params in self.post_conv_params: # params: [in, out, (kernel), (stride)] pad = (params[2][0] // 2, params[2][1] // 2, params[2][2] // 2) self.post_conv_layers.append(nn.Conv3d(*params, pad)) self.post_conv_bnorms.append(nn.BatchNorm3d(params[1])) def forward(self, X): # time to 'depth' dimension X = X.permute(1, 2, 0, 3, 4) # to (N, C, T, H, W) # reduce spatial dimensionality (collate somatic information) X = F.avg_pool3d(X, (1, 2, 2)) # grouped (cluster siloed) convolutions for conv, bnorm in zip(self.grp_conv_layers, self.grp_conv_bnorms): X = torch.tanh(bnorm(conv(X))) # frame-by-frame (space only) convolutions for conv, bnorm in zip(self.conv_layers, self.conv_bnorms): X = torch.tanh(bnorm(conv(X))) X = F.avg_pool3d(X, (1, 2, 2)) # return to time dimension first for operations over time X = X.permute(2, 0, 1, 3, 4) # back to (T, N, C, H, w) # stacked convolutional recurrent cells for cell in self.crnn_stack: X, _ = cell(X) # expand back out in space and reduce channels X = X.permute(1, 2, 0, 3, 4) # time to 'depth' dimension for trans, bnorm in zip(self.trans_layers, self.trans_bnorms): X = torch.tanh(bnorm(trans(X))) # clean up with more spatial convs (try interleaving with trans next) # frame-by-frame (space only) convolutions for conv, bnorm in zip(self.post_conv_layers, self.post_conv_bnorms): X = torch.tanh(bnorm(conv(X))) X = X.permute(2, 0, 1, 3, 4) # back to (T, N, C, H, w) return X def fit(self, train_set, test_set, lr=1e-4, epochs=10, batch_sz=1, print_every=40): train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=True, num_workers=2) test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2) N = train_set.__len__() # number of samples # DecoderLoss equivalent to MSE when alpha=0 self.loss = DecoderLoss(alpha=10).to(device) self.optimizer = optim.Adam(self.parameters(), lr=lr) n_batches = N // batch_sz train_costs, test_costs = [], [] for i in range(epochs): cost = 0 print("epoch:", i, "n_batches:", n_batches) for j, batch in enumerate(train_loader): cost += self.train_step( batch['net'].transpose(0, 1).to(device), batch['stim'].to(device)) # try sending batch to GPU, then passing (then delete) del batch # test whether useful for clearing off GPU if j % print_every == 0: # costs and accuracies for test set test_cost = 0 for t, testB in enumerate(test_loader, 1): testB_cost = self.get_cost( testB['net'].transpose(0, 1).to(device), testB['stim'].to(device)) test_cost += testB_cost test_cost /= t + 1 del testB print("cost: %f" % (test_cost)) # for plotting train_costs.append(cost / n_batches) test_costs.append(test_cost) # plot cost and accuracy progression fig, axes = plt.subplots(1) axes.plot(train_costs, label='training') axes.plot(test_costs, label='validation') axes.set_xlabel('Epoch') axes.set_ylabel('Cost') plt.legend() plt.show() def train_step(self, inputs, targets): self.train() # set the model to training mode self.optimizer.zero_grad() # Reset gradient # Forward decoded = self.forward(inputs) output = self.loss.forward( # swap batch to first dimension decoded.transpose(0, 1), targets) # Backward output.backward() # compute gradients self.optimizer.step() # Update parameters return output.item() # cost def get_cost(self, inputs, targets): self.eval() # set the model to testing mode self.optimizer.zero_grad() # Reset gradient with torch.no_grad(): # Forward decoded = self.forward(inputs) output = self.loss.forward( # swap batch to first dimension decoded.transpose(0, 1), targets) return output.item() def decode(self, sample_set): self.eval() # set the model to testing mode sample_loader = DataLoader(sample_set, batch_size=1, shuffle=True, num_workers=2) for i, sample in enumerate(sample_loader): with torch.no_grad(): # get stimulus prediction from network activity decoded = self.forward(sample['net'].to(device)) # Reduce out batch and channel dims, then put time last # (T, N, C, H, W) -> (H, W, T) decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0) net = sample['net'].squeeze().numpy().sum(axis=1) net = net.transpose(1, 2, 0) stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0) # synced scrollable videos of cell actity, decoding, and stimulus fig, ax = plt.subplots(1, 3, figsize=(17, 6)) net_stack = StackPlotter(ax[0], net, delta=1, vmin=0) deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1) stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1) fig.canvas.mpl_connect('scroll_event', net_stack.onscroll) fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll) fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll) ax[0].set_title('Network Recording') ax[1].set_title('Decoding') ax[2].set_title('Stimulus') fig.tight_layout() plt.show() again = input("Show another reconstruction? Enter 'n' to quit\n") if again == 'n': break
class RetinaDecoder(nn.Module): def __init__(self, grp_tempo_params, conv_params, crnn_cell_params, temp3d_stack_params, trans_params, post_conv_params): super(RetinaDecoder, self).__init__() # layer parameters self.grp_tempo_params = grp_tempo_params self.conv_params = conv_params self.crnn_cell_params = crnn_cell_params self.temp3d_stack_params = temp3d_stack_params self.trans_params = trans_params self.post_conv_params = post_conv_params # create model and send to GPU self.build() self.to(device) def build(self): # grouped convolutions self.grp_tempo_layers = nn.ModuleList() # convolutions self.conv_layers = nn.ModuleList() self.conv_bnorms = nn.ModuleList() # recurrent convolutions self.crnn_stack = nn.ModuleList() # 3d temporal convolutions self.tempo3d_layers = nn.ModuleList() # transpose convolutions self.trans_layers = nn.ModuleList() self.trans_bnorms = nn.ModuleList() # post-upsampling convolutions self.post_conv_layers = nn.ModuleList() self.post_conv_bnorms = nn.ModuleList() for p in self.grp_tempo_params: self.grp_tempo_layers.append( TemporalConv3dStack(p['in'], p['out'], p.get('kernel', (2, 1, 1)), p.get('space_dilation', 1), p.get('groups', 1), p.get('dropout', 0), p.get('activation', nn.ReLU))) for p in self.conv_params: d, h, w = p.get('kernel', (1, 3, 3)) pad = (d // 2, h // 2, w // 2) self.conv_layers.append( nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1), pad, p.get('dilation', 1), p.get('groups', 1), p.get('bias', True))) self.conv_bnorms.append(nn.BatchNorm3d(p['out'])) for p in self.crnn_cell_params: # recurrenct convolutional cells (GRU or LSTM) self.crnn_stack.append( p.get('crnn_cell', crnns.ConvGRUCell_wnorm)(p['dims'], p['in_kernel'], p['out_kernel'], p['in'], p['out'], p.get('learn_initial', False))) for p in self.temp3d_stack_params: self.tempo3d_layers.append( TemporalConv3dStack(p['in'], p['out'], p.get('kernel', (2, 3, 3)), p.get('space_dilation', 1), p.get('groups', 1), p.get('dropout', 0), p.get('activation', nn.ReLU))) for p in self.trans_params: self.trans_layers.append( CausalTranspose3d(p['in'], p['out'], p['kernel'], p['stride'], p.get('groups', 1), p.get('bias', True), p.get('dilations', (1, 1, 1)))) self.trans_bnorms.append(nn.BatchNorm3d(p['out'])) for p in self.post_conv_params: d, h, w = p.get('kernel', (1, 3, 3)) pad = (d // 2, h // 2, w // 2) self.post_conv_layers.append( nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1), pad, p.get('dilation', 1), p.get('groups', 1), p.get('bias', True))) self.post_conv_bnorms.append(nn.BatchNorm3d(p['out'])) def forward(self, X): # time to 'depth' dimension X = X.permute(1, 2, 0, 3, 4) # to (N, C, T, H, W) # reduce spatial dimensionality (collate somatic information) X = F.avg_pool3d(X, (1, 2, 2)) # grouped (cluster siloed) temporal convolutions for tempo_conv in self.grp_tempo_layers: X = tempo_conv(X) # frame-by-frame (space only) convolutions for conv, bnorm in zip(self.conv_layers, self.conv_bnorms): X = torch.tanh(bnorm(conv(X))) X = F.avg_pool3d(X, (2, 2, 2)) # testing! (try max if using ReLU at the start) # X = F.max_pool3d(X, (1, 2, 2)) if len(self.crnn_stack) > 0: # return to time dimension first for operations over time X = X.permute(2, 0, 1, 3, 4) # back to (T, N, C, H, w) # stacked convolutional recurrent cells for cell in self.crnn_stack: X = cell(X) X = F.relu(X) # test # expand back out in space and reduce channels X = X.permute(1, 2, 0, 3, 4) # time to 'depth' dimension for tempo_conv in self.tempo3d_layers: X = tempo_conv(X) for trans, bnorm in zip(self.trans_layers, self.trans_bnorms): X = torch.tanh(bnorm(trans(X))) # clean up with more spatial convs (try interleaving with trans next) # frame-by-frame (space only) convolutions for conv, bnorm in zip(self.post_conv_layers, self.post_conv_bnorms): X = torch.tanh(bnorm(conv(X))) X = X.permute(2, 0, 1, 3, 4) # back to (T, N, C, H, w) return X def fit(self, train_set, test_set, lr=1e-4, epochs=10, batch_sz=1, loss_alpha=10, print_every=40): train_loader = DataLoader(train_set, batch_size=batch_sz, shuffle=True, num_workers=2) test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2) N = train_set.__len__() # number of samples # DecoderLoss equivalent to MSE when alpha=0 (original default: 10) self.loss = DecoderLoss(alpha=loss_alpha).to(device) self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8) n_batches = N // batch_sz train_costs, test_costs = [], [] for i in range(epochs): cost = 0 print("epoch:", i, "n_batches:", n_batches) for j, batch in enumerate(train_loader): net, stim = batch['net'].to(device), batch['stim'].to(device) cost += self.train_step(net.transpose(0, 1), stim) del net, stim, batch if j % print_every == 0: # costs and accuracies for test set test_cost = 0 for t, testB in enumerate(test_loader, 1): net = testB['net'].to(device) stim = testB['stim'].to(device) testB_cost = self.get_cost(net.transpose(0, 1), stim) del net, stim, testB test_cost += testB_cost test_cost /= t + 1 print("cost: %f" % (test_cost)) # for plotting train_costs.append(cost / n_batches) test_costs.append(test_cost) # plot cost and accuracy progression fig, axes = plt.subplots(1) axes.plot(train_costs, label='training') axes.plot(test_costs, label='validation') axes.set_xlabel('Epoch') axes.set_ylabel('Cost') plt.legend() plt.show() def train_step(self, inputs, targets): self.train() # set the model to training mode self.optimizer.zero_grad() # Reset gradient # Forward decoded = self.forward(inputs) output = self.loss.forward( # swap batch to first dimension decoded.transpose(0, 1), targets) # Backward output.backward() # compute gradients self.optimizer.step() # Update parameters return output.item() # cost def get_cost(self, inputs, targets): self.eval() # set the model to testing mode self.optimizer.zero_grad() # Reset gradient with torch.no_grad(): # Forward decoded = self.forward(inputs) output = self.loss.forward( # swap batch to first dimension decoded.transpose(0, 1), targets) return output.item() def decode(self, sample_set): self.eval() # set the model to testing mode sample_loader = DataLoader(sample_set, batch_size=1, shuffle=True, num_workers=2) for i, sample in enumerate(sample_loader): with torch.no_grad(): # get stimulus prediction from network activity net = sample['net'].to(device).transpose(0, 1) decoded = self.forward(net) del net # Reduce out batch and channel dims, then put time last # (T, N, C, H, W) -> (H, W, T) decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0) net = sample['net'].squeeze().numpy().sum(axis=1) net = net.transpose(1, 2, 0) stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0) # synced scrollable videos of cell actity, decoding, and stimulus fig, ax = plt.subplots(1, 3, figsize=(17, 6)) net_stack = StackPlotter(ax[0], net, delta=1, vmin=0) deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1) stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1) fig.canvas.mpl_connect('scroll_event', net_stack.onscroll) fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll) fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll) ax[0].set_title('Network Recording') ax[1].set_title('Decoding') ax[2].set_title('Stimulus') fig.tight_layout() plt.show() again = input("Show another reconstruction? Enter 'n' to quit\n") if again == 'n': break def save_decodings(self, sample_set): self.eval() # set the model to testing mode sample_loader = DataLoader(sample_set, batch_size=1, num_workers=2) while True: nametag = input("Decoding set name: ") basefold = os.path.join(sample_set.root_dir, nametag) if not os.path.isdir(basefold): os.mkdir(basefold) break else: print('Folder exists, provide another name...') for i, sample in enumerate(sample_loader): with torch.no_grad(): # get stimulus prediction from network activity net = sample['net'].to(device).transpose(0, 1) decoded = self.forward(net) del sample, net # Reduce out batch and channel dims # (T, N, C, H, W) -> (T, H, W) decoded = decoded.squeeze().cpu().numpy() # save into subfolder corresponding to originating network decofold = os.path.join( basefold, sample_set.rec_frame.iloc[i, 0], # net folder name ) if not os.path.isdir(decofold): os.mkdir(decofold) np.save( # file name corresponding to stimulus os.path.join(decofold, sample_set.rec_frame.iloc[i, 1]), decoded)