コード例 #1
0
def main(args):
    worker_init_fn(None)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if args.model_path is not None:
        model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab),
                           hidden_size=args.hidden_size,
                           latent_size=args.latent_size,
                           depth=args.depth)
        model.load_state_dict(torch.load(args.model_path))
    else:
        model = load_pretrained("JTNN_ZINC")
    print("# model parameters: {:d}K".format(
        sum([x.nelement() for x in model.parameters()]) // 1000))

    dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=False)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=JTVAECollator(False),
        worker_init_fn=worker_init_fn)
    # Just an example of molecule decoding; in reality you may want to sample
    # tree and molecule vectors.
    acc = 0.0
    tot = 0
    model = model.to(device)
    model.eval()
    for it, batch in enumerate(tqdm(dataloader)):
        gt_smiles = batch['mol_trees'][0].smiles
        batch = dataset.move_to_device(batch, device)
        try:
            _, tree_vec, mol_vec = model.encode(batch)

            tree_mean = model.T_mean(tree_vec)
            # Following Mueller et al.
            tree_log_var = -torch.abs(model.T_var(tree_vec))
            epsilon = torch.randn(1, model.latent_size // 2).to(device)
            tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon

            mol_mean = model.G_mean(mol_vec)
            # Following Mueller et al.
            mol_log_var = -torch.abs(model.G_var(mol_vec))
            epsilon = torch.randn(1, model.latent_size // 2).to(device)
            mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon

            dec_smiles = model.decode(tree_vec, mol_vec)

            if dec_smiles == gt_smiles:
                acc += 1
            tot += 1
        except Exception as e:
            print("Failed to encode: {}".format(gt_smiles))
            print(e)

        if it % 20 == 1:
            print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format(
                it, len(dataloader), acc / tot))

    print("Reconstruction Accuracy: {}".format(acc / tot))
コード例 #2
0
ファイル: reconstruct_eval.py プロジェクト: jjhu94/dgl-1
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)

model = DGLJTNNVAE(vocab_file=vocab_file,
                   hidden_size=hidden_size,
                   latent_size=latent_size,
                   depth=depth)

if args.model_path is not None:
    model.load_state_dict(torch.load(args.model_path))
else:
    model = load_pretrained("JTNN_ZINC")

model = cuda(model)
model.eval()
print("Model #Params: %dK" %
      (sum([x.nelement() for x in model.parameters()]) / 1000,))

MAX_EPOCH = 100
PRINT_ITER = 20

def reconstruct():
    dataset.training = False
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=JTNNCollator(dataset.vocab, False),
        drop_last=True,