コード例 #1
0
ファイル: decoder_v1.py プロジェクト: geoffder/retina-decoder
    def fit(self,
            train_set,
            test_set,
            lr=1e-4,
            epochs=10,
            batch_sz=1,
            print_every=40):

        train_loader = DataLoader(train_set,
                                  batch_size=batch_sz,
                                  shuffle=True,
                                  num_workers=2)
        test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2)
        N = train_set.__len__()  # number of samples

        # self.loss = nn.MSELoss().to(device)
        self.loss = DecoderLoss(alpha=10).to(device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)

        n_batches = N // batch_sz
        train_costs, test_costs = [], []
        for i in range(epochs):
            cost = 0
            print("epoch:", i, "n_batches:", n_batches)
            for j, batch in enumerate(train_loader):

                cost += self.train_step(
                    batch['net'].transpose(0, 1).to(device),
                    # batch['net'].to(device),
                    batch['stim'].to(device))
                # try sending batch to GPU, then passing (then delete)
                del batch  # test whether useful for clearing off GPU

                if j % print_every == 0:
                    # costs and accuracies for test set
                    test_cost = 0
                    for t, testB in enumerate(test_loader, 1):
                        testB_cost = self.get_cost(
                            testB['net'].transpose(0, 1).to(device),
                            # testB['net'].to(device),
                            testB['stim'].to(device))
                        test_cost += testB_cost
                    test_cost /= t + 1
                    del testB

                    print("cost: %f" % (test_cost))

            # for plotting
            train_costs.append(cost / n_batches)
            test_costs.append(test_cost)

        # plot cost and accuracy progression
        fig, axes = plt.subplots(1)
        axes.plot(train_costs, label='training')
        axes.plot(test_costs, label='validation')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Cost')
        plt.legend()
        plt.show()
コード例 #2
0
    def fit(self,
            train_set,
            test_set,
            lr=1e-4,
            epochs=10,
            batch_sz=1,
            loss_alpha=10,
            print_every=40):

        train_loader = DataLoader(train_set,
                                  batch_size=batch_sz,
                                  shuffle=True,
                                  num_workers=2)
        test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2)
        N = train_set.__len__()  # number of samples

        # DecoderLoss equivalent to MSE when alpha=0 (original default: 10)
        self.loss = DecoderLoss(alpha=loss_alpha).to(device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8)

        n_batches = N // batch_sz
        train_costs, test_costs = [], []
        for i in range(epochs):
            cost = 0
            print("epoch:", i, "n_batches:", n_batches)
            for j, batch in enumerate(train_loader):
                net, stim = batch['net'].to(device), batch['stim'].to(device)
                cost += self.train_step(net.transpose(0, 1), stim)
                del net, stim, batch

                if j % print_every == 0:
                    # costs and accuracies for test set
                    test_cost = 0
                    for t, testB in enumerate(test_loader, 1):
                        net = testB['net'].to(device)
                        stim = testB['stim'].to(device)
                        testB_cost = self.get_cost(net.transpose(0, 1), stim)
                        del net, stim, testB
                        test_cost += testB_cost
                    test_cost /= t + 1

                    print("cost: %f" % (test_cost))

            # for plotting
            train_costs.append(cost / n_batches)
            test_costs.append(test_cost)

        # plot cost and accuracy progression
        fig, axes = plt.subplots(1)
        axes.plot(train_costs, label='training')
        axes.plot(test_costs, label='validation')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Cost')
        plt.legend()
        plt.show()
コード例 #3
0
ファイル: decoder_v5.py プロジェクト: geoffder/retina-decoder
class RetinaDecoder(nn.Module):
    def __init__(self, pre_pool, grp_tempo_params, conv_params,
                 crnn_cell_params, temp3d_stack_params, decode_params):
        super(RetinaDecoder, self).__init__()
        # layer parameters
        self.pre_pool = pre_pool
        self.grp_tempo_params = grp_tempo_params
        self.conv_params = conv_params
        self.crnn_cell_params = crnn_cell_params
        self.temp3d_stack_params = temp3d_stack_params
        self.decode_params = decode_params
        # create model and send to correct device (GPU if available)
        self.build()
        self.dv = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.dv)

    def build(self):
        # # # # # # # # # # ENCODER NETWORK # # # # # # # # # #
        encoder_mods = []

        # pooling operation before any processing
        if 'op' in self.pre_pool:  # skip by leaving param dict empty
            encoder_mods.append(make_pool3d_layer(self.pre_pool))

        # Grouped Temporal CNN, operating on each cluster channel separately
        for p in self.grp_tempo_params:
            encoder_mods.append(
                TemporalConv3dStack(p['in'], p['out'],
                                    p.get('kernel', (2, 1, 1)),
                                    p.get('space_dilation', 1),
                                    p.get('groups', 1), p.get('dropout', 0),
                                    p.get('activation', nn.ReLU)))
            if 'pool' in p:
                encoder_mods.append(make_pool3d_layer(p['pool']))

        # Spatial Only (non-causal) convolutional layers
        for p in self.conv_params:
            d, h, w = p.get('kernel', (1, 3, 3))
            pad = (d // 2, h // 2, w // 2)
            encoder_mods.append(
                nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1),
                          pad, p.get('dilation', 1), p.get('groups', 1),
                          p.get('bias', True)))
            encoder_mods.append(nn.BatchNorm3d(p['out']))
            encoder_mods.append(p.get('activation', nn.ReLU)())
            if 'pool' in p:
                encoder_mods.append(make_pool3d_layer(p['pool']))

        # Stack of Convolutional Recurrent Network(s)
        if len(self.crnn_cell_params) > 0:
            # swap time from depth dimension to first dimension for CRNN(s)
            # (N, C, T, H, W) -> (T, N, C, H, W)
            encoder_mods.append(Permuter((2, 0, 1, 3, 4)))
        for p in self.crnn_cell_params:
            # recurrenct convolutional cells (GRU or LSTM)
            encoder_mods.append(
                p.get('crnn_cell',
                      crnns.ConvGRUCell_wnorm)(p['dims'], p['in_kernel'],
                                               p['out_kernel'], p['in'],
                                               p['out'],
                                               p.get('learn_initial', False),
                                               p.get('return_hidden', False)))
            if 'post_activation' in p:
                encoder_mods.append(p['post_activation']())
        if len(self.crnn_cell_params) > 0:
            # swap time back to depth dimension following CRNN(s)
            # (T, N, C, H, W) -> (N, C, T, H, W)
            encoder_mods.append(Permuter((1, 2, 0, 3, 4)))

        # Temporal CNN
        for p in self.temp3d_stack_params:
            encoder_mods.append(
                TemporalConv3dStack(p['in'], p['out'],
                                    p.get('kernel', (2, 3, 3)),
                                    p.get('space_dilation', 1),
                                    p.get('groups', 1), p.get('dropout', 0),
                                    p.get('activation', nn.ReLU)))

        # package encoding layers as a Sequential network
        self.encoder_net = nn.Sequential(*encoder_mods)

        # # # # # # # # # # DECODER NETWORK # # # # # # # # # #
        decoder_mods = []

        # Transpose Convolutional layers (upsampling)
        for p in self.decode_params:
            # unpack kernel etc dimensions
            d, h, w = p.get('kernel', (1, 3, 3))
            st_d, st_h, st_w = p.get('stride', (1, 1, 1))
            dil_d, dil_h, dil_w = p.get('dilation', (1, 1, 1))

            # causal transpose
            if p.get('type', 'causal') == 'causal':
                decoder_mods.append(
                    CausalTranspose3d(p['in'], p['out'], p['kernel'],
                                      p['stride'], p.get('groups', 1),
                                      p.get('bias', True),
                                      p.get('dilations', (1, 1, 1))))
            # non-causal transpose
            elif p['type'] == 'trans':
                pad = (d // 2, h // 2, w // 2)
                decoder_mods.append(
                    nn.ConvTranspose3d(p['in'], p['out'],
                                       p['kernel'], p['stride'], pad, pad,
                                       p.get('groups', 1), p.get('bias', True),
                                       p.get('dilations', (1, 1, 1))))
            # plain convolution (spatial only) -> p['type'] == 'conv'
            else:
                pad = (d // 2, h // 2, w // 2)
                decoder_mods.append(
                    nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1),
                              pad, p.get('dilation', 1), p.get('groups', 1),
                              p.get('bias', True)))

            decoder_mods.append(nn.BatchNorm3d(p['out']))
            decoder_mods.append(p.get('activation', nn.Tanh)())

        # package decoding layers as a Sequential network
        self.decoder_net = nn.Sequential(*decoder_mods)

    def forward(self, X):
        X = self.encoder_net(X)
        X = self.decoder_net(X)
        return X

    def fit(self,
            train_set,
            test_set,
            lr=1e-4,
            epochs=10,
            batch_sz=1,
            loss_alpha=10,
            loss_decay=1,
            print_every=0,
            peons=2):

        train_loader = DataLoader(train_set,
                                  batch_size=batch_sz,
                                  shuffle=True,
                                  num_workers=peons)
        test_loader = DataLoader(test_set,
                                 batch_size=batch_sz,
                                 num_workers=peons)
        N = train_set.__len__()  # number of samples

        # DecoderLoss equivalent to MSE when alpha=0 (original default: 10)
        self.loss = DecoderLoss(alpha=loss_alpha, decay=loss_decay).to(self.dv)
        self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8)

        n_batches = np.ceil(N / batch_sz).astype('int')
        print_every = n_batches if print_every < 1 else print_every
        train_prog = None
        train_costs, test_costs = [], []
        for i in range(epochs):
            cost = 0
            print("epoch:", i, "n_batches:", n_batches)
            # start = 0
            for j, batch in enumerate(train_loader):
                # print('time to load batch', timer.time()-start)
                # start = timer.time()
                net, stim = batch['net'].to(self.dv), batch['stim'].to(self.dv)
                cost += self.train_step(net, stim)
                del net, stim, batch
                # print('time to train', timer.time()-start)
                train_prog.step() if train_prog is not None else 0
                if j % print_every == 0:
                    test_prog = ProgressBar(
                        np.ceil(test_set.__len__() / batch_sz).astype('int'),
                        size=np.ceil(test_set.__len__() /
                                     batch_sz).astype('int'),
                        label='validating: ')
                    # costs and accuracies for test set
                    test_cost = 0
                    for t, testB in enumerate(test_loader, 1):
                        net = testB['net'].to(self.dv)
                        stim = testB['stim'].to(self.dv)
                        testB_cost = self.get_cost(net, stim)
                        del net, stim, testB
                        test_cost += testB_cost
                        test_prog.step()
                    test_cost /= t + 1
                    print("validation cost: %f" % (test_cost))

                    train_prog = ProgressBar(print_every,
                                             size=test_set.__len__() * 2 //
                                             batch_sz,
                                             label='training:   ')
                    train_prog.step() if j == 0 else 0  # hack, skipped batch
                # start = timer.time()

            # Decay DecoderLoss sparsity penalty
            self.loss.decay()

            # for plotting
            train_costs.append(cost / n_batches)
            test_costs.append(test_cost)

        # plot cost and accuracy progression
        fig, axes = plt.subplots(1)
        axes.plot(train_costs, label='training')
        axes.plot(test_costs, label='validation')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Cost')
        plt.legend()
        plt.show()

    def train_step(self, inputs, targets):
        self.train()  # set the model to training mode
        self.optimizer.zero_grad()  # Reset gradient

        # Forward
        decoded = self.forward(inputs)  # (N, C, T, H, W)
        output = self.loss.forward(
            # swap time to second dimension -> (N, T, C, H, W)
            decoded.transpose(1, 2),
            targets)

        # Backward
        output.backward()  # compute gradients
        self.optimizer.step()  # Update parameters

        return output.item()  # cost

    def get_cost(self, inputs, targets):
        self.eval()  # set the model to testing mode
        self.optimizer.zero_grad()  # Reset gradient
        with torch.no_grad():
            # Forward
            decoded = self.forward(inputs)  # (N, C, T, H, W)
            output = self.loss.forward(
                # swap time to second dimension -> (N, T, C, H, W)
                decoded.transpose(1, 2),
                targets)
        return output.item()

    def decode(self, sample_set):
        self.eval()  # set the model to testing mode
        sample_loader = DataLoader(sample_set,
                                   batch_size=1,
                                   shuffle=True,
                                   num_workers=2)
        for i, sample in enumerate(sample_loader):
            with torch.no_grad():
                # get stimulus prediction from network activity
                net = sample['net'].to(self.dv)
                decoded = self.forward(net)
                del net

            # Reduce out batch and channel dims, then put time last
            # (N, C, T, H, W) -> (H, W, T)
            decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0)
            net = sample['net'].squeeze().numpy().sum(axis=0)
            net = net.transpose(1, 2, 0)
            stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0)

            # synced scrollable videos of cell actity, decoding, and stimulus
            fig, ax = plt.subplots(1, 3, figsize=(17, 6))
            net_stack = StackPlotter(ax[0], net, delta=1, vmin=0)
            deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1)
            stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1)
            fig.canvas.mpl_connect('scroll_event', net_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll)
            ax[0].set_title('Network Recording')
            ax[1].set_title('Decoding')
            ax[2].set_title('Stimulus')
            fig.tight_layout()
            plt.show()

            again = input("Show another reconstruction? Enter 'n' to quit\n")
            if again == 'n':
                break

    def save_decodings(self, sample_set, name=None):
        self.eval()  # set the model to testing mode
        sample_loader = DataLoader(sample_set, batch_size=1, num_workers=2)

        # make a parent output folder for this dataset if it doesn't exist
        outfold = os.path.join(sample_set.root_dir, 'outputs')
        if not os.path.isdir(outfold):
            os.mkdir(outfold)
        # prompt for name of and create this particular runs output folder
        while True:
            nametag = input("Decoding set name: ") if name is None else name
            name = None  # if parameter name fails, get input next loop
            basefold = os.path.join(outfold, nametag)
            if not os.path.isdir(basefold):
                os.mkdir(basefold)
                break
            else:
                print('Folder exists, provide another name...')

        # generate decoding of every sample in given dataset
        for i, sample in enumerate(sample_loader):
            with torch.no_grad():
                # get stimulus prediction from network activity
                net = sample['net'].to(self.dv)
                decoded = self.forward(net)
                del sample, net

            # Reduce out batch and channel dims
            # (T, N, C, H, W) -> (T, H, W)
            decoded = decoded.squeeze().cpu().numpy()

            # save into subfolder corresponding to originating network
            decofold = os.path.join(
                basefold,
                sample_set.rec_frame.iloc[i, 0],  # net folder name
            )
            if not os.path.isdir(decofold):
                os.mkdir(decofold)
            # .npy format
            np.save(
                # file name corresponding to stimulus
                os.path.join(decofold, sample_set.rec_frame.iloc[i, 1]),
                decoded)
コード例 #4
0
ファイル: decoder_v5.py プロジェクト: geoffder/retina-decoder
    def fit(self,
            train_set,
            test_set,
            lr=1e-4,
            epochs=10,
            batch_sz=1,
            loss_alpha=10,
            loss_decay=1,
            print_every=0,
            peons=2):

        train_loader = DataLoader(train_set,
                                  batch_size=batch_sz,
                                  shuffle=True,
                                  num_workers=peons)
        test_loader = DataLoader(test_set,
                                 batch_size=batch_sz,
                                 num_workers=peons)
        N = train_set.__len__()  # number of samples

        # DecoderLoss equivalent to MSE when alpha=0 (original default: 10)
        self.loss = DecoderLoss(alpha=loss_alpha, decay=loss_decay).to(self.dv)
        self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8)

        n_batches = np.ceil(N / batch_sz).astype('int')
        print_every = n_batches if print_every < 1 else print_every
        train_prog = None
        train_costs, test_costs = [], []
        for i in range(epochs):
            cost = 0
            print("epoch:", i, "n_batches:", n_batches)
            # start = 0
            for j, batch in enumerate(train_loader):
                # print('time to load batch', timer.time()-start)
                # start = timer.time()
                net, stim = batch['net'].to(self.dv), batch['stim'].to(self.dv)
                cost += self.train_step(net, stim)
                del net, stim, batch
                # print('time to train', timer.time()-start)
                train_prog.step() if train_prog is not None else 0
                if j % print_every == 0:
                    test_prog = ProgressBar(
                        np.ceil(test_set.__len__() / batch_sz).astype('int'),
                        size=np.ceil(test_set.__len__() /
                                     batch_sz).astype('int'),
                        label='validating: ')
                    # costs and accuracies for test set
                    test_cost = 0
                    for t, testB in enumerate(test_loader, 1):
                        net = testB['net'].to(self.dv)
                        stim = testB['stim'].to(self.dv)
                        testB_cost = self.get_cost(net, stim)
                        del net, stim, testB
                        test_cost += testB_cost
                        test_prog.step()
                    test_cost /= t + 1
                    print("validation cost: %f" % (test_cost))

                    train_prog = ProgressBar(print_every,
                                             size=test_set.__len__() * 2 //
                                             batch_sz,
                                             label='training:   ')
                    train_prog.step() if j == 0 else 0  # hack, skipped batch
                # start = timer.time()

            # Decay DecoderLoss sparsity penalty
            self.loss.decay()

            # for plotting
            train_costs.append(cost / n_batches)
            test_costs.append(test_cost)

        # plot cost and accuracy progression
        fig, axes = plt.subplots(1)
        axes.plot(train_costs, label='training')
        axes.plot(test_costs, label='validation')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Cost')
        plt.legend()
        plt.show()
コード例 #5
0
ファイル: decoder_v1.py プロジェクト: geoffder/retina-decoder
class RetinaDecoder(nn.Module):
    def __init__(self,
                 crnn_cell_params,
                 crnn_cell=crnns.ConvGRUCell,
                 learn_initial=False):
        super(RetinaDecoder, self).__init__()
        self.crnn_cell_params = crnn_cell_params
        self.crnn_cell = crnn_cell
        self.learn_initial = learn_initial
        self.build()
        self.to(device)

    def build(self):
        self.crnn_stack = nn.ModuleList()
        for i, params in enumerate(self.crnn_cell_params):
            # recurrenct convolutional cells (GRU or LSTM)
            self.crnn_stack.append(
                self.crnn_cell(*params, learn_initial=self.learn_initial))
        self.reduce_conv = nn.Conv2d(params[-1], 1, (1, 1))
        self.reduce_bnorm = nn.BatchNorm2d(1)

    def forward(self, X):
        # stacked convolutional recurrent cells
        for cell in self.crnn_stack:
            X, _ = cell(X)

        # reduce channel dimensionality to 1, frame by frame.
        frames = []
        for frame in X:
            frames.append(self.reduce_bnorm(self.reduce_conv(frame)))
        X = torch.stack(frames, dim=0)
        del frames

        X = torch.tanh(X)
        return X

    def fit(self,
            train_set,
            test_set,
            lr=1e-4,
            epochs=10,
            batch_sz=1,
            print_every=40):

        train_loader = DataLoader(train_set,
                                  batch_size=batch_sz,
                                  shuffle=True,
                                  num_workers=2)
        test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2)
        N = train_set.__len__()  # number of samples

        # self.loss = nn.MSELoss().to(device)
        self.loss = DecoderLoss(alpha=10).to(device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)

        n_batches = N // batch_sz
        train_costs, test_costs = [], []
        for i in range(epochs):
            cost = 0
            print("epoch:", i, "n_batches:", n_batches)
            for j, batch in enumerate(train_loader):

                cost += self.train_step(
                    batch['net'].transpose(0, 1).to(device),
                    # batch['net'].to(device),
                    batch['stim'].to(device))
                # try sending batch to GPU, then passing (then delete)
                del batch  # test whether useful for clearing off GPU

                if j % print_every == 0:
                    # costs and accuracies for test set
                    test_cost = 0
                    for t, testB in enumerate(test_loader, 1):
                        testB_cost = self.get_cost(
                            testB['net'].transpose(0, 1).to(device),
                            # testB['net'].to(device),
                            testB['stim'].to(device))
                        test_cost += testB_cost
                    test_cost /= t + 1
                    del testB

                    print("cost: %f" % (test_cost))

            # for plotting
            train_costs.append(cost / n_batches)
            test_costs.append(test_cost)

        # plot cost and accuracy progression
        fig, axes = plt.subplots(1)
        axes.plot(train_costs, label='training')
        axes.plot(test_costs, label='validation')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Cost')
        plt.legend()
        plt.show()

    def train_step(self, inputs, targets):
        self.train()  # set the model to training mode
        self.optimizer.zero_grad()  # Reset gradient

        # Forward
        decoded = self.forward(inputs)
        output = self.loss.forward(
            # swap batch to first dimension
            decoded.transpose(0, 1),
            targets)

        # Backward
        output.backward()  # compute gradients
        self.optimizer.step()  # Update parameters

        return output.item()  # cost

    def get_cost(self, inputs, targets):
        self.eval()  # set the model to testing mode
        self.optimizer.zero_grad()  # Reset gradient
        with torch.no_grad():
            # Forward
            decoded = self.forward(inputs)
            output = self.loss.forward(
                # swap batch to first dimension
                decoded.transpose(0, 1),
                targets)
        return output.item()

    def decode(self, sample_set):
        self.eval()  # set the model to testing mode
        sample_loader = DataLoader(sample_set,
                                   batch_size=1,
                                   shuffle=True,
                                   num_workers=2)
        for i, sample in enumerate(sample_loader):
            with torch.no_grad():
                # get stimulus prediction from network activity
                decoded = self.forward(sample['net'].to(device))

            # Reduce out batch and channel dims, then put time last
            # (T, N, C, H, W) -> (H, W, T)
            decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0)
            net = sample['net'].squeeze().numpy().sum(axis=1)
            net = net.transpose(1, 2, 0)
            stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0)

            # synced scrollable videos of cell actity, decoding, and stimulus
            fig, ax = plt.subplots(1, 3)
            net_stack = StackPlotter(ax[0], net, delta=1, vmin=0)
            deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1)
            stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1)
            fig.canvas.mpl_connect('scroll_event', net_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll)
            ax[0].set_title('Network Recording')
            ax[1].set_title('Decoding')
            ax[2].set_title('Stimulus')
            fig.tight_layout()
            plt.show()

            again = input("Show another reconstruction? Enter 'n' to quit\n")
            if again == 'n':
                break
コード例 #6
0
class RetinaDecoder(nn.Module):
    def __init__(self,
                 grp_conv_params,
                 conv_params,
                 crnn_cell_params,
                 trans_params,
                 post_conv_params,
                 crnn_cell=crnns.ConvGRUCell,
                 learn_initial=False):
        super(RetinaDecoder, self).__init__()
        # layer parameters
        self.grp_conv_params = grp_conv_params
        self.conv_params = conv_params
        self.crnn_cell_params = crnn_cell_params
        self.trans_params = trans_params
        self.post_conv_params = post_conv_params
        # ConvRNN settings
        self.crnn_cell = crnn_cell
        self.learn_initial = learn_initial
        # create model and send to GPU
        self.build()
        self.to(device)

    def build(self):
        # grouped convolutions
        self.grp_conv_layers = nn.ModuleList()
        self.grp_conv_bnorms = nn.ModuleList()
        # convolutions
        self.conv_layers = nn.ModuleList()
        self.conv_bnorms = nn.ModuleList()
        # recurrent convolutions
        self.crnn_stack = nn.ModuleList()
        # transpose convolutions
        self.trans_layers = nn.ModuleList()
        self.trans_bnorms = nn.ModuleList()
        # post-upsampling convolutions
        self.post_conv_layers = nn.ModuleList()
        self.post_conv_bnorms = nn.ModuleList()

        for params in self.grp_conv_params:
            # params: [in, out, (kernel), (stride), (dilation), groups]
            pad = ((params[2][0] * params[4][0] - 1) // 2,
                   (params[2][1] * params[4][1] - 1) // 2,
                   (params[2][2] * params[4][2] - 1) // 2)
            self.grp_conv_layers.append(
                nn.Conv3d(*params[:4], pad, *params[4:]))
            self.grp_conv_bnorms.append(nn.BatchNorm3d(params[1]))

        for params in self.conv_params:
            # params: [in, out, (kernel), (stride)]
            pad = (params[2][0] // 2, params[2][1] // 2, params[2][2] // 2)
            self.conv_layers.append(nn.Conv3d(*params, pad))
            self.conv_bnorms.append(nn.BatchNorm3d(params[1]))

        for params in self.crnn_cell_params:
            # params: [(dims), (in_kernel), (out_kernel), in_C, out_C]
            # recurrenct convolutional cells (GRU or LSTM)
            self.crnn_stack.append(
                self.crnn_cell(*params, learn_initial=self.learn_initial))

        for params in self.trans_params:
            pad = (params[2][0] // 2, params[2][1] // 2, params[2][2] // 2)
            self.trans_layers.append(
                nn.ConvTranspose3d(*params, padding=pad, output_padding=pad))
            self.trans_bnorms.append(nn.BatchNorm3d(params[1]))

        for params in self.post_conv_params:
            # params: [in, out, (kernel), (stride)]
            pad = (params[2][0] // 2, params[2][1] // 2, params[2][2] // 2)
            self.post_conv_layers.append(nn.Conv3d(*params, pad))
            self.post_conv_bnorms.append(nn.BatchNorm3d(params[1]))

    def forward(self, X):
        # time to 'depth' dimension
        X = X.permute(1, 2, 0, 3, 4)  # to (N, C, T, H, W)

        # reduce spatial dimensionality (collate somatic information)
        X = F.avg_pool3d(X, (1, 2, 2))

        # grouped (cluster siloed) convolutions
        for conv, bnorm in zip(self.grp_conv_layers, self.grp_conv_bnorms):
            X = torch.tanh(bnorm(conv(X)))

        # frame-by-frame (space only) convolutions
        for conv, bnorm in zip(self.conv_layers, self.conv_bnorms):
            X = torch.tanh(bnorm(conv(X)))
        X = F.avg_pool3d(X, (1, 2, 2))

        # return to time dimension first for operations over time
        X = X.permute(2, 0, 1, 3, 4)  # back to (T, N, C, H, w)

        # stacked convolutional recurrent cells
        for cell in self.crnn_stack:
            X, _ = cell(X)

        # expand back out in space and reduce channels
        X = X.permute(1, 2, 0, 3, 4)  # time to 'depth' dimension
        for trans, bnorm in zip(self.trans_layers, self.trans_bnorms):
            X = torch.tanh(bnorm(trans(X)))

        # clean up with more spatial convs (try interleaving with trans next)
        # frame-by-frame (space only) convolutions
        for conv, bnorm in zip(self.post_conv_layers, self.post_conv_bnorms):
            X = torch.tanh(bnorm(conv(X)))

        X = X.permute(2, 0, 1, 3, 4)  # back to (T, N, C, H, w)

        return X

    def fit(self,
            train_set,
            test_set,
            lr=1e-4,
            epochs=10,
            batch_sz=1,
            print_every=40):

        train_loader = DataLoader(train_set,
                                  batch_size=batch_sz,
                                  shuffle=True,
                                  num_workers=2)
        test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2)
        N = train_set.__len__()  # number of samples

        # DecoderLoss equivalent to MSE when alpha=0
        self.loss = DecoderLoss(alpha=10).to(device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)

        n_batches = N // batch_sz
        train_costs, test_costs = [], []
        for i in range(epochs):
            cost = 0
            print("epoch:", i, "n_batches:", n_batches)
            for j, batch in enumerate(train_loader):

                cost += self.train_step(
                    batch['net'].transpose(0, 1).to(device),
                    batch['stim'].to(device))
                # try sending batch to GPU, then passing (then delete)
                del batch  # test whether useful for clearing off GPU

                if j % print_every == 0:
                    # costs and accuracies for test set
                    test_cost = 0
                    for t, testB in enumerate(test_loader, 1):
                        testB_cost = self.get_cost(
                            testB['net'].transpose(0, 1).to(device),
                            testB['stim'].to(device))
                        test_cost += testB_cost
                    test_cost /= t + 1
                    del testB

                    print("cost: %f" % (test_cost))

            # for plotting
            train_costs.append(cost / n_batches)
            test_costs.append(test_cost)

        # plot cost and accuracy progression
        fig, axes = plt.subplots(1)
        axes.plot(train_costs, label='training')
        axes.plot(test_costs, label='validation')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Cost')
        plt.legend()
        plt.show()

    def train_step(self, inputs, targets):
        self.train()  # set the model to training mode
        self.optimizer.zero_grad()  # Reset gradient

        # Forward
        decoded = self.forward(inputs)
        output = self.loss.forward(
            # swap batch to first dimension
            decoded.transpose(0, 1),
            targets)

        # Backward
        output.backward()  # compute gradients
        self.optimizer.step()  # Update parameters

        return output.item()  # cost

    def get_cost(self, inputs, targets):
        self.eval()  # set the model to testing mode
        self.optimizer.zero_grad()  # Reset gradient
        with torch.no_grad():
            # Forward
            decoded = self.forward(inputs)
            output = self.loss.forward(
                # swap batch to first dimension
                decoded.transpose(0, 1),
                targets)
        return output.item()

    def decode(self, sample_set):
        self.eval()  # set the model to testing mode
        sample_loader = DataLoader(sample_set,
                                   batch_size=1,
                                   shuffle=True,
                                   num_workers=2)
        for i, sample in enumerate(sample_loader):
            with torch.no_grad():
                # get stimulus prediction from network activity
                decoded = self.forward(sample['net'].to(device))

            # Reduce out batch and channel dims, then put time last
            # (T, N, C, H, W) -> (H, W, T)
            decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0)
            net = sample['net'].squeeze().numpy().sum(axis=1)
            net = net.transpose(1, 2, 0)
            stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0)

            # synced scrollable videos of cell actity, decoding, and stimulus
            fig, ax = plt.subplots(1, 3, figsize=(17, 6))
            net_stack = StackPlotter(ax[0], net, delta=1, vmin=0)
            deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1)
            stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1)
            fig.canvas.mpl_connect('scroll_event', net_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll)
            ax[0].set_title('Network Recording')
            ax[1].set_title('Decoding')
            ax[2].set_title('Stimulus')
            fig.tight_layout()
            plt.show()

            again = input("Show another reconstruction? Enter 'n' to quit\n")
            if again == 'n':
                break
コード例 #7
0
class RetinaDecoder(nn.Module):
    def __init__(self, grp_tempo_params, conv_params, crnn_cell_params,
                 temp3d_stack_params, trans_params, post_conv_params):
        super(RetinaDecoder, self).__init__()
        # layer parameters
        self.grp_tempo_params = grp_tempo_params
        self.conv_params = conv_params
        self.crnn_cell_params = crnn_cell_params
        self.temp3d_stack_params = temp3d_stack_params
        self.trans_params = trans_params
        self.post_conv_params = post_conv_params
        # create model and send to GPU
        self.build()
        self.to(device)

    def build(self):
        # grouped convolutions
        self.grp_tempo_layers = nn.ModuleList()
        # convolutions
        self.conv_layers = nn.ModuleList()
        self.conv_bnorms = nn.ModuleList()
        # recurrent convolutions
        self.crnn_stack = nn.ModuleList()
        # 3d temporal convolutions
        self.tempo3d_layers = nn.ModuleList()
        # transpose convolutions
        self.trans_layers = nn.ModuleList()
        self.trans_bnorms = nn.ModuleList()
        # post-upsampling convolutions
        self.post_conv_layers = nn.ModuleList()
        self.post_conv_bnorms = nn.ModuleList()

        for p in self.grp_tempo_params:
            self.grp_tempo_layers.append(
                TemporalConv3dStack(p['in'], p['out'],
                                    p.get('kernel', (2, 1, 1)),
                                    p.get('space_dilation', 1),
                                    p.get('groups', 1), p.get('dropout', 0),
                                    p.get('activation', nn.ReLU)))

        for p in self.conv_params:
            d, h, w = p.get('kernel', (1, 3, 3))
            pad = (d // 2, h // 2, w // 2)
            self.conv_layers.append(
                nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1),
                          pad, p.get('dilation', 1), p.get('groups', 1),
                          p.get('bias', True)))
            self.conv_bnorms.append(nn.BatchNorm3d(p['out']))

        for p in self.crnn_cell_params:
            # recurrenct convolutional cells (GRU or LSTM)
            self.crnn_stack.append(
                p.get('crnn_cell',
                      crnns.ConvGRUCell_wnorm)(p['dims'], p['in_kernel'],
                                               p['out_kernel'], p['in'],
                                               p['out'],
                                               p.get('learn_initial', False)))

        for p in self.temp3d_stack_params:
            self.tempo3d_layers.append(
                TemporalConv3dStack(p['in'], p['out'],
                                    p.get('kernel', (2, 3, 3)),
                                    p.get('space_dilation', 1),
                                    p.get('groups', 1), p.get('dropout', 0),
                                    p.get('activation', nn.ReLU)))

        for p in self.trans_params:
            self.trans_layers.append(
                CausalTranspose3d(p['in'], p['out'], p['kernel'], p['stride'],
                                  p.get('groups', 1), p.get('bias', True),
                                  p.get('dilations', (1, 1, 1))))
            self.trans_bnorms.append(nn.BatchNorm3d(p['out']))

        for p in self.post_conv_params:
            d, h, w = p.get('kernel', (1, 3, 3))
            pad = (d // 2, h // 2, w // 2)
            self.post_conv_layers.append(
                nn.Conv3d(p['in'], p['out'], (d, h, w), p.get('stride', 1),
                          pad, p.get('dilation', 1), p.get('groups', 1),
                          p.get('bias', True)))
            self.post_conv_bnorms.append(nn.BatchNorm3d(p['out']))

    def forward(self, X):
        # time to 'depth' dimension
        X = X.permute(1, 2, 0, 3, 4)  # to (N, C, T, H, W)

        # reduce spatial dimensionality (collate somatic information)
        X = F.avg_pool3d(X, (1, 2, 2))

        # grouped (cluster siloed) temporal convolutions
        for tempo_conv in self.grp_tempo_layers:
            X = tempo_conv(X)

        # frame-by-frame (space only) convolutions
        for conv, bnorm in zip(self.conv_layers, self.conv_bnorms):
            X = torch.tanh(bnorm(conv(X)))
        X = F.avg_pool3d(X, (2, 2, 2))
        # testing! (try max if using ReLU at the start)
        # X = F.max_pool3d(X, (1, 2, 2))

        if len(self.crnn_stack) > 0:
            # return to time dimension first for operations over time
            X = X.permute(2, 0, 1, 3, 4)  # back to (T, N, C, H, w)

            # stacked convolutional recurrent cells
            for cell in self.crnn_stack:
                X = cell(X)

            X = F.relu(X)  # test

            # expand back out in space and reduce channels
            X = X.permute(1, 2, 0, 3, 4)  # time to 'depth' dimension

        for tempo_conv in self.tempo3d_layers:
            X = tempo_conv(X)

        for trans, bnorm in zip(self.trans_layers, self.trans_bnorms):
            X = torch.tanh(bnorm(trans(X)))

        # clean up with more spatial convs (try interleaving with trans next)
        # frame-by-frame (space only) convolutions
        for conv, bnorm in zip(self.post_conv_layers, self.post_conv_bnorms):
            X = torch.tanh(bnorm(conv(X)))

        X = X.permute(2, 0, 1, 3, 4)  # back to (T, N, C, H, w)

        return X

    def fit(self,
            train_set,
            test_set,
            lr=1e-4,
            epochs=10,
            batch_sz=1,
            loss_alpha=10,
            print_every=40):

        train_loader = DataLoader(train_set,
                                  batch_size=batch_sz,
                                  shuffle=True,
                                  num_workers=2)
        test_loader = DataLoader(test_set, batch_size=batch_sz, num_workers=2)
        N = train_set.__len__()  # number of samples

        # DecoderLoss equivalent to MSE when alpha=0 (original default: 10)
        self.loss = DecoderLoss(alpha=loss_alpha).to(device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr, eps=1e-8)

        n_batches = N // batch_sz
        train_costs, test_costs = [], []
        for i in range(epochs):
            cost = 0
            print("epoch:", i, "n_batches:", n_batches)
            for j, batch in enumerate(train_loader):
                net, stim = batch['net'].to(device), batch['stim'].to(device)
                cost += self.train_step(net.transpose(0, 1), stim)
                del net, stim, batch

                if j % print_every == 0:
                    # costs and accuracies for test set
                    test_cost = 0
                    for t, testB in enumerate(test_loader, 1):
                        net = testB['net'].to(device)
                        stim = testB['stim'].to(device)
                        testB_cost = self.get_cost(net.transpose(0, 1), stim)
                        del net, stim, testB
                        test_cost += testB_cost
                    test_cost /= t + 1

                    print("cost: %f" % (test_cost))

            # for plotting
            train_costs.append(cost / n_batches)
            test_costs.append(test_cost)

        # plot cost and accuracy progression
        fig, axes = plt.subplots(1)
        axes.plot(train_costs, label='training')
        axes.plot(test_costs, label='validation')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Cost')
        plt.legend()
        plt.show()

    def train_step(self, inputs, targets):
        self.train()  # set the model to training mode
        self.optimizer.zero_grad()  # Reset gradient

        # Forward
        decoded = self.forward(inputs)
        output = self.loss.forward(
            # swap batch to first dimension
            decoded.transpose(0, 1),
            targets)

        # Backward
        output.backward()  # compute gradients
        self.optimizer.step()  # Update parameters

        return output.item()  # cost

    def get_cost(self, inputs, targets):
        self.eval()  # set the model to testing mode
        self.optimizer.zero_grad()  # Reset gradient
        with torch.no_grad():
            # Forward
            decoded = self.forward(inputs)
            output = self.loss.forward(
                # swap batch to first dimension
                decoded.transpose(0, 1),
                targets)
        return output.item()

    def decode(self, sample_set):
        self.eval()  # set the model to testing mode
        sample_loader = DataLoader(sample_set,
                                   batch_size=1,
                                   shuffle=True,
                                   num_workers=2)
        for i, sample in enumerate(sample_loader):
            with torch.no_grad():
                # get stimulus prediction from network activity
                net = sample['net'].to(device).transpose(0, 1)
                decoded = self.forward(net)
                del net

            # Reduce out batch and channel dims, then put time last
            # (T, N, C, H, W) -> (H, W, T)
            decoded = decoded.squeeze().cpu().numpy().transpose(1, 2, 0)
            net = sample['net'].squeeze().numpy().sum(axis=1)
            net = net.transpose(1, 2, 0)
            stim = sample['stim'].squeeze().numpy().transpose(1, 2, 0)

            # synced scrollable videos of cell actity, decoding, and stimulus
            fig, ax = plt.subplots(1, 3, figsize=(17, 6))
            net_stack = StackPlotter(ax[0], net, delta=1, vmin=0)
            deco_stack = StackPlotter(ax[1], decoded, delta=1, vmin=-1, vmax=1)
            stim_stack = StackPlotter(ax[2], stim, delta=1, vmin=-1, vmax=1)
            fig.canvas.mpl_connect('scroll_event', net_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', deco_stack.onscroll)
            fig.canvas.mpl_connect('scroll_event', stim_stack.onscroll)
            ax[0].set_title('Network Recording')
            ax[1].set_title('Decoding')
            ax[2].set_title('Stimulus')
            fig.tight_layout()
            plt.show()

            again = input("Show another reconstruction? Enter 'n' to quit\n")
            if again == 'n':
                break

    def save_decodings(self, sample_set):
        self.eval()  # set the model to testing mode
        sample_loader = DataLoader(sample_set, batch_size=1, num_workers=2)

        while True:
            nametag = input("Decoding set name: ")
            basefold = os.path.join(sample_set.root_dir, nametag)
            if not os.path.isdir(basefold):
                os.mkdir(basefold)
                break
            else:
                print('Folder exists, provide another name...')

        for i, sample in enumerate(sample_loader):
            with torch.no_grad():
                # get stimulus prediction from network activity
                net = sample['net'].to(device).transpose(0, 1)
                decoded = self.forward(net)
                del sample, net

            # Reduce out batch and channel dims
            # (T, N, C, H, W) -> (T, H, W)
            decoded = decoded.squeeze().cpu().numpy()

            # save into subfolder corresponding to originating network
            decofold = os.path.join(
                basefold,
                sample_set.rec_frame.iloc[i, 0],  # net folder name
            )
            if not os.path.isdir(decofold):
                os.mkdir(decofold)
            np.save(
                # file name corresponding to stimulus
                os.path.join(decofold, sample_set.rec_frame.iloc[i, 1]),
                decoded)