output = Generate_dataset(GLOBAL_NUM_GRAPHS) g_list, test_glist = load_graphs(output, GLOBAL_NUM_GRAPHS) #base_classifier = load_base_model(label_map, g_list) base_args = { 'gm': 'mean_field', 'feat_dim': 2, 'latent_dim': 10, 'out_dim': 20, 'max_lv': 2, 'hidden': 32 } base_classifier = GraphClassifier(num_classes=20, **base_args) env = GraphEdgeEnv(base_classifier) print("len g_list:", len(g_list)) if cmd_args.frac_meta > 0: num_train = int(len(g_list) * (1 - cmd_args.frac_meta)) agent = Agent(g_list, test_glist[num_train:], env) else: agent = Agent(g_list, None, env) if GLOBAL_PHASE == 'train': print("\n\nStarting Training Loop\n\n") agent.train()
loss = F.mse_loss(q_sa, list_target) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description('exp: %.5f, loss: %0.5f' % (self.eps, loss)) log_out.close() if __name__ == '__main__': random.seed(cmd_args.seed) np.random.seed(cmd_args.seed) torch.manual_seed(cmd_args.seed) label_map, _, g_list = load_graphs() random.shuffle(g_list) base_classifier = load_base_model(label_map, g_list) env = GraphEdgeEnv(base_classifier, n_edges=1) if cmd_args.frac_meta > 0: num_train = int(len(g_list) * (1 - cmd_args.frac_meta)) agent = Agent(g_list[:num_train], g_list[num_train:], env) else: agent = Agent(g_list, None, env) agent.train() agent.net.load_state_dict( torch.load(cmd_args.save_dir + '/epoch-best.model')) agent.eval()