Пример #1
0
def fit_ln(cells, train_stimuli, exptdate, stim_shape, l2=1e-3, readme=None):
    """Fits an LN model using keras"""
    ncells = len(cells)
    batchsize = 5000

    # get the layers
    layers = ln(stim_shape, ncells, weight_init='normal', l2_reg=l2)

    # compile it
    model = sequential(layers, RMSprop(lr=1e-4), loss='poisson')

    # load experiment data
    test_stimuli = ['whitenoise', 'naturalscene']
    data = Experiment(exptdate,
                      cells,
                      train_stimuli,
                      test_stimuli,
                      stim_shape[0],
                      batchsize,
                      nskip=6000)

    # create a monitor that keeps track of progress
    monitor = KerasMonitor('ln', model, data, readme, save_every=20)

    # train
    train(model, data, monitor, num_epochs=30)

    return model
Пример #2
0
def fit_convnet(cells, train_stimuli, exptdate, nclip=0, readme=None):
    """Main script for fitting a convnet

    author: Niru Maheswaranathan
    """

    stim_shape = (40, 50, 50)
    ncells = len(cells)
    batchsize = 5000

    # get the convnet layers
    layers = convnet(stim_shape, ncells, num_filters=(8, 16),
                     filter_size=(15, 7), weight_init='normal',
                     l2_reg_weights=(0.01, 0.01, 0.01),
                     l1_reg_activity=(0.0, 0.0, 0.001),
                     dropout=(0.1, 0.0))

    # compile the keras model
    model = sequential(layers, 'adam', loss='poisson')

    # load experiment data
    test_stimuli = ['whitenoise', 'naturalscene']
    data = Experiment(exptdate, cells, train_stimuli, test_stimuli, stim_shape[0], batchsize, nskip=nclip)

    # create a monitor to track progress
    monitor = KerasMonitor('convnet', model, data, readme, save_every=20)

    # train
    train(model, data, monitor, num_epochs=50)

    return model
Пример #3
0
def fit_fc_rnn(expt, stim):
    train(fc_rnn,
          expt,
          stim,
          model_args=("flatten", "mse"),
          lr=1e-3,
          nb_epochs=250,
          val_split=0.05)
Пример #4
0
def fit_bn_spat_cnn(expt, stim):
    train(bn_spat_cnn,
          expt,
          stim,
          model_args=("spatial"),
          lr=1e-2,
          nb_epochs=250,
          val_split=0.05)
Пример #5
0
def fit_copy_cnn(expt, stim):
    train(copy_cnn,
          expt,
          stim,
          model_args=(),
          lr=1e-3,
          nb_epochs=250,
          val_split=0.05)
Пример #6
0
def fit_bn_rnn(expt, stim):
    train(bn_rnn,
          expt,
          stim,
          lr=5e-3,
          model_args=("add_dim", "mse"),
          nb_epochs=250,
          val_split=0.05,
          bz=1024)
Пример #7
0
def fit_cn_tcn(expt, stim):
    train(cn_tcn,
          expt,
          stim,
          model_args=("add_dim", "mse"),
          lr=1e-3,
          nb_epochs=250,
          val_split=0.05,
          bz=2048)
Пример #8
0
def fit_conv_to_rnn(expt, stim):
    train(conv_to_rnn,
          expt,
          stim,
          model_args=("add_dim", "mse"),
          lr=1e-4,
          nb_epochs=250,
          val_split=0.05,
          bz=1024)
Пример #9
0
def fit_conv_to_lstm(expt, stim):
    train(conv_to_lstm,
          expt,
          stim,
          model_args=("add_dim"),
          lr=0.1,
          nb_epochs=250,
          val_split=0.05,
          bz=2048)
Пример #10
0
def fit_conv_lstm(expt, stim):
    train(conv_lstm,
          expt,
          stim,
          model_args=("add_dim"),
          lr=1e-3,
          bz=128,
          nb_epochs=250,
          val_split=0.05)
Пример #11
0
def fit_ln(expt, ci, stim, activation, l2_reg=0.1):

    if activation.lower() == 'rbf':
        model_args = (30, 6)
    else:
        model_args = ()

    model = functools.partial(linear_nonlinear,
                              activation=activation,
                              l2_reg=l2_reg)
    #tp.banner(f'Training LN-{activation}, expt {args.expt}, {args.stim}, cell {ci+1:02d}')
    train(model,
          expt,
          stim,
          model_args=model_args,
          lr=1e-2,
          nb_epochs=500,
          val_split=0.05,
          cells=[ci])
Пример #12
0
def fit_bn_cnn(expt, stim):
    train(bn_cnn, expt, stim, lr=1e-2, nb_epochs=250, val_split=0.05)