示例#1
0
class PlaceCellsTest(unittest.TestCase):
    def setUp(self):
        np.random.seed(1)
        self.place_cells = PlaceCells()

    def test_us_range(self):
        self.assertEqual(self.place_cells.us.shape, (256, 2))

        self.assertLessEqual(np.max(self.place_cells.us), 4.5)
        self.assertGreaterEqual(np.min(self.place_cells.us), -4.5)

    def test_get_activatoin(self):
        pos = (0.0, 0.0)
        c = self.place_cells.get_activation(pos)

        # Check shape == (256,)
        self.assertEqual(c.shape, (256, ))

        # Check whether the sum equals to 1
        self.assertAlmostEqual(np.sum(c), 1.0, places=5)

        self.assertLessEqual(np.max(c), 1.0)
        self.assertGreaterEqual(np.min(c), 0.0)

    def test_get_nearest_cell_pos(self):
        pos = np.copy(self.place_cells.us[0])
        c = self.place_cells.get_activation(pos)
        nearest_cell_pos = self.place_cells.get_nearest_cell_pos(c)

        self.assertTrue(np.allclose(pos, nearest_cell_pos))
示例#2
0
    def __init__(self, options):
        super().__init__()
        self.options = options

        # TODO: set-up checkpointing so that saved model along with saved place cell centers can be loaded
        self.pc = PlaceCells(options, gpu=torch.cuda.is_available()) # this gpu might fail if is available but not used
        self.criterion = cross_entropy if self.options.loss == 'CE' else torch.nn.MSELoss()
示例#3
0
    def test_get_confirm_batch(self):
        np.random.seed(1)
        place_cells = PlaceCells()
        hd_cells = HDCells()
        
        self.data_manager.prepare(place_cells, hd_cells)

        batch_size = 10
        sequence_length = 100

        index_size = self.data_manager.get_confirm_index_size(batch_size, sequence_length)
        # 49
        self.assertEqual(index_size, (50000 // (sequence_length * batch_size)) - 1)

        index = 0
        out = self.data_manager.get_confirm_batch(batch_size, sequence_length, index)
        inputs_batch, place_init_batch, hd_init_batch, place_pos_batch = out

        self.assertEqual(inputs_batch.shape,        (batch_size, sequence_length, 3))
        self.assertEqual(place_init_batch.shape,    (batch_size, 256))
        self.assertEqual(hd_init_batch.shape,       (batch_size, 12))
        self.assertEqual(place_pos_batch.shape,     (batch_size, sequence_length, 2))

        index = index_size-1
        out = self.data_manager.get_confirm_batch(batch_size, sequence_length, index)
        inputs_batch, place_init_batch, hd_init_batch, place_pos_batch = out
        
        self.assertEqual(inputs_batch.shape,        (batch_size, sequence_length, 3))
        self.assertEqual(place_init_batch.shape,    (batch_size, 256))
        self.assertEqual(hd_init_batch.shape,       (batch_size, 12))
        self.assertEqual(place_pos_batch.shape,     (batch_size, sequence_length, 2))
示例#4
0
def main(argv):
    np.random.seed(1)

    if not os.path.exists(flags.save_dir):
        os.mkdir(flags.save_dir)

    data_manager = DataManager()

    place_cells = PlaceCells()
    hd_cells = HDCells()

    data_manager.prepare(place_cells, hd_cells)

    model = Model(place_cell_size=place_cells.cell_size,
                  hd_cell_size=hd_cells.cell_size,
                  sequence_length=flags.sequence_length)

    trainer = Trainer(data_manager, model, flags)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # For Tensorboard log
    log_dir = flags.save_dir + "/log"
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    # Load checkpoints
    saver, start_step = load_checkpoints(sess)

    # Train
    train(sess, trainer, saver, summary_writer, start_step)
示例#5
0
    def test_prepare(self):
        np.random.seed(1)
        place_cells = PlaceCells()
        hd_cells = HDCells()

        self.data_manager.prepare(place_cells, hd_cells)

        # Check inputs shape
        self.assertEqual(self.data_manager.inputs.shape, (49999,3))

        # Check outputs shape
        self.assertEqual(self.data_manager.place_outputs.shape, (49999,256))
        self.assertEqual(self.data_manager.hd_outputs.shape,    (49999,12))
示例#6
0
    def test_init(self):
        np.random.seed(1)

        data_manager = DataManager()
        place_cells = PlaceCells()
        hd_cells = HDCells()
        data_manager.prepare(place_cells, hd_cells)

        sequence_length = 100

        model = Model(place_cell_size=place_cells.cell_size,
                      hd_cell_size=hd_cells.cell_size,
                      sequence_length=sequence_length)

        flags = get_options()
        trainer = Trainer(data_manager, model, flags)
示例#7
0
    def test_get_train_batch(self):
        np.random.seed(1)
        place_cells = PlaceCells()
        hd_cells = HDCells()

        self.data_manager.prepare(place_cells, hd_cells)

        batch_size = 10
        sequence_length = 100
        out = self.data_manager.get_train_batch(batch_size, sequence_length)
        inputs_batch, place_outputs_batch, hd_outputs_batch, place_init_batch, hd_init_batch = \
            out

        self.assertEqual(inputs_batch.shape,        (batch_size, sequence_length, 3))
        self.assertEqual(place_outputs_batch.shape, (batch_size, sequence_length, 256))
        self.assertEqual(hd_outputs_batch.shape,    (batch_size, sequence_length, 12))
        
        self.assertEqual(place_init_batch.shape,    (batch_size, 256))
        self.assertEqual(hd_init_batch.shape,       (batch_size, 12))
示例#8
0
    ret = np.vstack(row_images_with_spacers)
    return ret


def convert_to_colomap(im, cmap):
    im = cmap(im)
    im = np.uint8(im * 255)
    return im


np.random.seed(1)

data_manager = DataManager()

place_cells = PlaceCells()
hd_cells = HDCells()

data_manager.prepare(place_cells, hd_cells)

model = Model(place_cell_size=place_cells.cell_size,
              hd_cell_size=hd_cells.cell_size,
              sequence_length=100)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Load checkpoints
load_checkpoints(sess)

batch_size = 10
示例#9
0
    options.MODEL_type = 'PATFORM'
    options.lr = 1e-4
    options.T = 1000
    options.activation = 'relu'  # recurrent nonlinearity
    options.DoG = True  # use difference of gaussians tuning curves
    options.periodic = False  # trajectories with periodic boundary conditions
    options.norm_cov = False  # normalize translated/averaged spatial covariance (dont use with DoG - autocorrelation signal not strong enough)
    options.gauss_norm = False
    options.box_width = 2.2  # width of training environment
    options.box_height = 2.2  # height of training environment
    # options.run_ID = generate_run_ID_PATFORM(options)

    # NON-ORTHOGONAL SINGLE-CELL PATTERN FORMATION DYNAMICS
    options.Ng = 32
    options.lr = 1e-4
    place_cells = PlaceCells(options)

    # Symmetry-preserving nonlinearity (tanh)
    G = grid_pattern_formation(place_cells, options, activation='tanh')
    plot_ratemaps(G, options.Ng)
    # Symmetry-breaking nonlinearities (relu)
    options.lr = 5e-3
    G = grid_pattern_formation(place_cells, options, activation='relu')
    plot_ratemaps(G, options.Ng)

    # ORTHOGONAL POPULATION PATTERN FORMATION DYNAMICS
    # good run
    options.res = 64
    options.DoG = False
    options.norm_cov = True  # True works better with gaussian tuning
    options.gauss_norm = False
示例#10
0
class BaseSiren(LightningModule):
    """
    This class defines the data-loading, score calculating, visualization and
    logging structure of the model. It is to be used as a superclass for
    the different RNN models to be used (VanillaRNN, LSTM, GRU)
    """
    def __init__(self, options):
        super().__init__()
        self.options = options

        # TODO: set-up checkpointing so that saved model along with saved place cell centers can be loaded
        self.pc = PlaceCells(options, gpu=torch.cuda.is_available(
        ))  # this gpu might fail if is available but not used
        self.criterion = cross_entropy if self.options.loss == 'CE' else torch.nn.MSELoss(
        )

    def training_step(self, batch, batch_idx):
        inputs, place_outputs, pos = batch
        output, coord, layer_outputs = self(inputs, place_outputs, pos)
        loss = self.criterion(output, place_outputs)
        print('LOSS', coord.shape)
        print('LOSS', loss)
        # Weight regularization
        # loss += self._l2_loss() * self.options.weight_decay
        # loss += self.weight_decay * tf.reduce_sum(self.RNN.weights[1]**2)

        pred_pos = self.pc.get_nearest_cell_pos(output)
        err = torch.mean(torch.sqrt(torch.sum((pos - pred_pos)**2, dim=-1)))

        tensorboard_logs = {'train_loss': loss, 'train_err': err}
        return {
            'loss': loss,
            'err': err,
            'output': output,
            'log': tensorboard_logs
        }

    def configure_optimizers(self):
        if self.options.optim == 'SGD':
            return optim.SGD(self.siren.parameters(),
                             lr=self.options.learning_rate,
                             momentum=0.9)
        elif self.options.optim == 'Adam':
            return optim.Adam(self.siren.parameters(),
                              lr=self.options.learning_rate)
        elif self.options.optim == 'RMSProp':
            return optim.RMSprop(self.siren.parameters(),
                                 lr=self.options.learning_rate)
        elif self.options.optim == 'LBFGS':
            return optim.LBFGS(self.siren.parameters(),
                               lr=self.options.learning_rate)
        elif self.options.optim == 'HessianFree':
            return HessianFree(
                self.siren.parameters(),
                lr=self.options.learning_rate)  #  use_gnm=True, verbose=True

    def train_dataloader(self):
        dg = DataGenerator(self.options, self.pc, gpu=self.on_gpu, train=True)
        return dg

    # def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
    #     optimizer.step()

    # TODO: Average of error per step to a csv
    def validation_step(self, batch, batch_idx):
        inputs, place_outputs, pos = batch
        output, coord, layer_outputs = self(inputs, place_outputs, pos)
        loss = self.criterion(output, place_outputs)

        pred_pos = self.pc.get_nearest_cell_pos(output)
        # err = torch.mean(torch.sqrt(torch.sum((pos - pred_pos)**2, dim=-1)))

        err = torch.mean(torch.sqrt((pos - pred_pos)**2), dim=0)

        return {
            'val_loss': loss,
            'layer_outputs': layer_outputs,
            'coord': coord,
            'val_err': err,
            'output': output
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_err = torch.stack([x['val_err'].mean() for x in outputs]).mean()

        # avg position error per step in sequence
        avg_seq_err = (torch.stack([x['val_err'] for x in outputs
                                    ]).mean(dim=0)).mean(dim=-1)
        save_seq_err(avg_seq_err, self.options)

        # maybe only need to do outputs['val_loss']
        tensorboard_logs = {'val_loss': avg_loss, 'val_err': avg_err}
        # these are the full size of the epoch
        pos = torch.cat([x['coord'] for x in outputs], dim=0)
        # all_but_last_two_dims = pos.size()[:-2]
        # pos = pos.view(*all_but_last_two_dims, -1)
        pos = pos.to('cpu').detach().numpy()
        act = torch.cat([x['layer_outputs'] for x in outputs], dim=0)
        # all_but_last_two_dims = act.size()[:-2]
        # act = act.view(*all_but_last_two_dims, -1)
        act = act.to('cpu').detach().numpy()
        # TODO: fix these vars

        # Save a picture of rate maps
        # save_ratemaps(self.model, self.trajectory_generator, self.options, step=tot_step)
        for i in range(pos.shape[0], act.shape[0] + pos.shape[0],
                       pos.shape[0]):
            tmp_options = self.options
            j = i / pos.shape[0]
            tmp_options.run_ID + '_L{j}'
            ppos = pos.reshape(-1, pos.shape[-1])
            aact = act[i - pos.shape[0]:i].reshape(
                -1, act[i - pos.shape[0]:i].shape[-1])
            compute_ratemaps(ppos, aact, tmp_options, epoch=self.current_epoch)
        # TODO: specify range
        del pos, act
        return {
            'val_loss': avg_loss,
            'val_err': avg_err,
            'log': tensorboard_logs
        }

    def val_dataloader(self):
        # TODO: do a real train/val split
        dg = DataGenerator(self.options, self.pc, gpu=self.on_gpu, train=False)
        return dg

    def _l2_loss(self):
        return self.rnn.weight_hh_l0.norm(2)
示例#11
0
 def setUp(self):
     np.random.seed(1)
     self.place_cells = PlaceCells()