def train_molecule_chef_qed_hiv(property_predictor,
                                predictor_label_to_optimize):
    params = Params()

    # Set the random seeds.
    rng = np.random.RandomState(5156416)
    torch.manual_seed(rng.choice(1000000))

    # Set up data
    # == The property data
    train_prop_dataset, val_prop_dataset = (
        get_train_and_val_product_property_datasets(
            params, property_predictor, predictor_label_to_optimize))
    print("Created property datasets!")

    # == The sequence data
    stop_symbol_idx = mchef_config.get_num_graphs(
    )  # comes after al the graphs
    trsfm = symbol_sequence_data.TrsfmSeqStrToArray(
        symbol_sequence_data.StopSymbolDetails(True, stop_symbol_idx),
        shuffle_seq_flag=True,
        rng=rng)

    reaction_bags_dataset = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_train, trsfm)
    reaction_train_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset, train_prop_dataset)
    train_dataloader = DataLoader(reaction_train_dataset,
                                  batch_size=params.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_datasets_func)

    reaction_bags_dataset_val = symbol_sequence_data.SymbolSequenceDataset(
        params.path_react_bags_val, trsfm)
    reaction_val_dataset = merged_dataset.MergedDataset(
        reaction_bags_dataset_val, val_prop_dataset)
    val_dataloader = DataLoader(reaction_val_dataset,
                                batch_size=500,
                                shuffle=False,
                                collate_fn=collate_datasets_func)

    # == The graph data
    indices_to_graphs = atom_features_dataset.PickledGraphDataset(
        params.path_mol_details, params.cuda_details)
    assert stop_symbol_idx == len(
        indices_to_graphs), "stop symbol index should be after graphs"

    # Set up Model
    mol_chef_params = get_mchef.MChefParams(params.cuda_details,
                                            indices_to_graphs,
                                            len(indices_to_graphs),
                                            stop_symbol_idx, params.latent_dim)
    mc_wae = get_mchef.get_mol_chef(mol_chef_params)
    mc_wae = params.cuda_details.return_cudafied(mc_wae)

    # set up trainer
    optimizer = optim.Adam(mc_wae.parameters(), lr=params.learning_rate)
    lr_scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer, gamma=params.lr_reduction_factor)

    # Set up some loggers
    tb_writer_train = tb_.get_tb_writer(
        f"{TB_LOGS_FILE}/{params.run_name}_train")
    tb_writer_val = tb_.get_tb_writer(f"{TB_LOGS_FILE}/{params.run_name}_val")

    def add_details_to_train(dict_to_add):
        for name, value in dict_to_add.items():
            tb_writer_train.add_scalar(name, value)

    train_log_helper = logging_tools.LogHelper([add_details_to_train])
    tb_writer_train.global_step = 0

    # Set up steps and setup funcs.
    def optimizer_step():
        optimizer.step()
        tb_writer_train.global_step += 1

    def setup_for_train():
        mc_wae._logger_manager = train_log_helper
        mc_wae.train()  # put in train mode

    def setup_for_val():
        tb_writer_val.global_step = tb_writer_train.global_step
        mc_wae._tb_logger = None  # turn off the more concise logging
        mc_wae.eval()

    # Run an initial validation
    setup_for_val()
    best_ae_obj_sofar = validation(val_dataloader, mc_wae, tb_writer_val,
                                   params.cuda_details,
                                   params.property_pred_factor,
                                   params.lambda_value)

    # Train!
    for epoch_num in range(params.num_epochs):
        print(f"We are starting epoch {epoch_num}")
        tb_writer_train.add_scalar("epoch_num", epoch_num)
        setup_for_train()
        train(train_dataloader, mc_wae, optimizer, optimizer_step,
              params.cuda_details, tb_writer_train, params.lambda_value,
              params.property_pred_factor)

        print("Switching to eval.")
        setup_for_val()
        ae_obj = validation(val_dataloader, mc_wae, tb_writer_val,
                            params.cuda_details, params.property_pred_factor,
                            params.lambda_value)

        if ae_obj >= best_ae_obj_sofar:
            print("** Best LL found so far! :-) **")
            best_ae_obj_sofar = ae_obj
            best_flag = True
        else:
            best_flag = False

        save_checkpoint(
            dict(epochs_completed=epoch_num + 1,
                 wae_state_dict=mc_wae.state_dict(),
                 optimizer=optimizer.state_dict(),
                 learning_rate_scheduler=lr_scheduler.state_dict(),
                 ll_from_val=ae_obj,
                 wae_lambda_value=params.property_pred_factor,
                 stop_symbol_idx=stop_symbol_idx),
            is_best=best_flag,
            filename=path.join(
                CHKPT_FOLDER,
                f"{params.run_name}-{datetime.datetime.now().isoformat()}.pth.pick"
            ))

        # See https://github.com/pytorch/pytorch/pull/7889, in PyTorch 1.1 you have to call scheduler after:
        if (epoch_num % params.lr_reduction_interval == 0
                and epoch_num / params.lr_reduction_interval > 0.9):
            print("Running the learning rate scheduler. Optimizer is:")
            lr_scheduler.step()
            print(optimizer)
        print(
            f"=========================================================================================="
        )
Пример #2
0
def main(params: Params):
    # # Random seeds
    rng = np.random.RandomState(4545)
    torch.manual_seed(2562)

    # # Data
    train_trees = train_utils.load_tuple_trees(params.train_tree_path, rng)
    val_trees = train_utils.load_tuple_trees(params.val_tree_path, rng)
    print(f"Number of trees in valid set: {len(val_trees)}")
    starting_reactants = train_utils.load_reactant_vocab(params.reactant_vocab_path)

    # # Model Params
    _dogae_params = {'latent_dim': 25,
                     'mol_graph_embedder_params': {'hidden_layer_size': 80,
                                                   'edge_names': ['single', 'double', 'triple', 'aromatic'],
                                                   'embedding_dim': 50,
                                                   'num_layers': 4},
                     'dag_graph_embedder_gnn_params': {'hlayer_size': 50,
                                                       'edge_names': ['reactions'],
                                                       'num_layers': 7},
                     'dag_embedder_aggr_type_s': 'FINAL_NODE',
                     'decoder_params': {'gru_insize': 50,
                                        'gru_hsize': 200,
                                        'num_layers': 3,
                                        'gru_dropout': 0.1,
                                        'max_steps': 100},
                     }

    # # Model
    model, collate_func, model_other_parts = dogae_utils.load_dogae_model(params.device, params.log_for_reaction_predictor_path,
                                 dogae_train_details=dogae_utils.DogaeTrainDetails(starting_reactants, _dogae_params))

    # # Dataloaders
    train_dataloader = data.DataLoader(train_trees, batch_size=params.batch_size,
                                       num_workers=params.num_dataloader_workers, collate_fn=collate_func,
                                       shuffle=True)
    val_dataloader = data.DataLoader(val_trees, batch_size=params.val_batch_size, num_workers=params.num_dataloader_workers,
                                     collate_fn=collate_func, shuffle=False)

    # # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=params.milestones, gamma=params.gamma)

    # # Create a folder to store the checkpoints.
    os.makedirs(path.join(CHKPT_FOLDER, params.run_name))

    # # Tensorboard loggers
    tb_writer_train = tb_.get_tb_writer(f"{TB_LOGS_FILE}/{params.run_name}_train")
    tb_writer_train.global_step = 0
    tb_writer_train.add_hparams({**misc.unpack_class_into_params_dict(model_other_parts['hparams'], prepender="model:"),
                                **misc.unpack_class_into_params_dict(params, prepender="train:")}, {})
    tb_writer_val = tb_.get_tb_writer(f"{TB_LOGS_FILE}/{params.run_name}_val")
    def add_details_to_train(dict_to_add):
        for name, value in dict_to_add.items():
            tb_writer_train.add_scalar(name, value)
    train_log_helper = logging_tools.LogHelper([add_details_to_train])

    # # Create Ignite trainer
    def loss_fn(model, x):
        # Note that outside the model shall compute the embeddings of the graph
        # -- these are needed in both the encoder
        # and decoder so saves compute to just compute them once.
        x, new_order = x
        embedded_graphs = model.mol_embdr(x.molecular_graphs)
        x.molecular_graph_embeddings = embedded_graphs
        new_node_feats_for_dag = x.molecular_graph_embeddings[x.dags_for_inputs.node_features.squeeze(),:]
        x.dags_for_inputs.node_features = new_node_feats_for_dag

        # Then we can run the model forward.
        loss = -model(x, lambda_=params.lambda_value).mean()

        return loss

    def prepare_batch(batch, device):
        x, new_order = batch
        x.inplace_to(device)
        return x, new_order

    def setup_for_val():
        tb_writer_val.global_step = tb_writer_train.global_step  # match the steps between the train and val tensorboards
        model._logger_manager = None  # turn off the more precise logging for when we go through validation set/sample.

    trainer, timers = ignite_utils.create_unsupervised_trainer_timers(
        model, optimizer, loss_fn, device=params.device, prepare_batch=prepare_batch
    )

    # # Now create the Ignite callbacks for dealing with the progressbar and performing validation etc...
    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(initial=0, leave=False, total=len(train_dataloader), desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        pbar.desc = desc.format(engine.state.output)

        tb_writer_train.global_step += 1
        tb_writer_train.add_scalar(TOTAL_LOSS_TB_STRING, engine.state.output)
        if engine.state.iteration % params.log_interval_histograms == 0:
            # Every 100 steps we store the histograms of our sampled z's to ensure not getting posterior collapse
            model.encoder.shallow_dist._tb_logger = tb_writer_train  # turn it on for this step
        else:
            model.encoder.shallow_dist._tb_logger = None # otherwise we do not save this due to speed.

        pbar.update()

    @trainer.on(Events.EPOCH_STARTED)
    def setup_trainer(engine):
        timers.reset()
        model._logger_manager = train_log_helper
        tb_writer_train.add_scalar("epoch_num", engine.state.epoch)
        tqdm.write(f"\n\n# Epoch {engine.state.epoch} starting!")

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        """
        This callback does validation at the end of a training epoch.
        
        Note we have two kinds of validation: simple and expensive. The simple runs the same loss calculation we use in
        training (i.e. with teacher forcing), so that it can evaluate the whole sequence at once. This means it runs
         quickly. On the other hand the expensive evaluation runs slower operations. It does greedy reconstruction
        of the sequence, always feeding in the previous chosen action as input to the next time steps. New reactions
        have to be predicted by calling the reaction predictor oracle. We also sample from the model at new places in
        latent space, which also requires the evaluation of one step at a time and calls to the Transformer.
        Given the expense of doing this form of validation it is done less frequently.
        """
        tqdm.write(f"\n\n# Epoch {engine.state.epoch} finished")
        tqdm.write(f"## Timings:\n{str(timers)}")
        tqdm.write(f"## Validation")

        # Setup for validation
        setup_for_val()
        run_expensive_ops_flag = (engine.state.epoch % params.expensive_ops_freq) == 0
        # ^ we will only do the ops that involve sampling infrequently to save constantly bombarding the server.

        # ## Main validation code

        def val_func():
            # First look at performance on validation dataset
            dogae_utils.validation(val_dataloader, model, tb_writer_val, params.device, {'lambda_': params.lambda_value},
                                   run_expensive_ops_flag)

            # Then create some samples!
            if run_expensive_ops_flag:
                out_tuple_trees = dogae_utils.sample_n_from_prior(model, 10, rng)
                tuple_trees_as_text = ' ;\n'.join(map(str, out_tuple_trees))
                tb_writer_train.add_text("tuple_trees_sampled", f"```{tuple_trees_as_text}```")

        misc.try_but_pass(val_func, requests.exceptions.Timeout, False)
        # ^ can have problems with reaction server so continue
        # training for now and ignore validation (can do it from the checkpoints later).

        # Save a checkpoint
        time_chkpt = strftime("%y-%m-%d_%H:%M:%S", gmtime())
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'mol_to_graph_idx_for_reactants': model_other_parts['mol_to_graph_idx_for_reactants'],
            'run_name': params.run_name,
            'iter': engine.state.iteration,
            'epoch': engine.state.epoch,
            'lambda': params.lambda_value,
            'dogae_params': _dogae_params,
            },
            path.join(CHKPT_FOLDER, params.run_name, f'time-{time_chkpt}_epoch-{engine.state.epoch}.pth.pick'))

        # Reset the progress bar and run the LR scheduler.
        pbar.n = pbar.last_print_n = 0
        pbar.reset()
        lr_scheduler.step()

    @trainer.on(Events.STARTED)
    def initial_validation(engine):
        tqdm.write(f"# Initial Validation")

        # Switch the logger for validation:
        setup_for_val()

        dogae_utils.validation(val_dataloader, model, tb_writer_val, params.device,
                               {'lambda_': params.lambda_value})  # run before start training.

    # # Now we can train!
    print("Beginning Training!")
    trainer.run(train_dataloader, max_epochs=params.num_epochs)
    pbar.close()