def build_model(input_var, ExptDict): # Unpack necessary variables model = ExptDict["model"]["model_id"] n_loc = ExptDict["task"]["n_loc"] n_out = ExptDict["task"]["n_out"] batch_size = ExptDict["batch_size"] n_in = ExptDict["n_in"] n_hid = ExptDict["n_hid"] out_nonlin = ExptDict["task"]["out_nonlin"] if model == 'LeInitRecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.LeInitRecurrent(input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) elif model == 'LeInitRecurrentWithFastWeights': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] gamma = ExptDict["model"]["gamma"] l_out, l_rec = models.LeInitRecurrentWithFastWeights( input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin, gamma=gamma) elif model == 'OrthoInitRecurrent': init_val = ExptDict["model"]["init_val"] l_out, l_rec = models.OrthoInitRecurrent(input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, init_val=init_val, out_nlin=out_nonlin) elif model == 'GRURecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.GRURecurrent(input_var, batch_size=batch_size, n_in=n_loc * n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) return l_out, l_rec
def build_model(input_var,ExptDict): # Unpack necessary variables model = ExptDict["model"]["model_id"] n_loc = ExptDict["task"]["n_loc"] n_out = ExptDict["task"]["n_out"] batch_size = ExptDict["batch_size"] n_in = ExptDict["n_in"] n_hid = ExptDict["n_hid"] out_nonlin = ExptDict["task"]["out_nonlin"] if model == 'LeInitRecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.LeInitRecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) elif model == 'OrthoInitRecurrent': init_val = ExptDict["model"]["init_val"] l_out, l_rec = models.OrthoInitRecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, init_val=init_val, out_nlin=out_nonlin) elif model == 'ResidualRecurrent': leak_inp = ExptDict["model"]["leak_inp"] leak_hid = ExptDict["model"]["leak_hid"] l_out, l_rec = models.ResidualRecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, leak_inp=leak_inp, leak_hid=leak_hid, out_nlin=out_nonlin) elif model == 'GRURecurrent': diag_val = ExptDict["model"]["diag_val"] offdiag_val = ExptDict["model"]["offdiag_val"] l_out, l_rec = models.GRURecurrent(input_var, batch_size=batch_size, n_in=(n_loc+1)*n_in, n_out=n_out, n_hid=n_hid, diag_val=diag_val, offdiag_val=offdiag_val, out_nlin=out_nonlin) return l_out, l_rec