Beispiel #1
0
    def __call__(self, dataset_path):
        """
        :param dataset_path: This text file should specify a selection of reactant bags.
        :return: the z embeddings given by the Molecule Chef associated with the given reactant bags.
        """
        # == Read in the data and set up a Dataloader ==
        trsfm = symbol_sequence_data.TrsfmSeqStrToArray(symbol_sequence_data.StopSymbolDetails(True, self.stop_symbol_indx),
                                                        shuffle_seq_flag=True,
                                                        rng=self.rng)
        trsfm_and_tuple_indx = lambda x: trsfm(x)[0]
        reaction_bags_dataset = symbol_sequence_data.SymbolSequenceDataset(dataset_path, trsfm_and_tuple_indx)
        dataloader = DataLoader(reaction_bags_dataset, batch_size=500, shuffle=False,
                                              collate_fn=symbol_sequence_data.reorder_and_pad_then_collate)

        # == Now go through this data in batches and calculate the z embeddings. ==
        # A subtle point is that out collate function above reorders the items of the batch -- so that it is compatible
        # with the packed/padded sequences of PyTorch (where longest seqs have to come first). So we need to flip back
        # this order to give zs in the collect order.
        results = []
        with tqdm.tqdm(dataloader, total=len(dataloader)) as t:
            for padded_seq_batch_first, lengths, order in t:
                packed_seq = rnn.pack_padded_sequence(padded_seq_batch_first, lengths, batch_first=True)
                packed_seq = self.cuda_details.return_cudafied(packed_seq)

                zs = self.ae._run_through_to_z(packed_seq)
                zs_np = zs.detach().cpu().numpy()
                order_np = order.cpu().numpy()

                reverse_order = np.argsort(order_np)
                zs_np = zs_np[reverse_order]
                results.append(zs_np)
        results = np.concatenate(results)
        results_tuples = [(row,) for row in results]
        return results_tuples
def train_molecule_chef_qed_hiv(property_predictor,
                                predictor_label_to_optimize):
    params = Params()

    # Set the random seeds.
    rng = np.random.RandomState(5156416)
    torch.manual_seed(rng.choice(1000000))

    # Set up data
    # == The property data
    train_prop_dataset, val_prop_dataset = (
        get_train_and_val_product_property_datasets(
            params, property_predictor, predictor_label_to_optimize))
    print("Created property datasets!")

    # == The sequence data
    stop_symbol_idx = mchef_config.get_num_graphs(
    )  # comes after al the graphs
    trsfm = symbol_sequence_data.TrsfmSeqStrToArray(
        symbol_sequence_data.StopSymbolDetails(True, stop_symbol_idx),
        shuffle_seq_flag=True,
        rng=rng)

    reaction_bags_dataset = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_train, trsfm)
    reaction_train_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset, train_prop_dataset)
    train_dataloader = DataLoader(reaction_train_dataset,
                                  batch_size=params.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_datasets_func)

    reaction_bags_dataset_val = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_val, trsfm)
    reaction_val_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset_val, val_prop_dataset)
    val_dataloader = DataLoader(reaction_val_dataset,
                                batch_size=500,
                                shuffle=False,
                                collate_fn=collate_datasets_func)

    # == The graph data
    indices_to_graphs = atom_features_dataset.PickledGraphDataset(
        params.path_mol_details, params.cuda_details)
    assert stop_symbol_idx == len(
        indices_to_graphs), "stop symbol index should be after graphs"

    # Set up Model
    mol_chef_params = get_mchef.MChefParams(params.cuda_details,
                                            indices_to_graphs,
                                            len(indices_to_graphs),
                                            stop_symbol_idx, params.latent_dim)
    mc_wae = get_mchef.get_mol_chef(mol_chef_params)
    mc_wae = params.cuda_details.return_cudafied(mc_wae)

    # set up trainer
    optimizer = optim.Adam(mc_wae.parameters(), lr=params.learning_rate)
    lr_scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer, gamma=params.lr_reduction_factor)

    # Set up some loggers
    tb_writer_train = tb_.get_tb_writer(
        f"{TB_LOGS_FILE}/{params.run_name}_train")
    tb_writer_val = tb_.get_tb_writer(f"{TB_LOGS_FILE}/{params.run_name}_val")

    def add_details_to_train(dict_to_add):
        for name, value in dict_to_add.items():
            tb_writer_train.add_scalar(name, value)

    train_log_helper = logging_tools.LogHelper([add_details_to_train])
    tb_writer_train.global_step = 0

    # Set up steps and setup funcs.
    def optimizer_step():
        optimizer.step()
        tb_writer_train.global_step += 1

    def setup_for_train():
        mc_wae._logger_manager = train_log_helper
        mc_wae.train()  # put in train mode

    def setup_for_val():
        tb_writer_val.global_step = tb_writer_train.global_step
        mc_wae._tb_logger = None  # turn off the more concise logging
        mc_wae.eval()

    # Run an initial validation
    setup_for_val()
    best_ae_obj_sofar = validation(val_dataloader, mc_wae, tb_writer_val,
                                   params.cuda_details,
                                   params.property_pred_factor,
                                   params.lambda_value)

    # Train!
    for epoch_num in range(params.num_epochs):
        print(f"We are starting epoch {epoch_num}")
        tb_writer_train.add_scalar("epoch_num", epoch_num)
        setup_for_train()
        train(train_dataloader, mc_wae, optimizer, optimizer_step,
              params.cuda_details, tb_writer_train, params.lambda_value,
              params.property_pred_factor)

        print("Switching to eval.")
        setup_for_val()
        ae_obj = validation(val_dataloader, mc_wae, tb_writer_val,
                            params.cuda_details, params.property_pred_factor,
                            params.lambda_value)

        if ae_obj >= best_ae_obj_sofar:
            print("** Best LL found so far! :-) **")
            best_ae_obj_sofar = ae_obj
            best_flag = True
        else:
            best_flag = False

        save_checkpoint(
            dict(epochs_completed=epoch_num + 1,
                 wae_state_dict=mc_wae.state_dict(),
                 optimizer=optimizer.state_dict(),
                 learning_rate_scheduler=lr_scheduler.state_dict(),
                 ll_from_val=ae_obj,
                 wae_lambda_value=params.property_pred_factor,
                 stop_symbol_idx=stop_symbol_idx),
            is_best=best_flag,
            filename=path.join(
                CHKPT_FOLDER,
                f"{params.run_name}-{datetime.datetime.now().isoformat()}.pth.pick"
            ))

        # See https://github.com/pytorch/pytorch/pull/7889, in PyTorch 1.1 you have to call scheduler after:
        if (epoch_num % params.lr_reduction_interval == 0
                and epoch_num / params.lr_reduction_interval > 0.9):
            print("Running the learning rate scheduler. Optimizer is:")
            lr_scheduler.step()
            print(optimizer)
        print(
            f"=========================================================================================="
        )
def main(params: Params):
    rng = np.random.RandomState(1001)
    torch.manual_seed(rng.choice(10000000))

    # == Lets get the model! ==
    molchef_wae, latent_dim, stop_symbol_idx = load_in_mchef(params.weights_to_use, cuda_details=params.cuda_details,
                                 path_molecule_details=params.path_mol_details)

    # == Let's get the mapping from id to SMILES ==
    seq_to_smi_list = mt.MapSeqsToReactants()

    # == Get dataset  ==
    trsfm = symbol_sequence_data.TrsfmSeqStrToArray(symbol_sequence_data.StopSymbolDetails(True, stop_symbol_idx),
                                                    shuffle_seq_flag=True,
                                                    rng=rng)
    reaction_bags_dataset = symbol_sequence_data.SymbolSequenceDataset(params.path_react_bags_train, trsfm)

    # == Lets get a set of initial z  ==
    # start from train examples.
    print("starting from training examples.")
    zs_to_start_from_train_data = []
    indices_to_use = list(range(10)) + rng.permutation(len(reaction_bags_dataset))[:params.num_molecules_to_optimize - 10].tolist()
    # ^ use first ten as well as random ones as easy to look at first ten

    # We'll embed each molecule into the latent space one by one:
    # could batch this: but then would have to deal with ordering by length for packing padded sequence and for the
    # number of molecules that we run on it seems fast enough.
    for i in tqdm.tqdm(indices_to_use, desc="creating initial starting locations"):
        sequence_batch_first = reaction_bags_dataset[i][0]
        sequence_batch_first = torch.from_numpy(sequence_batch_first).view(1, -1)
        lengths = torch.tensor([sequence_batch_first.shape[1]])
        packed_seq = rnn.pack_padded_sequence(sequence_batch_first, lengths, batch_first=True)
        packed_seq = packed_seq.to(params.cuda_details.device_str)
        z_sample = molchef_wae._run_through_to_z(packed_seq)
        zs_to_start_from_train_data.append(z_sample)

    # == Now we shall run the optimization for each of these z samples  ==
    results = collections.defaultdict(list)

    searches = [('prop_opt',
                 LocalSearchRunner(False, molchef_wae.prop_predictor_, molchef_wae, seq_to_smi_list, params))]

    for search_name, searcher in searches:
        print(f"Doing {search_name}")

        init_points = zs_to_start_from_train_data
        for initial_z in tqdm.tqdm(init_points):
            results[search_name].append(searcher.optimize_z(initial_z, params.num_distinct_molecule_steps, params.epsilon))

    # == Now we can write out the files to store the reactants found on the trace. ==
    with open('local_search_results.pick', 'wb') as fo:
        pickle.dump(results, fo)

    # we shall also write out tokenized reactant bags that we need predicting.
    all_reactant_bags = set()
    for results_for_search_type in results.values():
        for individual_run_results in results_for_search_type:
            reactant_strs = individual_run_results[2]
            all_reactant_bags.update(reactant_strs)  # as defined reactants should already be in canonical form.

    tokenized_sampled_reactants = [mt.tokenization(smi_str) for smi_str in all_reactant_bags if len(smi_str)]
    with open(params.output_path, 'w') as fo:
        fo.writelines('\n'.join(tokenized_sampled_reactants))