def fit_ln(cells, train_stimuli, exptdate, stim_shape, l2=1e-3, readme=None): """Fits an LN model using keras""" ncells = len(cells) batchsize = 5000 # get the layers layers = ln(stim_shape, ncells, weight_init='normal', l2_reg=l2) # compile it model = sequential(layers, RMSprop(lr=1e-4), loss='poisson') # load experiment data test_stimuli = ['whitenoise', 'naturalscene'] data = Experiment(exptdate, cells, train_stimuli, test_stimuli, stim_shape[0], batchsize, nskip=6000) # create a monitor that keeps track of progress monitor = KerasMonitor('ln', model, data, readme, save_every=20) # train train(model, data, monitor, num_epochs=30) return model
def fit_convnet(cells, train_stimuli, exptdate, nclip=0, readme=None): """Main script for fitting a convnet author: Niru Maheswaranathan """ stim_shape = (40, 50, 50) ncells = len(cells) batchsize = 5000 # get the convnet layers layers = convnet(stim_shape, ncells, num_filters=(8, 16), filter_size=(15, 7), weight_init='normal', l2_reg_weights=(0.01, 0.01, 0.01), l1_reg_activity=(0.0, 0.0, 0.001), dropout=(0.1, 0.0)) # compile the keras model model = sequential(layers, 'adam', loss='poisson') # load experiment data test_stimuli = ['whitenoise', 'naturalscene'] data = Experiment(exptdate, cells, train_stimuli, test_stimuli, stim_shape[0], batchsize, nskip=nclip) # create a monitor to track progress monitor = KerasMonitor('convnet', model, data, readme, save_every=20) # train train(model, data, monitor, num_epochs=50) return model
def fit_fc_rnn(expt, stim): train(fc_rnn, expt, stim, model_args=("flatten", "mse"), lr=1e-3, nb_epochs=250, val_split=0.05)
def fit_bn_spat_cnn(expt, stim): train(bn_spat_cnn, expt, stim, model_args=("spatial"), lr=1e-2, nb_epochs=250, val_split=0.05)
def fit_copy_cnn(expt, stim): train(copy_cnn, expt, stim, model_args=(), lr=1e-3, nb_epochs=250, val_split=0.05)
def fit_bn_rnn(expt, stim): train(bn_rnn, expt, stim, lr=5e-3, model_args=("add_dim", "mse"), nb_epochs=250, val_split=0.05, bz=1024)
def fit_cn_tcn(expt, stim): train(cn_tcn, expt, stim, model_args=("add_dim", "mse"), lr=1e-3, nb_epochs=250, val_split=0.05, bz=2048)
def fit_conv_to_rnn(expt, stim): train(conv_to_rnn, expt, stim, model_args=("add_dim", "mse"), lr=1e-4, nb_epochs=250, val_split=0.05, bz=1024)
def fit_conv_to_lstm(expt, stim): train(conv_to_lstm, expt, stim, model_args=("add_dim"), lr=0.1, nb_epochs=250, val_split=0.05, bz=2048)
def fit_conv_lstm(expt, stim): train(conv_lstm, expt, stim, model_args=("add_dim"), lr=1e-3, bz=128, nb_epochs=250, val_split=0.05)
def fit_ln(expt, ci, stim, activation, l2_reg=0.1): if activation.lower() == 'rbf': model_args = (30, 6) else: model_args = () model = functools.partial(linear_nonlinear, activation=activation, l2_reg=l2_reg) #tp.banner(f'Training LN-{activation}, expt {args.expt}, {args.stim}, cell {ci+1:02d}') train(model, expt, stim, model_args=model_args, lr=1e-2, nb_epochs=500, val_split=0.05, cells=[ci])
def fit_bn_cnn(expt, stim): train(bn_cnn, expt, stim, lr=1e-2, nb_epochs=250, val_split=0.05)