Exemplo n.º 1
0
def test_jtnn():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    model = DGLJTNNVAE(hidden_size=1,
                       latent_size=2,
                       depth=1).to(device)
Exemplo n.º 2
0
def main(args):
    worker_init_fn(None)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if args.model_path is not None:
        model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab),
                           hidden_size=args.hidden_size,
                           latent_size=args.latent_size,
                           depth=args.depth)
        model.load_state_dict(torch.load(args.model_path))
    else:
        model = load_pretrained("JTNN_ZINC")
    print("# model parameters: {:d}K".format(
        sum([x.nelement() for x in model.parameters()]) // 1000))

    dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=False)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=JTVAECollator(False),
        worker_init_fn=worker_init_fn)
    # Just an example of molecule decoding; in reality you may want to sample
    # tree and molecule vectors.
    acc = 0.0
    tot = 0
    model = model.to(device)
    model.eval()
    for it, batch in enumerate(tqdm(dataloader)):
        gt_smiles = batch['mol_trees'][0].smiles
        batch = dataset.move_to_device(batch, device)
        try:
            _, tree_vec, mol_vec = model.encode(batch)

            tree_mean = model.T_mean(tree_vec)
            # Following Mueller et al.
            tree_log_var = -torch.abs(model.T_var(tree_vec))
            epsilon = torch.randn(1, model.latent_size // 2).to(device)
            tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon

            mol_mean = model.G_mean(mol_vec)
            # Following Mueller et al.
            mol_log_var = -torch.abs(model.G_var(mol_vec))
            epsilon = torch.randn(1, model.latent_size // 2).to(device)
            mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon

            dec_smiles = model.decode(tree_vec, mol_vec)

            if dec_smiles == gt_smiles:
                acc += 1
            tot += 1
        except Exception as e:
            print("Failed to encode: {}".format(gt_smiles))
            print(e)

        if it % 20 == 1:
            print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format(
                it, len(dataloader), acc / tot))

    print("Reconstruction Accuracy: {}".format(acc / tot))
Exemplo n.º 3
0
def main(args):
    torch.multiprocessing.set_sharing_strategy('file_system')

    worker_init_fn(None)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab),
                       hidden_size=args.hidden_size,
                       latent_size=args.latent_size,
                       depth=args.depth)
    print("# model parameters: {:d}K".format(
        sum([x.nelement() for x in model.parameters()]) // 1000))

    if args.model_path is not None:
        model.load_state_dict(torch.load(args.model_path))
    else:
        for param in model.parameters():
            if param.dim() == 1:
                nn.init.constant_(param, 0)
            else:
                nn.init.xavier_normal_(param)

    dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=True)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4,
                            collate_fn=JTVAECollator(True),
                            drop_last=True,
                            worker_init_fn=worker_init_fn)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)

    model = model.to(device)
    model.train()
    for epoch in range(args.max_epoch):
        word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0

        for it, batch in enumerate(dataloader):
            batch = dataset.move_to_device(batch, device)
            model.zero_grad()
            try:
                loss, kl_div, wacc, tacc, sacc, dacc = model(batch, args.beta)
            except:
                print([t.smiles for t in batch['mol_trees']])
                raise
            loss.backward()
            optimizer.step()

            word_acc += wacc
            topo_acc += tacc
            assm_acc += sacc
            steo_acc += dacc

            if (it + 1) % args.print_iter == 0:
                word_acc = word_acc / args.print_iter * 100
                topo_acc = topo_acc / args.print_iter * 100
                assm_acc = assm_acc / args.print_iter * 100
                steo_acc = steo_acc / args.print_iter * 100

                print(
                    "KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f"
                    % (kl_div, word_acc, topo_acc, assm_acc, steo_acc,
                       loss.item()))
                word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
                sys.stdout.flush()

            if (it + 1) % 1500 == 0:  # Fast annealing
                scheduler.step()
                print("learning rate: %.6f" % scheduler.get_lr()[0])
                torch.save(
                    model.state_dict(),
                    args.save_path + "/model.iter-%d-%d" % (epoch, it + 1))

        scheduler.step()
        print("learning rate: %.6f" % scheduler.get_lr()[0])
        torch.save(model.state_dict(),
                   args.save_path + "/model.iter-" + str(epoch))
Exemplo n.º 4
0
                    default=1e-3,
                    help="Learning Rate")
args = parser.parse_args()

dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=True)
vocab_file = dataset.vocab_file

batch_size = int(args.batch_size)
hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)
beta = float(args.beta)
lr = float(args.lr)

model = DGLJTNNVAE(vocab_file=vocab_file,
                   hidden_size=hidden_size,
                   latent_size=latent_size,
                   depth=depth)

if args.model_path is not None:
    model.load_state_dict(torch.load(args.model_path))
else:
    for param in model.parameters():
        if param.dim() == 1:
            nn.init.constant_(param, 0)
        else:
            nn.init.xavier_normal_(param)

model = cuda(model)
print("Model #Params: %dK" % (sum([x.nelement()
                                   for x in model.parameters()]) / 1000, ))
Exemplo n.º 5
0
                    help="Latent Size of node(atom) features and edge(atom) features, "
                         "should be consistent with pre-trained model")
parser.add_argument("-d", "--depth", dest="depth", default=3,
                    help="Depth of message passing hops, "
                         "should be consistent with pre-trained model")
args = parser.parse_args()

dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=False)
vocab_file = dataset.vocab_file

hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)

model = DGLJTNNVAE(vocab_file=vocab_file,
                   hidden_size=hidden_size,
                   latent_size=latent_size,
                   depth=depth)

if args.model_path is not None:
    model.load_state_dict(torch.load(args.model_path))
else:
    model = load_pretrained("JTNN_ZINC")

model = cuda(model)
model.eval()
print("Model #Params: %dK" %
      (sum([x.nelement() for x in model.parameters()]) / 1000,))

MAX_EPOCH = 100
PRINT_ITER = 20