def build_model(input_var, ExptDict): # Unpack necessary variables model = ExptDict["model"]["model_id"] n_loc = ExptDict["task"]["n_loc"] n_out = ExptDict["task"]["n_out"] batch_size = ExptDict["batch_size"] n_in = ExptDict["n_in"] n_hid = ExptDict["n_hid"] out_nonlin = ExptDict["task"]["out_nonlin"] if model == 'LeInitRecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.LeInitRecurrent(input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) elif model == 'LeInitRecurrentWithFastWeights': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] gamma = ExptDict["model"]["gamma"] l_out, l_rec = models.LeInitRecurrentWithFastWeights( input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin, gamma=gamma) elif model == 'OrthoInitRecurrent': init_val = ExptDict["model"]["init_val"] l_out, l_rec = models.OrthoInitRecurrent(input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, init_val=init_val, out_nlin=out_nonlin) elif model == 'GRURecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.GRURecurrent(input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) return l_out, l_rec
def build_model(input_var,ExptDict): # Unpack necessary variables model = ExptDict["model"]["model_id"] n_loc = ExptDict["task"]["n_loc"] n_out = ExptDict["task"]["n_out"] batch_size = ExptDict["batch_size"] n_in = ExptDict["n_in"] n_hid = ExptDict["n_hid"] out_nonlin = ExptDict["task"]["out_nonlin"] if model == 'LeInitRecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.LeInitRecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) elif model == 'OrthoInitRecurrent': init_val = ExptDict["model"]["init_val"] l_out, l_rec = models.OrthoInitRecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, init_val=init_val, out_nlin=out_nonlin) elif model == 'ResidualRecurrent': leak_inp = ExptDict["model"]["leak_inp"] leak_hid = ExptDict["model"]["leak_hid"] l_out, l_rec = models.ResidualRecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, leak_inp=leak_inp, leak_hid=leak_hid, out_nlin=out_nonlin) elif model == 'GRURecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.GRURecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) return l_out, l_rec
generator, test_generator = build_generators(t_ind) # Define the input and expected output variable input_var = None mask_var = None if torch.cuda.is_available(): device='cuda:0' else: device='cpu' if model == 'LeInitRecurrent': model = models.LeInitRecurrent(input_var, mask_var=mask_var, batch_size=generator.batch_size, n_in=generator.n_in, n_out=generator.n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin='sigmoid') elif model == 'GRURecurrent': model = models.GRURecurrent(input_var, mask_var=mask_var, batch_size=generator.batch_size, n_in=generator.n_in, n_out=generator.n_out, n_hid=n_hid) # Build the model model.train() optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=wdecay_coeff) def l2_activation_regularization(activations): loss = 0 for i in range(5): loss = loss + 1e-4 * torch.pow(activations[int(-1*i)], 2).mean() return loss # TRAINING s_vec, opt_vec, net_vec, frac_rmse_vec = [], [], [], [] for i, (_, example_input, example_output, example_mask, s, opt_s) in generator: