def train(epochs, batchsize, interval, c_path, s_path, modeldir): # Dataset definition dataset = CRDataset(c_path, s_path) collator = CollateFn() # Model definition generator = CartoonRenderer() generator.cuda() generator.train() gen_opt = torch.optim.Adam(generator.parameters(), lr=0.0001) discriminator = Discriminator() discriminator.cuda() discriminator.train() dis_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0001) iterations = 0 for epoch in range(epochs): dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True, drop_last=True, collate_fn=collator) dataloader = tqdm(dataloader) for i, data in enumerate(dataloader): iterations += 1 c, s = data y, _, _, _ = generator(c, s) dis_loss = adversarial_loss_dis(discriminator, y, s) dis_opt.zero_grad() dis_loss.backward() dis_opt.step() y, c_feat, sa_list, y_feat = generator(c, s) y_c, _, _, _ = generator(c, c) y_s, _, _, _ = generator(s, s) gen_loss = adversarial_loss_gen(discriminator, y) gen_loss += reconstruction_loss(y_c, c) gen_loss += reconstruction_loss(y_s, s) gen_loss += content_loss(sa_list, y_feat) gen_loss += style_loss(c_feat, y_feat) gen_opt.zero_grad() gen_loss.backward() gen_opt.step() if iterations % interval == 1: torch.save(generator.state_dict(), f"{modeldir}/model_{iterations}.pt") print( f"iter: {iterations} dis loss: {dis_loss.data} gen loss: {gen_loss.data}" )
def train(epochs, batchsize, interval, c_path, s_path): # Dataset definition dataset = HairDataset(c_path, s_path) collator = CollateFn() # Model & Optimizer Definition munit = MUNIT() munit.cuda() munit.train() m_opt = torch.optim.Adam(munit.parameters(), lr=0.0001, betas=(0.5, 0.999), weight_decay=0.0001) discriminator_a = Discriminator() discriminator_a.cuda() discriminator_a.train() da_opt = torch.optim.Adam(discriminator_a.parameters(), lr=0.0001, betas=(0.5, 0.999), weight_decay=0.0001) discriminator_b = Discriminator() discriminator_b.cuda() discriminator_b.train() db_opt = torch.optim.Adam(discriminator_b.parameters(), lr=0.0001, betas=(0.5, 0.999), weight_decay=0.0001) vgg = Vgg19Norm() vgg.cuda() vgg.train() iterations = 0 for epoch in range(epochs): dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True, drop_last=True, collate_fn=collator) dataloader = tqdm(dataloader) for i, data in enumerate(dataloader): iterations += 1 a, b = data _, _, _, _, _, _, ba, ab, _, _, _, _, _, _ = munit(a, b) loss = adversarial_dis_loss(discriminator_a, ba, a) loss += adversarial_dis_loss(discriminator_b, ab, b) da_opt.zero_grad() db_opt.zero_grad() loss.backward() da_opt.step() db_opt.step() c_a, s_a, c_b, s_b, a_recon, \ b_recon, ba, ab, c_b_recon, s_a_recon, c_a_recon, s_b_recon, aba, bab = munit(a, b) loss = adversarial_gen_loss(discriminator_a, ba) loss += adversarial_gen_loss(discriminator_b, ab) loss += 10 * reconstruction_loss(a_recon, a) loss += 10 * reconstruction_loss(b_recon, b) loss += reconstruction_loss(c_a, c_a_recon) loss += reconstruction_loss(c_b, c_b_recon) loss += reconstruction_loss(s_a, s_a_recon) loss += reconstruction_loss(s_b, s_b_recon) loss += 10 * reconstruction_loss(aba, a) loss += 10 * reconstruction_loss(bab, b) loss += perceptual_loss(vgg, ba, b) loss += perceptual_loss(vgg, ab, a) m_opt.zero_grad() loss.backward() m_opt.step() if iterations % interval == 1: torch.save(munit.load_state_dict, f"./modeldir/model_{iterations}.pt") pylab.rcParams['figure.figsize'] = (16.0, 16.0) pylab.clf() munit.eval() with torch.no_grad(): _, _, _, _, _, _, _, ab, _, _, _, _, _, _ = munit(a, b) fake = ab.detach().cpu().numpy() real = a.detach().cpu().numpy() for i in range(batchsize): tmp = (np.clip(real[i] * 127.5 + 127.5, 0, 255)).transpose(1, 2, 0).astype(np.uint8) pylab.subplot(4, 4, 2 * i + 1) pylab.imshow(tmp) pylab.axis("off") pylab.savefig( "outdir/visualize_{}.png".format(iterations)) tmp = (np.clip(fake[i] * 127.5 + 127.5, 0, 255)).transpose(1, 2, 0).astype(np.uint8) pylab.subplot(4, 4, 2 * i + 2) pylab.imshow(tmp) pylab.axis("off") pylab.savefig( "outdir/visualize_{}.png".format(iterations)) munit.train() print(f"iter: {iterations} loss: {loss.data}")
# Hyperparameters num_classes = 10 learning_rate = 1e-4 batch_size = 50 num_epochs = 5 #load dataset train_dataset = datasets.MNIST(root='dataset/', train=True, transform=GraphTransform(device), download=False) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=CollateFn(device)) test_dataset = datasets.MNIST(root='dataset/', train=False, transform=GraphTransform(device), download=False) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, collate_fn=CollateFn(device)) # to create A+I, but please see MNIST.py, this is just a referece def adj_head(m): M = m**2
def train(epochs, batchsize, s_interval, c_weight, kl_weight, x_path, y_path): generator = Generator() generator.cuda() generator.train() content_discriminator = ContentDiscriminator() content_discriminator.cuda() content_discriminator.train() domain_x_discriminator = DomainDiscriminator() domain_x_discriminator.cuda() domain_x_discriminator.train() domain_y_discriminator = DomainDiscriminator() domain_y_discriminator.cuda() domain_y_discriminator.train() g_optim = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999)) cdis_optim = torch.optim.Adam(content_discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999)) ddis_x_optim = torch.optim.Adam(domain_x_discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999)) ddis_y_optim = torch.optim.Adam(domain_y_discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999)) dataset = HairDataset(medium_path=x_path, twin_path=y_path) collator = CollateFn() iterations = 0 for epoch in range(epochs): dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=True, collate_fn=collator.train, drop_last=True, num_workers=0) progress_bar = tqdm(dataloader) for index, data in enumerate(progress_bar): iterations += 1 x, y = data # discriminator update enc_x, enc_y, _, _, fake_x, fake_y, _, _, infers_x, infers_y = generator.forward(x, y) _, infer_x, _ = infers_x _, infer_y, _ = infers_y dis_loss = adversarial_content_D(content_discriminator, enc_x, enc_y) dis_loss += adversarial_domain_D(domain_x_discriminator, fake_x, x) dis_loss += adversarial_domain_D(domain_y_discriminator ,fake_y, y) dis_loss += adversarial_domain_D(domain_x_discriminator, infer_x, x) dis_loss += adversarial_domain_D(domain_y_discriminator, infer_y, y) cdis_optim.zero_grad() ddis_x_optim.zero_grad() ddis_y_optim.zero_grad() dis_loss.backward() cdis_optim.step() ddis_x_optim.step() ddis_y_optim.step() # generator update enc_x, enc_y, attr_x, attr_y, fake_x, fake_y, recon_x, recon_y, infers_x, infers_y = generator.forward(x ,y) latent_x, infer_x, infer_attr_x = infers_x latent_y, infer_y, infer_attr_y = infers_y _, _, _, _, fake_xyx, fake_yxy, _, _, _, _ = generator.forward(fake_x, fake_y) gen_loss = adversarial_content_G(content_discriminator, enc_x, enc_y) gen_loss += adversarial_domain_G(domain_x_discriminator, fake_x) gen_loss += adversarial_domain_G(domain_y_discriminator, fake_y) gen_loss += adversarial_domain_G(domain_x_discriminator, infer_x) gen_loss += adversarial_domain_G(domain_y_discriminator, infer_y) gen_loss += c_weight * cross_cycle_consistency_loss(x, y, fake_xyx, fake_yxy) gen_loss += c_weight * cross_cycle_consistency_loss(x, y, recon_x, recon_y) gen_loss += c_weight * cross_cycle_consistency_loss(latent_x, latent_y, infer_attr_x, infer_attr_y) #gen_loss += kl_weight * (l2_regularize(attr_x) + l2_regularize(attr_y)) g_optim.zero_grad() gen_loss.backward() g_optim.step() if iterations % s_interval == 1: torch.save(generator.state_dict(), './model/model_{}.pt'.format(iterations)) pylab.rcParams['figure.figsize'] = (16.0,16.0) pylab.clf() with torch.no_grad(): _, _, _, _, _, fake_y, _, _, _, _ = generator.forward(x, y) fake_y = fake_y[:2].detach().cpu().numpy() real_x = x[:2].detach().cpu().numpy() for i in range(1): tmp = (np.clip(real_x[i] * 127.5 + 127.5, 0, 255)).transpose(1, 2, 0).astype(np.uint8) pylab.subplot(2, 2, 2 * i + 1) pylab.imshow(tmp) pylab.axis("off") pylab.savefig("outdir/visualize_{}.png".format(iterations)) tmp = (np.clip(fake_y[i] * 127.5 + 127.5, 0, 255)).transpose(1, 2, 0).astype(np.uint8) pylab.subplot(2, 2, 2 * i + 2) pylab.imshow(tmp) pylab.axis("off") pylab.savefig("outdir/visualize_{}.png".format(iterations)) print('iteration: {} dis loss: {} gen loss: {}'.format(iterations, dis_loss, gen_loss))
def main(): # Training settings parser = argparse.ArgumentParser( description='GNN baselines on pcqm4m with PGL') parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument( '--gnn', type=str, default='gin-virtual', help= 'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') parser.add_argument( '--graph_pooling', type=str, default='sum', help='graph pooling strategy mean or sum (default: sum)') parser.add_argument('--drop_ratio', type=float, default=0, help='dropout ratio (default: 0)') parser.add_argument( '--num_layers', type=int, default=5, help='number of GNN message passing layers (default: 5)') parser.add_argument( '--emb_dim', type=int, default=600, help='dimensionality of hidden units in GNNs (default: 600)') parser.add_argument('--train_subset', action='store_true') parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 256)') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') parser.add_argument('--num_workers', type=int, default=1, help='number of workers (default: 1)') parser.add_argument('--log_dir', type=str, default="", help='tensorboard log directory') parser.add_argument('--checkpoint_dir', type=str, default='', help='directory to save checkpoint') parser.add_argument('--save_test_dir', type=str, default='', help='directory to save test submission file') args = parser.parse_args() print(args) random.seed(42) np.random.seed(42) paddle.seed(42) if not args.use_cuda: paddle.set_device("cpu") ### automatic dataloading and splitting class Config(): def __init__(self): self.base_data_path = "./dataset" config = Config() ds = MolDataset(config) split_idx = ds.get_idx_split() test_ds = Subset(ds, split_idx['test']) print("Test exapmles: ", len(test_ds)) ### automatic evaluator. takes dataset name as input evaluator = PCQM4MEvaluator() test_loader = Dataloader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=CollateFn()) shared_params = { 'num_layers': args.num_layers, 'emb_dim': args.emb_dim, 'drop_ratio': args.drop_ratio, 'graph_pooling': args.graph_pooling } if args.gnn == 'gin': model = GNN(gnn_type='gin', virtual_node=False, **shared_params) elif args.gnn == 'gin-virtual': model = GNN(gnn_type='gin', virtual_node=True, **shared_params) elif args.gnn == 'gcn': model = GNN(gnn_type='gcn', virtual_node=False, **shared_params) elif args.gnn == 'gcn-virtual': model = GNN(gnn_type='gcn', virtual_node=True, **shared_params) else: raise ValueError('Invalid GNN type') num_params = sum(p.numel() for p in model.parameters()) print(f'#Params: {num_params}') checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pdparams') if not os.path.exists(checkpoint_path): raise RuntimeError(f'Checkpoint file not found at {checkpoint_path}') model.set_state_dict(paddle.load(checkpoint_path)) print('Predicting on test data...') y_pred = test(model, test_loader) print('Saving test submission file...') evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
def main(): utils.writer = SummaryWriter() parser = argparse.ArgumentParser() parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') parser.add_argument('--epochs', type=int, default=700, help='Number of epochs to train.') parser.add_argument('--link-pred', action='store_true', default=False, help='Enable Link Prediction Loss') parser.add_argument('--dataset', default='ENZYMES', help="Choose dataset: ENZYMES, DD") parser.add_argument('--batch-size', default=256, type=int, help="Choose dataset: ENZYMES, DD") parser.add_argument('--train-ratio', default=0.9, type=float, help="Train/Val split ratio") parser.add_argument('--pool-ratio', default=0.25, type=float, help="Train/Val split ratio") args = parser.parse_args() utils.writer.add_text("args", str(args)) device = "cuda" if not args.no_cuda and torch.cuda.is_available( ) else "cpu" dataset = TUDataset(args.dataset) # dataset = MNIST(root="~/.torch/data/", transform=GraphTransform(device), download=True) dataset_size = len(dataset) train_size = int(dataset_size * args.train_ratio) test_size = dataset_size - train_size max_num_nodes = max([item[0][0].shape[0] for item in dataset]) n_classes = int(max([item[1] for item in dataset])) + 1 train_data, test_data = random_split(dataset, (train_size, test_size)) input_shape = int(dataset[0][0][1].shape[-1]) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, collate_fn=CollateFn(device)) test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True, collate_fn=CollateFn(device)) model = BatchedModel(pool_size=int(max_num_nodes * args.pool_ratio), device=device, link_pred=args.link_pred, input_shape=input_shape, n_classes=n_classes).to(device) model.train() optimizer = optim.Adam(model.parameters()) for e in tqdm(range(args.epochs)): utils.e = e epoch_losses_list = [] true_sample = 0 model.train() for i, (adj, features, masks, batch_labels) in enumerate(train_loader): utils.train_iter += 1 graph_feat = model(features, adj, masks) output = model.classifier(graph_feat) loss = model.loss(output, batch_labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0) optimizer.step() optimizer.zero_grad() epoch_losses_list.append(loss.item()) iter_true_sample = (output.argmax(dim=1).long() == batch_labels.long()). \ float().sum().item() iter_acc = float(iter_true_sample) / output.shape[0] utils.writer.add_scalar("iter train acc", iter_acc, utils.train_iter) print(f"{utils.train_iter} iter train acc: {iter_acc}") true_sample += iter_true_sample acc = true_sample / train_size utils.writer.add_scalar("Epoch Acc", acc, e) tqdm.write(f"Epoch:{e} \t train_acc:{acc:.2f}") test_loss_list = [] true_sample = 0 model.eval() with torch.no_grad(): for i, (adj, features, masks, batch_labels) in enumerate(test_loader): utils.test_iter += 1 graph_feat = model(features, adj, masks) output = model.classifier(graph_feat) loss = model.loss(output, batch_labels) test_loss_list.append(loss.item()) iter_true_sample = (output.argmax(dim=1).long() == batch_labels.long()). \ float().sum().item() iter_acc = float(iter_true_sample) / output.shape[0] utils.writer.add_scalar("iter test acc", iter_acc, utils.test_iter) print(f"{utils.test_iter} iter test acc: {iter_acc}") true_sample += iter_true_sample acc = true_sample / test_size utils.writer.add_scalar("Epoch Acc", acc, e) tqdm.write(f"Epoch:{e} \t val_acc:{acc:.2f}")
def main(): # Training settings parser = argparse.ArgumentParser( description='GNN baselines on pcqm4m with PGL') parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument( '--gnn', type=str, default='gin-virtual', help= 'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') parser.add_argument( '--graph_pooling', type=str, default='sum', help='graph pooling strategy mean or sum (default: sum)') parser.add_argument('--drop_ratio', type=float, default=0, help='dropout ratio (default: 0)') parser.add_argument( '--num_layers', type=int, default=5, help='number of GNN message passing layers (default: 5)') parser.add_argument( '--emb_dim', type=int, default=600, help='dimensionality of hidden units in GNNs (default: 600)') parser.add_argument('--train_subset', action='store_true') parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 256)') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') parser.add_argument('--num_workers', type=int, default=1, help='number of workers (default: 1)') parser.add_argument('--log_dir', type=str, default="", help='tensorboard log directory') parser.add_argument('--checkpoint_dir', type=str, default='', help='directory to save checkpoint') parser.add_argument('--save_test_dir', type=str, default='', help='directory to save test submission file') args = parser.parse_args() print(args) random.seed(42) np.random.seed(42) paddle.seed(42) if not args.use_cuda: paddle.set_device("cpu") ### automatic dataloading and splitting class Config(): def __init__(self): self.base_data_path = "./dataset" config = Config() ds = MolDataset(config) split_idx = ds.get_idx_split() train_ds = Subset(ds, split_idx['train']) valid_ds = Subset(ds, split_idx['valid']) test_ds = Subset(ds, split_idx['test']) print("Train exapmles: ", len(train_ds)) print("Valid exapmles: ", len(valid_ds)) print("Test exapmles: ", len(test_ds)) ### automatic evaluator. takes dataset name as input evaluator = PCQM4MEvaluator() train_loader = Dataloader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=CollateFn()) valid_loader = Dataloader(valid_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=CollateFn()) if args.save_test_dir is not '': test_loader = Dataloader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=CollateFn()) if args.checkpoint_dir is not '': os.makedirs(args.checkpoint_dir, exist_ok=True) shared_params = { 'num_layers': args.num_layers, 'emb_dim': args.emb_dim, 'drop_ratio': args.drop_ratio, 'graph_pooling': args.graph_pooling } if args.gnn == 'gin': model = GNN(gnn_type='gin', virtual_node=False, **shared_params) elif args.gnn == 'gin-virtual': model = GNN(gnn_type='gin', virtual_node=True, **shared_params) elif args.gnn == 'gcn': model = GNN(gnn_type='gcn', virtual_node=False, **shared_params) elif args.gnn == 'gcn-virtual': model = GNN(gnn_type='gcn', virtual_node=True, **shared_params) else: raise ValueError('Invalid GNN type') num_params = sum(p.numel() for p in model.parameters()) print(f'#Params: {num_params}') if args.log_dir is not '': writer = SummaryWriter(log_dir=args.log_dir) best_valid_mae = 1000 scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.001, step_size=300, gamma=0.25) optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters()) msg = "ogbg_lsc_paddle_baseline\n" for epoch in range(1, args.epochs + 1): print("=====Epoch {}".format(epoch)) print('Training...') train_mae = train(model, train_loader, optimizer) print('Evaluating...') valid_mae = eval(model, valid_loader, evaluator) print({'Train': train_mae, 'Validation': valid_mae}) if args.log_dir is not '': writer.add_scalar('valid/mae', valid_mae, epoch) writer.add_scalar('train/mae', train_mae, epoch) if valid_mae < best_valid_mae: best_valid_mae = valid_mae if args.checkpoint_dir is not '': print('Saving checkpoint...') paddle.save( model.state_dict(), os.path.join(args.checkpoint_dir, 'checkpoint.pdparams')) if args.save_test_dir is not '': print('Predicting on test data...') y_pred = test(model, test_loader) print('Saving test submission file...') evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir) scheduler.step() print(f'Best validation MAE so far: {best_valid_mae}') try: msg +="Epoch: %d | Train: %.6f | Valid: %.6f | Best Valid: %.6f\n" \ % (epoch, train_mae, valid_mae, best_valid_mae) print(msg) except: continue if args.log_dir is not '': writer.close()
def main(config): if dist.get_world_size() > 1: dist.init_parallel_env() if dist.get_rank() == 0: timestamp = datetime.now().strftime("%Hh%Mm%Ss") log_path = os.path.join(config.log_dir, "tensorboard_log_%s" % timestamp) writer = SummaryWriter(log_path) log.info("loading data") raw_dataset = GraphPropPredDataset(name=config.dataset_name) config.num_class = raw_dataset.num_tasks config.eval_metric = raw_dataset.eval_metric config.task_type = raw_dataset.task_type mol_dataset = MolDataset(config, raw_dataset, transform=make_multihop_edges) splitted_index = raw_dataset.get_idx_split() train_ds = Subset(mol_dataset, splitted_index['train'], mode='train') valid_ds = Subset(mol_dataset, splitted_index['valid'], mode="valid") test_ds = Subset(mol_dataset, splitted_index['test'], mode="test") log.info("Train Examples: %s" % len(train_ds)) log.info("Val Examples: %s" % len(valid_ds)) log.info("Test Examples: %s" % len(test_ds)) fn = CollateFn(config) train_loader = Dataloader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=fn) valid_loader = Dataloader(valid_ds, batch_size=config.batch_size, num_workers=config.num_workers, collate_fn=fn) test_loader = Dataloader(test_ds, batch_size=config.batch_size, num_workers=config.num_workers, collate_fn=fn) model = ClassifierNetwork(config.hidden_size, config.out_dim, config.num_layers, config.dropout_prob, config.virt_node, config.K, config.conv_type, config.appnp_hop, config.alpha) model = paddle.DataParallel(model) optim = Adam(learning_rate=config.lr, parameters=model.parameters()) criterion = nn.loss.BCEWithLogitsLoss() evaluator = Evaluator(config.dataset_name) best_valid = 0 global_step = 0 for epoch in range(1, config.epochs + 1): model.train() for idx, batch_data in enumerate(train_loader): g, mh_graphs, labels, unmask = batch_data g = g.tensor() multihop_graphs = [] for item in mh_graphs: multihop_graphs.append(item.tensor()) g.multi_hop_graphs = multihop_graphs labels = paddle.to_tensor(labels) unmask = paddle.to_tensor(unmask) pred = model(g) pred = paddle.masked_select(pred, unmask) labels = paddle.masked_select(labels, unmask) train_loss = criterion(pred, labels) train_loss.backward() optim.step() optim.clear_grad() if global_step % 80 == 0: message = "train: epoch %d | step %d | " % (epoch, global_step) message += "loss %.6f" % (train_loss.numpy()) log.info(message) if dist.get_rank() == 0: writer.add_scalar("loss", train_loss.numpy(), global_step) global_step += 1 valid_result = evaluate(model, valid_loader, criterion, evaluator) message = "valid: epoch %d | step %d | " % (epoch, global_step) for key, value in valid_result.items(): message += " | %s %.6f" % (key, value) if dist.get_rank() == 0: writer.add_scalar("valid_%s" % key, value, global_step) log.info(message) test_result = evaluate(model, test_loader, criterion, evaluator) message = "test: epoch %d | step %d | " % (epoch, global_step) for key, value in test_result.items(): message += " | %s %.6f" % (key, value) if dist.get_rank() == 0: writer.add_scalar("test_%s" % key, value, global_step) log.info(message) if best_valid < valid_result[config.metrics]: best_valid = valid_result[config.metrics] best_valid_result = valid_result best_test_result = test_result message = "best result: epoch %d | " % (epoch) message += "valid %s: %.6f | " % (config.metrics, best_valid_result[config.metrics]) message += "test %s: %.6f | " % (config.metrics, best_test_result[config.metrics]) log.info(message) message = "final eval best result:%.6f" % best_valid_result[config.metrics] log.info(message) message = "final test best result:%.6f" % best_test_result[config.metrics] log.info(message)