Exemplo n.º 1
0
def train_validity(grammar=True,
                   model=None,
                   EPOCHS=None,
                   BATCH_SIZE=None,
                   lr=2e-4,
                   main_dataset=None,
                   new_datasets=None,
                   plot_ignore_initial=0,
                   save_file=None,
                   plot_prefix='',
                   dashboard='main',
                   preload_weights=False):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    if save_file is not None:
        save_path = root_location + 'pretrained/' + save_file
    else:
        save_path = None
    molecules = True  # checking for validity only makes sense for molecules
    settings = get_settings(molecules=molecules, grammar=grammar)

    # TODO: separate settings for this?
    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    if preload_weights:
        try:
            model.load(save_path)
        except:
            pass

    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    # create the composite loaders
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)
    valid_smile_ds, invalid_smile_ds = new_datasets
    valid_train, valid_val = valid_smile_ds.get_train_valid_loaders(BATCH_SIZE)
    invalid_train, invalid_val = valid_smile_ds.get_train_valid_loaders(
        BATCH_SIZE)
    train_gen = MixedLoader(train_loader, valid_train, invalid_train)
    valid_gen = MixedLoader(valid_loader, valid_val, invalid_val)

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=0.0001,
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = nn.BCELoss(size_average=True)

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 save_path=save_path,
                 dashboard_name=dashboard,
                 plot_ignore_initial=plot_ignore_initial,
                 plot_prefix=plot_prefix)

    return fitter
Exemplo n.º 2
0
def train_vae(molecules = True,
              grammar = True,
              EPOCHS = None,
              BATCH_SIZE = None,
              lr = 2e-4,
              drop_rate = 0.0,
              plot_ignore_initial = 0,
              reg_weight = 1,
              epsilon_std = 0.01,
              sample_z = True,
              save_file = None,
              preload_file = None,
              encoder_type='cnn',
              decoder_type='step',
              plot_prefix = '',
              dashboard = 'main',
              preload_weights=False):

    root_location = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file
    if preload_file is None:
        preload_path = save_path
    else:
        preload_path = root_location + 'pretrained/' + preload_file



    settings = get_settings(molecules=molecules,grammar=grammar)

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE


    model,_ = get_vae(molecules=molecules,
                      grammar=grammar,
                      drop_rate=drop_rate,
                      sample_z = sample_z,
                      rnn_encoder=encoder_type,
                      decoder_type = decoder_type,
                      weights_file=preload_path if preload_weights else None,
                      epsilon_std=epsilon_std
                      )

    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    main_dataset = DatasetFromHDF5(settings['data_path'],'data')
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    train_gen = TwinGenerator(train_loader)
    valid_gen = TwinGenerator(valid_loader)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=min(0.0001,0.1*lr),
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = VAELoss(settings['grammar'], sample_z, reg_weight)

    metric_monitor = MetricPlotter(plot_prefix=plot_prefix,
                                   loss_display_cap=4.0,
                                   dashboard_name=dashboard,
                                   plot_ignore_initial=plot_ignore_initial)

    checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                             save_path=save_path)

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 metric_monitor = metric_monitor,
                 checkpointer = checkpointer)

    return model, fitter, main_dataset
Exemplo n.º 3
0
def train_mol_descriptor(grammar=True,
                         EPOCHS=None,
                         BATCH_SIZE=None,
                         lr=2e-4,
                         gradient_clip=5,
                         drop_rate=0.0,
                         plot_ignore_initial=0,
                         save_file=None,
                         preload_file=None,
                         encoder_type='rnn',
                         plot_prefix='',
                         dashboard='properties',
                         aux_dataset=None,
                         preload_weights=False):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file

    if preload_file is None:
        preload_path = save_path
    else:
        preload_path = root_location + 'pretrained/' + preload_file

    batch_mult = 1 if aux_dataset is None else 2

    settings = get_settings(molecules=True, grammar=grammar)
    max_steps = settings['max_seq_length']

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE
    if False:
        pre_model, _ = get_decoder(True,
                                   grammar,
                                   z_size=settings['z_size'],
                                   decoder_hidden_n=200,
                                   feature_len=settings['feature_len'],
                                   max_seq_length=max_steps,
                                   drop_rate=drop_rate,
                                   decoder_type=encoder_type,
                                   batch_size=BATCH_SIZE * batch_mult)

        class AttentionSimulator(nn.Module):
            def __init__(self, pre_model, drop_rate):
                super().__init__()
                self.pre_model = pre_model
                pre_model_2 = AttentionAggregatingHead(pre_model,
                                                       drop_rate=drop_rate)
                pre_model_2.model_out_transform = lambda x: x[1]
                self.model = MeanVarianceSkewHead(pre_model_2,
                                                  4,
                                                  drop_rate=drop_rate)

            def forward(self, x):
                self.pre_model.policy = PolicyFromTarget(x)
                return self.model(None)

        model = to_gpu(AttentionSimulator(pre_model, drop_rate=drop_rate))
    else:
        pre_model = get_encoder(feature_len=settings['feature_len'],
                                max_seq_length=settings['max_seq_length'],
                                cnn_encoder_params={
                                    'kernel_sizes': (2, 3, 4),
                                    'filters': (2, 3, 4),
                                    'dense_size': 100
                                },
                                drop_rate=drop_rate,
                                encoder_type=encoder_type)

        model = MeanVarianceSkewHead(pre_model, 4, drop_rate=drop_rate)

    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    main_dataset = MultiDatasetFromHDF5(settings['data_path'],
                                        ['actions', 'smiles'])
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    def scoring_fun(x):
        if isinstance(x, tuple) or isinstance(x, list):
            x = {'actions': x[0], 'smiles': x[1]}
        out_x = to_gpu(x['actions'])
        end_of_slice = randint(3, out_x.size()[1])
        #TODO inject random slicing back
        out_x = out_x[:, 0:end_of_slice]
        smiles = x['smiles']
        scores = to_gpu(
            torch.from_numpy(property_scorer(smiles).astype(np.float32)))
        return out_x, scores

    train_gen_main = IterableTransform(train_loader, scoring_fun)
    valid_gen_main = IterableTransform(valid_loader, scoring_fun)

    if aux_dataset is not None:
        train_aux, valid_aux = SamplingWrapper(aux_dataset) \
            .get_train_valid_loaders(BATCH_SIZE,
                                     dataset_name=['actions',
                                                   'smiles'])
        train_gen_aux = IterableTransform(train_aux, scoring_fun)
        valid_gen_aux = IterableTransform(valid_aux, scoring_fun)
        train_gen = CombinedLoader([train_gen_main, train_gen_aux],
                                   num_batches=90)
        valid_gen = CombinedLoader([valid_gen_main, valid_gen_aux],
                                   num_batches=10)
    else:
        train_gen = train_gen_main  #CombinedLoader([train_gen_main, train_gen_aux], num_batches=90)
        valid_gen = valid_gen_main  #CombinedLoader([valid_gen_main, valid_gen_aux], num_batches=10)

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=min(0.0001, 0.1 * lr),
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = VariationalLoss(['valid', 'logP', 'SA', 'cyc_sc'])

    metric_monitor = MetricPlotter(plot_prefix=plot_prefix,
                                   loss_display_cap=4.0,
                                   dashboard_name=dashboard,
                                   plot_ignore_initial=plot_ignore_initial)

    checkpointer = Checkpointer(valid_batches_to_checkpoint=10,
                                save_path=save_path)

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 grad_clip=gradient_clip,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 metric_monitor=metric_monitor,
                 checkpointer=checkpointer)

    return model, fitter, main_dataset
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_off=1e-4,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          save_file=None,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file
    smiles_save_path = root_location + 'pretrained/' + smiles_save_file

    settings = get_settings(molecules=molecules, grammar=grammar)

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    save_dataset = IncrementingHDF5Dataset(smiles_save_path)

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=reward_fun_on,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    def get_model(sanity_checks=sanity_checks):
        return get_decoder(
            molecules,
            grammar,
            z_size=settings['z_size'],
            decoder_hidden_n=200,
            feature_len=None,  # made redundant, need to factor it out
            max_seq_length=max_steps,
            drop_rate=drop_rate,
            decoder_type=decoder_type,
            task=task)[0]

    model = get_model()

    if preload_file is not None:
        try:
            preload_path = root_location + 'pretrained/' + preload_file
            model.load_state_dict(torch.load(preload_path))
        except:
            pass

    anchor_model = None
    if anchor_file is not None:
        anchor_model = get_model()
        try:
            anchor_path = root_location + 'pretrained/' + anchor_file
            anchor_model.load_state_dict(torch.load(anchor_path))
        except:
            anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        actions, logits, rewards, terminals, info = model_out
        smiles, valid = info
        total_rewards = rewards.sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'test.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except:
                pass
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    def get_fitter(model,
                   loss_obj,
                   fit_plot_prefix='',
                   model_process_fun=None,
                   lr=None,
                   loss_display_cap=float('inf'),
                   anchor_model=None,
                   anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        def my_gen():
            for _ in range(1000):
                yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

        fitter = fit_rl(train_gen=my_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        metric_monitor=metric_monitor,
                        checkpointer=checkpointer)

        return fitter

    # the on-policy fitter
    fitter1 = get_fitter(model,
                         PolicyGradientLoss(on_policy_loss_type),
                         plot_prefix + 'on-policy',
                         model_process_fun=model_process_fun,
                         lr=lr_on,
                         anchor_model=anchor_model,
                         anchor_weight=anchor_weight)
    # get existing molecule data to add training
    main_dataset = DatasetFromHDF5(settings['data_path'], 'actions')

    # TODO change call to a simple DataLoader, no validation
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    fitter2 = get_fitter(model,
                         PolicyGradientLoss(off_policy_loss_type),
                         plot_prefix + ' off-policy',
                         lr=lr_off,
                         model_process_fun=model_process_fun,
                         loss_display_cap=125)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy()
            yield next(fitter)

    def off_policy_gen(fitter, data_gen, model):
        while True:
            data_iter = data_gen.__iter__()
            try:
                x_actions = next(data_iter).to(torch.int64)
                model.policy = PolicyFromTarget(x_actions)
                yield next(fitter)
            except StopIteration:
                data_iter = data_gen.__iter__()

    return model, on_policy_gen(fitter1,
                                model), off_policy_gen(fitter2, train_loader,
                                                       model)