示例#1
0
def get_vae(molecules=True,
            grammar=True,
            weights_file=None,
            epsilon_std=1,
            decoder_type='step',
            **kwargs):
    model_args = get_model_args(molecules=molecules, grammar=grammar)
    for key, value in kwargs.items():
        if key in model_args:
            model_args[key] = value
    sample_z = model_args.pop('sample_z')

    encoder_args = [
        'feature_len', 'max_seq_length', 'cnn_encoder_params', 'drop_rate',
        'encoder_type', 'rnn_encoder_hidden_n'
    ]
    encoder = get_encoder(**{
        key: value
        for key, value in model_args.items() if key in encoder_args
    })

    decoder_args = [
        'z_size', 'decoder_hidden_n', 'feature_len', 'max_seq_length',
        'drop_rate', 'batch_size'
    ]
    decoder, _ = get_decoder(molecules,
                             grammar,
                             decoder_type=decoder_type,
                             **{
                                 key: value
                                 for key, value in model_args.items()
                                 if key in decoder_args
                             })

    model = generative_playground.models.heads.vae.VariationalAutoEncoderHead(
        encoder=encoder,
        decoder=decoder,
        sample_z=sample_z,
        epsilon_std=epsilon_std,
        z_size=model_args['z_size'])

    if weights_file is not None:
        model.load(weights_file)

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules,
                      grammar,
                      max_seq_length=settings['max_seq_length'])
    # codec.set_model(model)  # todo do we ever use this?
    return model, codec
示例#2
0
def train_mol_descriptor(grammar=True,
                         EPOCHS=None,
                         BATCH_SIZE=None,
                         lr=2e-4,
                         gradient_clip=5,
                         drop_rate=0.0,
                         plot_ignore_initial=0,
                         save_file=None,
                         preload_file=None,
                         encoder_type='rnn',
                         plot_prefix='',
                         dashboard='properties',
                         aux_dataset=None,
                         preload_weights=False):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file

    if preload_file is None:
        preload_path = save_path
    else:
        preload_path = root_location + 'pretrained/' + preload_file

    batch_mult = 1 if aux_dataset is None else 2

    settings = get_settings(molecules=True, grammar=grammar)
    max_steps = settings['max_seq_length']

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE
    if False:
        pre_model, _ = get_decoder(True,
                                   grammar,
                                   z_size=settings['z_size'],
                                   decoder_hidden_n=200,
                                   feature_len=settings['feature_len'],
                                   max_seq_length=max_steps,
                                   drop_rate=drop_rate,
                                   decoder_type=encoder_type,
                                   batch_size=BATCH_SIZE * batch_mult)

        class AttentionSimulator(nn.Module):
            def __init__(self, pre_model, drop_rate):
                super().__init__()
                self.pre_model = pre_model
                pre_model_2 = AttentionAggregatingHead(pre_model,
                                                       drop_rate=drop_rate)
                pre_model_2.model_out_transform = lambda x: x[1]
                self.model = MeanVarianceSkewHead(pre_model_2,
                                                  4,
                                                  drop_rate=drop_rate)

            def forward(self, x):
                self.pre_model.policy = PolicyFromTarget(x)
                return self.model(None)

        model = to_gpu(AttentionSimulator(pre_model, drop_rate=drop_rate))
    else:
        pre_model = get_encoder(feature_len=settings['feature_len'],
                                max_seq_length=settings['max_seq_length'],
                                cnn_encoder_params={
                                    'kernel_sizes': (2, 3, 4),
                                    'filters': (2, 3, 4),
                                    'dense_size': 100
                                },
                                drop_rate=drop_rate,
                                encoder_type=encoder_type)

        model = MeanVarianceSkewHead(pre_model, 4, drop_rate=drop_rate)

    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    main_dataset = MultiDatasetFromHDF5(settings['data_path'],
                                        ['actions', 'smiles'])
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    def scoring_fun(x):
        if isinstance(x, tuple) or isinstance(x, list):
            x = {'actions': x[0], 'smiles': x[1]}
        out_x = to_gpu(x['actions'])
        end_of_slice = randint(3, out_x.size()[1])
        #TODO inject random slicing back
        out_x = out_x[:, 0:end_of_slice]
        smiles = x['smiles']
        scores = to_gpu(
            torch.from_numpy(property_scorer(smiles).astype(np.float32)))
        return out_x, scores

    train_gen_main = IterableTransform(train_loader, scoring_fun)
    valid_gen_main = IterableTransform(valid_loader, scoring_fun)

    if aux_dataset is not None:
        train_aux, valid_aux = SamplingWrapper(aux_dataset) \
            .get_train_valid_loaders(BATCH_SIZE,
                                     dataset_name=['actions',
                                                   'smiles'])
        train_gen_aux = IterableTransform(train_aux, scoring_fun)
        valid_gen_aux = IterableTransform(valid_aux, scoring_fun)
        train_gen = CombinedLoader([train_gen_main, train_gen_aux],
                                   num_batches=90)
        valid_gen = CombinedLoader([valid_gen_main, valid_gen_aux],
                                   num_batches=10)
    else:
        train_gen = train_gen_main  #CombinedLoader([train_gen_main, train_gen_aux], num_batches=90)
        valid_gen = valid_gen_main  #CombinedLoader([valid_gen_main, valid_gen_aux], num_batches=10)

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=min(0.0001, 0.1 * lr),
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = VariationalLoss(['valid', 'logP', 'SA', 'cyc_sc'])

    metric_monitor = MetricPlotter(plot_prefix=plot_prefix,
                                   loss_display_cap=4.0,
                                   dashboard_name=dashboard,
                                   plot_ignore_initial=plot_ignore_initial)

    checkpointer = Checkpointer(valid_batches_to_checkpoint=10,
                                save_path=save_path)

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 grad_clip=gradient_clip,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 metric_monitor=metric_monitor,
                 checkpointer=checkpointer)

    return model, fitter, main_dataset