def train_molecule_chef_qed_hiv(property_predictor,
                                predictor_label_to_optimize):
    params = Params()

    # Set the random seeds.
    rng = np.random.RandomState(5156416)
    torch.manual_seed(rng.choice(1000000))

    # Set up data
    # == The property data
    train_prop_dataset, val_prop_dataset = (
        get_train_and_val_product_property_datasets(
            params, property_predictor, predictor_label_to_optimize))
    print("Created property datasets!")

    # == The sequence data
    stop_symbol_idx = mchef_config.get_num_graphs(
    )  # comes after al the graphs
    trsfm = symbol_sequence_data.TrsfmSeqStrToArray(
        symbol_sequence_data.StopSymbolDetails(True, stop_symbol_idx),
        shuffle_seq_flag=True,
        rng=rng)

    reaction_bags_dataset = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_train, trsfm)
    reaction_train_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset, train_prop_dataset)
    train_dataloader = DataLoader(reaction_train_dataset,
                                  batch_size=params.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_datasets_func)

    reaction_bags_dataset_val = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_val, trsfm)
    reaction_val_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset_val, val_prop_dataset)
    val_dataloader = DataLoader(reaction_val_dataset,
                                batch_size=500,
                                shuffle=False,
                                collate_fn=collate_datasets_func)

    # == The graph data
    indices_to_graphs = atom_features_dataset.PickledGraphDataset(
        params.path_mol_details, params.cuda_details)
    assert stop_symbol_idx == len(
        indices_to_graphs), "stop symbol index should be after graphs"

    # Set up Model
    mol_chef_params = get_mchef.MChefParams(params.cuda_details,
                                            indices_to_graphs,
                                            len(indices_to_graphs),
                                            stop_symbol_idx, params.latent_dim)
    mc_wae = get_mchef.get_mol_chef(mol_chef_params)
    mc_wae = params.cuda_details.return_cudafied(mc_wae)

    # set up trainer
    optimizer = optim.Adam(mc_wae.parameters(), lr=params.learning_rate)
    lr_scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer, gamma=params.lr_reduction_factor)

    # Set up some loggers
    tb_writer_train = tb_.get_tb_writer(
        f"{TB_LOGS_FILE}/{params.run_name}_train")
    tb_writer_val = tb_.get_tb_writer(f"{TB_LOGS_FILE}/{params.run_name}_val")

    def add_details_to_train(dict_to_add):
        for name, value in dict_to_add.items():
            tb_writer_train.add_scalar(name, value)

    train_log_helper = logging_tools.LogHelper([add_details_to_train])
    tb_writer_train.global_step = 0

    # Set up steps and setup funcs.
    def optimizer_step():
        optimizer.step()
        tb_writer_train.global_step += 1

    def setup_for_train():
        mc_wae._logger_manager = train_log_helper
        mc_wae.train()  # put in train mode

    def setup_for_val():
        tb_writer_val.global_step = tb_writer_train.global_step
        mc_wae._tb_logger = None  # turn off the more concise logging
        mc_wae.eval()

    # Run an initial validation
    setup_for_val()
    best_ae_obj_sofar = validation(val_dataloader, mc_wae, tb_writer_val,
                                   params.cuda_details,
                                   params.property_pred_factor,
                                   params.lambda_value)

    # Train!
    for epoch_num in range(params.num_epochs):
        print(f"We are starting epoch {epoch_num}")
        tb_writer_train.add_scalar("epoch_num", epoch_num)
        setup_for_train()
        train(train_dataloader, mc_wae, optimizer, optimizer_step,
              params.cuda_details, tb_writer_train, params.lambda_value,
              params.property_pred_factor)

        print("Switching to eval.")
        setup_for_val()
        ae_obj = validation(val_dataloader, mc_wae, tb_writer_val,
                            params.cuda_details, params.property_pred_factor,
                            params.lambda_value)

        if ae_obj >= best_ae_obj_sofar:
            print("** Best LL found so far! :-) **")
            best_ae_obj_sofar = ae_obj
            best_flag = True
        else:
            best_flag = False

        save_checkpoint(
            dict(epochs_completed=epoch_num + 1,
                 wae_state_dict=mc_wae.state_dict(),
                 optimizer=optimizer.state_dict(),
                 learning_rate_scheduler=lr_scheduler.state_dict(),
                 ll_from_val=ae_obj,
                 wae_lambda_value=params.property_pred_factor,
                 stop_symbol_idx=stop_symbol_idx),
            is_best=best_flag,
            filename=path.join(
                CHKPT_FOLDER,
                f"{params.run_name}-{datetime.datetime.now().isoformat()}.pth.pick"
            ))

        # See https://github.com/pytorch/pytorch/pull/7889, in PyTorch 1.1 you have to call scheduler after:
        if (epoch_num % params.lr_reduction_interval == 0
                and epoch_num / params.lr_reduction_interval > 0.9):
            print("Running the learning rate scheduler. Optimizer is:")
            lr_scheduler.step()
            print(optimizer)
        print(
            f"=========================================================================================="
        )
Ejemplo n.º 2
0
def main(params: Params):
    # == Data ==
    print("\nCreating the embeddings -- ie our y variable")
    latent_dataset_creator = LatentsDatasetCreator(params)
    train_latents_dataset = latent_dataset_creator(params.path_react_bags_train)
    val_latents_dataset = latent_dataset_creator(params.path_react_bags_val)

    print("\nCreating the graphs -- ie our x variable")
    graph_dataset_creator = GraphDatasetCreator()
    train_graphs_dataset = graph_dataset_creator(params.path_products_train)
    val_graphs_dataset = graph_dataset_creator(params.path_products_val)

    train_dataset = merged_dataset.MergedDataset(train_graphs_dataset, train_latents_dataset)
    val_dataset = merged_dataset.MergedDataset(val_graphs_dataset, val_latents_dataset)

    train_dataloader = data.DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True,
                                       collate_fn=_graph_and_latents_collate_func)
    val_dataloader = data.DataLoader(val_dataset, batch_size=500, shuffle=False,
                                     collate_fn=_graph_and_latents_collate_func)

    # == Model & Optimizer ==
    latent_dim = latent_dataset_creator.latent_dim
    print(f"Latent dim is {latent_dim}")
    model = graph_regressors.GNNThenMLP(latent_dim, params.gnn_hidden_size, params.edge_names, params.gnn_embedding_dim,
                                        params.cuda_details, params.gnn_time_steps)
    model = params.cuda_details.return_cudafied(model)
    optimizer = optim.Adam(model.parameters(), weight_decay=1e-4)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 25], gamma=0.2)
    loss = _L2Loss()

    # == Create Training hooks for Pytorch-Ignite ==
    def _prepare_batch(batch, device, non_blocking):
        graphs, output = batch
        graphs = graphs.to_torch(params.cuda_details)
        output = params.cuda_details.return_cudafied(output)
        return graphs, output
    trainer = create_supervised_trainer(model, optimizer, loss, prepare_batch=_prepare_batch)
    evaluator = create_supervised_evaluator(model, metrics={'loss': Loss(loss)},  prepare_batch=_prepare_batch)

    desc = "ITERATION - loss: {:.5f}"
    pbar = tqdm.tqdm(initial=0, leave=False, total=len(train_dataloader), desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_dataloader) + 1

        if iter % params.log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(params.log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_dataloader)
        metrics = evaluator.state.metrics
        avg_loss = metrics['loss']
        tqdm.tqdm.write(
            "Training Results - Epoch: {}  Avg loss: {:.5f}"
                .format(engine.state.epoch, avg_loss)
        )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_dataloader)
        metrics = evaluator.state.metrics
        avg_loss = metrics['loss']
        tqdm.tqdm.write(
            "Validation Results - Epoch: {}  Avg loss: {:.5f}"
                .format(engine.state.epoch, avg_loss)
        )
        lr_scheduler.step()

        # save a checkpoint
        torch.save(dict(engine_epochs=engine.state.epoch, model_state_dict=model.state_dict()), path.join(CHKPT_FOLDER,
                                                  f"latents_to-{datetime.datetime.now().isoformat()}.pth.pick"))

        # reset progress bar
        pbar.n = pbar.last_print_n = 0

    # == Train! ==
    print("\nTraining!")
    trainer.run(train_dataloader, max_epochs=params.num_epochs)
    pbar.close()
    torch.save(dict(engine_epochs=params.num_epochs,
                    model_state_dict=model.state_dict()), path.join(CHKPT_FOLDER,
                     f"final-weights.pth.pick"))