def snapshot(model, optimizer, config, step, gpus=[0], tag=None, scheduler=None): if scheduler is not None: model_snapshot = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "step": step } else: model_snapshot = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": step } torch.save(model_snapshot, os.path.join(config.save_dir, "model_snapshot_{}.pth".format(tag) if tag is not None else "model_snapshot_{:07d}.pth".format(step))) # update config file's test path save_name = os.path.join(config.save_dir, 'config.yaml') # config_save = edict(yaml.load(open(save_name, 'r'), Loader=yaml.FullLoader)) config_save = edict(yaml.load(open(save_name, 'r'))) config_save.test.test_model_dir = config.save_dir config_save.test.test_model_name = "model_snapshot_{}.pth".format( tag) if tag is not None else "model_snapshot_{:07d}.pth".format(step) yaml.dump(edict2dict(config_save), open(save_name, 'w'), default_flow_style=False)
def fit(model, optimizer, mcmc_sampler, train_dl, max_node_number, max_epoch=20, config=None, save_interval=50, sample_interval=1, sigma_list=None, sample_from_sigma_delta=0.0, test_dl=None): logging.info(f"{sigma_list}, {sample_from_sigma_delta}") assert isinstance(mcmc_sampler, LangevinMCSampler) optimizer.zero_grad() # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=config.train.lr_dacey) for epoch in range(max_epoch): train_losses = [] train_loss_items = [] test_losses = [] test_loss_items = [] t_start = time.time() model.train() for train_adj_b, train_x_b in train_dl: # here, # train_adj_b is of size [batch_size, N, N] # train_x_b is of size [batch_size, N, F_i] train_adj_b = train_adj_b.to(config.dev) train_x_b = train_x_b.to(config.dev) train_node_flag_b = train_adj_b.sum(-1).gt(1e-5).to( dtype=torch.float32) if isinstance(sigma_list, float): sigma_list = [sigma_list] train_x_b, train_noise_adj_b, \ train_node_flag_b, grad_log_q_noise_list = \ gen_list_of_data(train_x_b, train_adj_b, train_node_flag_b, sigma_list) # thereafter, # train_noise_adj_b is of size [len(sigma_list) * batch_size, N, N] # train_x_b is of size [len(sigma_list) * batch_size, N, F_i] optimizer.zero_grad() score = model(x=train_x_b, adjs=train_noise_adj_b, node_flags=train_node_flag_b) loss, loss_items = loss_func(score.chunk(len(sigma_list), dim=0), grad_log_q_noise_list, sigma_list) train_loss_items.append(loss_items) loss.backward() optimizer.step() train_losses.append(loss.detach().cpu().item()) scheduler.step(epoch) assert isinstance(model, nn.Module) model.eval() for test_adj_b, test_x_b in test_dl: test_adj_b = test_adj_b.to(config.dev) test_x_b = test_x_b.to(config.dev) test_node_flag_b = test_adj_b.sum(-1).gt(1e-5).to( dtype=torch.float32) test_x_b, test_noise_adj_b, test_node_flag_b, grad_log_q_noise_list = \ gen_list_of_data(test_x_b, test_adj_b, test_node_flag_b, sigma_list) with torch.no_grad(): score = model(x=test_x_b, adjs=test_noise_adj_b, node_flags=test_node_flag_b) loss, loss_items = loss_func(score.chunk(len(sigma_list), dim=0), grad_log_q_noise_list, sigma_list) test_loss_items.append(loss_items) test_losses.append(loss.detach().cpu().item()) mean_train_loss = np.mean(train_losses) mean_test_loss = np.mean(test_losses) mean_train_loss_item = np.mean(train_loss_items, axis=0) mean_train_loss_item_str = np.array2string(mean_train_loss_item, precision=2, separator="\t", prefix="\t") mean_test_loss_item = np.mean(test_loss_items, axis=0) mean_test_loss_item_str = np.array2string(mean_test_loss_item, precision=2, separator="\t", prefix="\t") logging.info( f'epoch: {epoch:03d}| time: {time.time() - t_start:.1f}s| ' f'train loss: {mean_train_loss:+.3e} | ' f'test loss: {mean_test_loss:+.3e} | ') logging.info(f'epoch: {epoch:03d}| ' f'train loss i: {mean_train_loss_item_str} ' f'test loss i: {mean_test_loss_item_str} | ') if epoch % save_interval == save_interval - 1: to_save = { 'model': model.state_dict(), 'sigma_list': sigma_list, 'config': edict2dict(config), 'epoch': epoch, 'train_loss': mean_train_loss, 'test_loss': mean_test_loss, 'train_loss_item': mean_train_loss_item, 'test_loss_item': mean_test_loss_item, } torch.save( to_save, os.path.join(config.model_save_dir, f"{config.dataset.name}_{sigma_list}.pth")) # torch.save(to_save, os.path.join(config.save_dir, "model.pth")) if epoch % sample_interval == sample_interval - 1: model.eval() test_adj_b, test_x_b = test_dl.__iter__().__next__() test_adj_b = test_adj_b.to(config.dev) test_x_b = test_x_b.to(config.dev) if isinstance(config.mcmc.grad_step_size, (list, tuple)): grad_step_size = config.mcmc.grad_step_size[0] else: grad_step_size = config.mcmc.grad_step_size step_size = grad_step_size * \ torch.tensor(sigma_list).to(test_x_b) \ .repeat_interleave(test_adj_b.size(0), dim=0)[..., None, None] ** 2 test_node_flag_b = test_adj_b.sum(-1).gt(1e-5).to( dtype=torch.float32) test_x_b, test_noise_adj_b, test_node_flag_b, grad_log_q_noise_list = \ gen_list_of_data(test_x_b, test_adj_b, test_node_flag_b, sigma_list) init_adjs = test_noise_adj_b with torch.no_grad(): sample_b, _ = mcmc_sampler.sample( config.sample.batch_size, lambda x, y: model(test_x_b, x, y), max_node_num=max_node_number, step_num=None, init_adjs=init_adjs, init_flags=test_node_flag_b, is_final=True, step_size=step_size) sample_b_list = sample_b.chunk(len(sigma_list), dim=0) init_adjs_list = init_adjs.chunk(len(sigma_list), dim=0) for sigma, sample_b, init_adjs in zip(sigma_list, sample_b_list, init_adjs_list): sample_from_sigma = sigma + sample_from_sigma_delta eval_sample_batch( sample_b, mcmc_sampler.end_sample(test_adj_b, to_int=True)[0], init_adjs, config.save_dir, title=f'epoch_{epoch}_{sample_from_sigma}.pdf') if init_adjs is not None: plot_graphs_adj( mcmc_sampler.end_sample(init_adjs, to_int=True)[0], node_num=test_node_flag_b.sum(-1).cpu().numpy(), title=f'epoch_{epoch}_{sample_from_sigma}_init.pdf', save_dir=config.save_dir) result_dict = eval_torch_batch(mcmc_sampler.end_sample( test_adj_b, to_int=True)[0], sample_b, methods=None) logging.info( f'MMD {epoch} {sample_from_sigma}: {result_dict}')