Ejemplo n.º 1
0
def prepare_for_evaluation(rank, args):
    worker_seed = args['seed'] + rank * 10000
    set_random_seed(worker_seed)
    torch.set_num_threads(1)

    # Setup dataset and data loader
    dataset = MoleculeDataset(args['dataset'],
                              subset_id=rank,
                              n_subsets=args['num_processes'])

    # Initialize model
    if not args['pretrained']:
        model = DGMG(atom_types=dataset.atom_types,
                     bond_types=dataset.bond_types,
                     node_hidden_size=args['node_hidden_size'],
                     num_prop_rounds=args['num_propagation_rounds'],
                     dropout=args['dropout'])
        model.load_state_dict(
            torch.load(args['model_path'])['model_state_dict'])
    else:
        model = load_pretrained('_'.join(
            ['DGMG', args['dataset'], args['order']]),
                                log=False)
    model.eval()

    worker_num_samples = args['num_samples'] // args['num_processes']
    if rank == args['num_processes'] - 1:
        worker_num_samples += args['num_samples'] % args['num_processes']

    worker_log_dir = os.path.join(args['log_dir'], str(rank))
    mkdir_p(worker_log_dir, log=False)
    generate_and_save(worker_log_dir, worker_num_samples,
                      args['max_num_steps'], model)
Ejemplo n.º 2
0
def test_dgmg():
    model = DGMG(atom_types=['O', 'Cl', 'C', 'S', 'F', 'Br', 'N'],
                 bond_types=[Chem.rdchem.BondType.SINGLE,
                             Chem.rdchem.BondType.DOUBLE,
                             Chem.rdchem.BondType.TRIPLE],
                 node_hidden_size=1,
                 num_prop_rounds=1,
                 dropout=0.2)
    assert model(
        actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)], rdkit_mol=True) == 'CO'
    assert model(rdkit_mol=False) is None
    model.eval()
    assert model(rdkit_mol=True) is not None

    model = DGMG(atom_types=['O', 'Cl', 'C', 'S', 'F', 'Br', 'N'],
                 bond_types=[Chem.rdchem.BondType.SINGLE,
                             Chem.rdchem.BondType.DOUBLE,
                             Chem.rdchem.BondType.TRIPLE])
    assert model(
        actions=[(0, 2), (1, 3), (0, 0), (1, 0), (2, 0), (1, 3), (0, 7)], rdkit_mol=True) == 'CO'
    assert model(rdkit_mol=False) is None
    model.eval()
    assert model(rdkit_mol=True) is not None
Ejemplo n.º 3
0
def main(rank, args):
    """
    Parameters
    ----------
    rank : int
        Subprocess id
    args : dict
        Configuration
    """
    if rank == 0:
        t1 = time.time()

    set_random_seed(args['seed'])
    # Remove the line below will result in problems for multiprocess
    torch.set_num_threads(1)

    # Setup dataset and data loader
    dataset = MoleculeDataset(args['dataset'],
                              args['order'], ['train', 'val'],
                              subset_id=rank,
                              n_subsets=args['num_processes'])

    # Note that currently the batch size for the loaders should only be 1.
    train_loader = DataLoader(dataset.train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=dataset.collate)
    val_loader = DataLoader(dataset.val_set,
                            batch_size=args['batch_size'],
                            shuffle=True,
                            collate_fn=dataset.collate)

    if rank == 0:
        try:
            from tensorboardX import SummaryWriter
            writer = SummaryWriter(args['log_dir'])
        except ImportError:
            print(
                'If you want to use tensorboard, install tensorboardX with pip.'
            )
            writer = None
        train_printer = Printer(args['nepochs'], len(dataset.train_set),
                                args['batch_size'], writer)
        val_printer = Printer(args['nepochs'], len(dataset.val_set),
                              args['batch_size'])
    else:
        val_printer = None

    # Initialize model
    model = DGMG(atom_types=dataset.atom_types,
                 bond_types=dataset.bond_types,
                 node_hidden_size=args['node_hidden_size'],
                 num_prop_rounds=args['num_propagation_rounds'],
                 dropout=args['dropout'])

    if args['num_processes'] == 1:
        from utils import Optimizer
        optimizer = Optimizer(args['lr'],
                              Adam(model.parameters(), lr=args['lr']))
    else:
        from utils import MultiProcessOptimizer
        optimizer = MultiProcessOptimizer(
            args['num_processes'], args['lr'],
            Adam(model.parameters(), lr=args['lr']))

    if rank == 0:
        t2 = time.time()
    best_val_prob = 0

    # Training
    for epoch in range(args['nepochs']):
        model.train()
        if rank == 0:
            print('Training')

        for i, data in enumerate(train_loader):
            log_prob = model(actions=data, compute_log_prob=True)
            prob = log_prob.detach().exp()

            loss_averaged = -log_prob
            prob_averaged = prob
            optimizer.backward_and_step(loss_averaged)
            if rank == 0:
                train_printer.update(epoch + 1, loss_averaged.item(),
                                     prob_averaged.item())

        synchronize(args['num_processes'])

        # Validation
        val_log_prob = evaluate(epoch, model, val_loader, val_printer)
        if args['num_processes'] > 1:
            dist.all_reduce(val_log_prob, op=dist.ReduceOp.SUM)
        val_log_prob /= args['num_processes']
        # Strictly speaking, the computation of probability here is different from what is
        # performed on the training set as we first take an average of log likelihood and then
        # take the exponentiation. By Jensen's inequality, the resulting value is then a
        # lower bound of the real probabilities.
        val_prob = (-val_log_prob).exp().item()
        val_log_prob = val_log_prob.item()
        if val_prob >= best_val_prob:
            if rank == 0:
                torch.save({'model_state_dict': model.state_dict()},
                           args['checkpoint_dir'])
                print(
                    'Old val prob {:.10f} | new val prob {:.10f} | model saved'
                    .format(best_val_prob, val_prob))
            best_val_prob = val_prob
        elif epoch >= args['warmup_epochs']:
            optimizer.decay_lr()

        if rank == 0:
            print('Validation')
            if writer is not None:
                writer.add_scalar('validation_log_prob', val_log_prob, epoch)
                writer.add_scalar('validation_prob', val_prob, epoch)
                writer.add_scalar('lr', optimizer.lr, epoch)
            print('Validation log prob {:.4f} | prob {:.10f}'.format(
                val_log_prob, val_prob))

        synchronize(args['num_processes'])

    if rank == 0:
        t3 = time.time()
        print('It took {} to setup.'.format(datetime.timedelta(seconds=t2 -
                                                               t1)))
        print('It took {} to finish training.'.format(
            datetime.timedelta(seconds=t3 - t2)))
        print(
            '--------------------------------------------------------------------------'
        )
        print('On average, an epoch takes {}.'.format(
            datetime.timedelta(seconds=(t3 - t2) / args['nepochs'])))