def main(args): worker_init_fn(None) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') if args.model_path is not None: model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab), hidden_size=args.hidden_size, latent_size=args.latent_size, depth=args.depth) model.load_state_dict(torch.load(args.model_path)) else: model = load_pretrained("JTNN_ZINC") print("# model parameters: {:d}K".format( sum([x.nelement() for x in model.parameters()]) // 1000)) dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=False) dataloader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=JTVAECollator(False), worker_init_fn=worker_init_fn) # Just an example of molecule decoding; in reality you may want to sample # tree and molecule vectors. acc = 0.0 tot = 0 model = model.to(device) model.eval() for it, batch in enumerate(tqdm(dataloader)): gt_smiles = batch['mol_trees'][0].smiles batch = dataset.move_to_device(batch, device) try: _, tree_vec, mol_vec = model.encode(batch) tree_mean = model.T_mean(tree_vec) # Following Mueller et al. tree_log_var = -torch.abs(model.T_var(tree_vec)) epsilon = torch.randn(1, model.latent_size // 2).to(device) tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon mol_mean = model.G_mean(mol_vec) # Following Mueller et al. mol_log_var = -torch.abs(model.G_var(mol_vec)) epsilon = torch.randn(1, model.latent_size // 2).to(device) mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon dec_smiles = model.decode(tree_vec, mol_vec) if dec_smiles == gt_smiles: acc += 1 tot += 1 except Exception as e: print("Failed to encode: {}".format(gt_smiles)) print(e) if it % 20 == 1: print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format( it, len(dataloader), acc / tot)) print("Reconstruction Accuracy: {}".format(acc / tot))
def main(args): torch.multiprocessing.set_sharing_strategy('file_system') worker_init_fn(None) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab), hidden_size=args.hidden_size, latent_size=args.latent_size, depth=args.depth) print("# model parameters: {:d}K".format( sum([x.nelement() for x in model.parameters()]) // 1000)) if args.model_path is not None: model.load_state_dict(torch.load(args.model_path)) else: for param in model.parameters(): if param.dim() == 1: nn.init.constant_(param, 0) else: nn.init.xavier_normal_(param) dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=True) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=JTVAECollator(True), drop_last=True, worker_init_fn=worker_init_fn) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9) model = model.to(device) model.train() for epoch in range(args.max_epoch): word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0 for it, batch in enumerate(dataloader): batch = dataset.move_to_device(batch, device) model.zero_grad() try: loss, kl_div, wacc, tacc, sacc, dacc = model(batch, args.beta) except: print([t.smiles for t in batch['mol_trees']]) raise loss.backward() optimizer.step() word_acc += wacc topo_acc += tacc assm_acc += sacc steo_acc += dacc if (it + 1) % args.print_iter == 0: word_acc = word_acc / args.print_iter * 100 topo_acc = topo_acc / args.print_iter * 100 assm_acc = assm_acc / args.print_iter * 100 steo_acc = steo_acc / args.print_iter * 100 print( "KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item())) word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0 sys.stdout.flush() if (it + 1) % 1500 == 0: # Fast annealing scheduler.step() print("learning rate: %.6f" % scheduler.get_lr()[0]) torch.save( model.state_dict(), args.save_path + "/model.iter-%d-%d" % (epoch, it + 1)) scheduler.step() print("learning rate: %.6f" % scheduler.get_lr()[0]) torch.save(model.state_dict(), args.save_path + "/model.iter-" + str(epoch))