Esempio n. 1
0
    def __init__(self, params: Params):
        # == Lets get the model! ==
        indices_to_graphs = atom_features_dataset.PickledGraphDataset(params.path_mol_details, params.cuda_details)

        chkpt = torch.load(params.weights_to_use, map_location=params.cuda_details.device_str)
        # ^ Load the checkpoint as we can get some model details from there.

        latent_dim = chkpt['wae_state_dict']['latent_prior._params'].shape[1] // 2
        print(f"Inferred a latent dimensionality of {latent_dim}")
        assert chkpt['stop_symbol_idx'] == mchef_config.get_num_graphs()

        mol_chef_params = get_mchef.MChefParams(params.cuda_details, indices_to_graphs, len(indices_to_graphs),
                                                chkpt['stop_symbol_idx'], latent_dim)
        molchef_wae = get_mchef.get_mol_chef(mol_chef_params)
        molchef_wae = params.cuda_details.return_cudafied(molchef_wae)
        molchef_wae.load_state_dict(chkpt['wae_state_dict'])

        self.ae = molchef_wae
        self.cuda_details = params.cuda_details
        self.rng = np.random.RandomState(1001)
        self.latent_dim = latent_dim
        self.stop_symbol_indx = chkpt['stop_symbol_idx']
def train_molecule_chef_qed_hiv(property_predictor,
                                predictor_label_to_optimize):
    params = Params()

    # Set the random seeds.
    rng = np.random.RandomState(5156416)
    torch.manual_seed(rng.choice(1000000))

    # Set up data
    # == The property data
    train_prop_dataset, val_prop_dataset = (
        get_train_and_val_product_property_datasets(
            params, property_predictor, predictor_label_to_optimize))
    print("Created property datasets!")

    # == The sequence data
    stop_symbol_idx = mchef_config.get_num_graphs(
    )  # comes after al the graphs
    trsfm = symbol_sequence_data.TrsfmSeqStrToArray(
        symbol_sequence_data.StopSymbolDetails(True, stop_symbol_idx),
        shuffle_seq_flag=True,
        rng=rng)

    reaction_bags_dataset = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_train, trsfm)
    reaction_train_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset, train_prop_dataset)
    train_dataloader = DataLoader(reaction_train_dataset,
                                  batch_size=params.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_datasets_func)

    reaction_bags_dataset_val = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_val, trsfm)
    reaction_val_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset_val, val_prop_dataset)
    val_dataloader = DataLoader(reaction_val_dataset,
                                batch_size=500,
                                shuffle=False,
                                collate_fn=collate_datasets_func)

    # == The graph data
    indices_to_graphs = atom_features_dataset.PickledGraphDataset(
        params.path_mol_details, params.cuda_details)
    assert stop_symbol_idx == len(
        indices_to_graphs), "stop symbol index should be after graphs"

    # Set up Model
    mol_chef_params = get_mchef.MChefParams(params.cuda_details,
                                            indices_to_graphs,
                                            len(indices_to_graphs),
                                            stop_symbol_idx, params.latent_dim)
    mc_wae = get_mchef.get_mol_chef(mol_chef_params)
    mc_wae = params.cuda_details.return_cudafied(mc_wae)

    # set up trainer
    optimizer = optim.Adam(mc_wae.parameters(), lr=params.learning_rate)
    lr_scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer, gamma=params.lr_reduction_factor)

    # Set up some loggers
    tb_writer_train = tb_.get_tb_writer(
        f"{TB_LOGS_FILE}/{params.run_name}_train")
    tb_writer_val = tb_.get_tb_writer(f"{TB_LOGS_FILE}/{params.run_name}_val")

    def add_details_to_train(dict_to_add):
        for name, value in dict_to_add.items():
            tb_writer_train.add_scalar(name, value)

    train_log_helper = logging_tools.LogHelper([add_details_to_train])
    tb_writer_train.global_step = 0

    # Set up steps and setup funcs.
    def optimizer_step():
        optimizer.step()
        tb_writer_train.global_step += 1

    def setup_for_train():
        mc_wae._logger_manager = train_log_helper
        mc_wae.train()  # put in train mode

    def setup_for_val():
        tb_writer_val.global_step = tb_writer_train.global_step
        mc_wae._tb_logger = None  # turn off the more concise logging
        mc_wae.eval()

    # Run an initial validation
    setup_for_val()
    best_ae_obj_sofar = validation(val_dataloader, mc_wae, tb_writer_val,
                                   params.cuda_details,
                                   params.property_pred_factor,
                                   params.lambda_value)

    # Train!
    for epoch_num in range(params.num_epochs):
        print(f"We are starting epoch {epoch_num}")
        tb_writer_train.add_scalar("epoch_num", epoch_num)
        setup_for_train()
        train(train_dataloader, mc_wae, optimizer, optimizer_step,
              params.cuda_details, tb_writer_train, params.lambda_value,
              params.property_pred_factor)

        print("Switching to eval.")
        setup_for_val()
        ae_obj = validation(val_dataloader, mc_wae, tb_writer_val,
                            params.cuda_details, params.property_pred_factor,
                            params.lambda_value)

        if ae_obj >= best_ae_obj_sofar:
            print("** Best LL found so far! :-) **")
            best_ae_obj_sofar = ae_obj
            best_flag = True
        else:
            best_flag = False

        save_checkpoint(
            dict(epochs_completed=epoch_num + 1,
                 wae_state_dict=mc_wae.state_dict(),
                 optimizer=optimizer.state_dict(),
                 learning_rate_scheduler=lr_scheduler.state_dict(),
                 ll_from_val=ae_obj,
                 wae_lambda_value=params.property_pred_factor,
                 stop_symbol_idx=stop_symbol_idx),
            is_best=best_flag,
            filename=path.join(
                CHKPT_FOLDER,
                f"{params.run_name}-{datetime.datetime.now().isoformat()}.pth.pick"
            ))

        # See https://github.com/pytorch/pytorch/pull/7889, in PyTorch 1.1 you have to call scheduler after:
        if (epoch_num % params.lr_reduction_interval == 0
                and epoch_num / params.lr_reduction_interval > 0.9):
            print("Running the learning rate scheduler. Optimizer is:")
            lr_scheduler.step()
            print(optimizer)
        print(
            f"=========================================================================================="
        )