def main(args): worker_init_fn(None) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') if args.model_path is not None: model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab), hidden_size=args.hidden_size, latent_size=args.latent_size, depth=args.depth) model.load_state_dict(torch.load(args.model_path)) else: model = load_pretrained("JTNN_ZINC") print("# model parameters: {:d}K".format( sum([x.nelement() for x in model.parameters()]) // 1000)) dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=False) dataloader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=JTVAECollator(False), worker_init_fn=worker_init_fn) # Just an example of molecule decoding; in reality you may want to sample # tree and molecule vectors. acc = 0.0 tot = 0 model = model.to(device) model.eval() for it, batch in enumerate(tqdm(dataloader)): gt_smiles = batch['mol_trees'][0].smiles batch = dataset.move_to_device(batch, device) try: _, tree_vec, mol_vec = model.encode(batch) tree_mean = model.T_mean(tree_vec) # Following Mueller et al. tree_log_var = -torch.abs(model.T_var(tree_vec)) epsilon = torch.randn(1, model.latent_size // 2).to(device) tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon mol_mean = model.G_mean(mol_vec) # Following Mueller et al. mol_log_var = -torch.abs(model.G_var(mol_vec)) epsilon = torch.randn(1, model.latent_size // 2).to(device) mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon dec_smiles = model.decode(tree_vec, mol_vec) if dec_smiles == gt_smiles: acc += 1 tot += 1 except Exception as e: print("Failed to encode: {}".format(gt_smiles)) print(e) if it % 20 == 1: print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format( it, len(dataloader), acc / tot)) print("Reconstruction Accuracy: {}".format(acc / tot))
hidden_size = int(args.hidden_size) latent_size = int(args.latent_size) depth = int(args.depth) model = DGLJTNNVAE(vocab_file=vocab_file, hidden_size=hidden_size, latent_size=latent_size, depth=depth) if args.model_path is not None: model.load_state_dict(torch.load(args.model_path)) else: model = load_pretrained("JTNN_ZINC") model = cuda(model) model.eval() print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) MAX_EPOCH = 100 PRINT_ITER = 20 def reconstruct(): dataset.training = False dataloader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=JTNNCollator(dataset.vocab, False), drop_last=True,