def __init__(self, params: Params): # == Lets get the model! == indices_to_graphs = atom_features_dataset.PickledGraphDataset(params.path_mol_details, params.cuda_details) chkpt = torch.load(params.weights_to_use, map_location=params.cuda_details.device_str) # ^ Load the checkpoint as we can get some model details from there. latent_dim = chkpt['wae_state_dict']['latent_prior._params'].shape[1] // 2 print(f"Inferred a latent dimensionality of {latent_dim}") assert chkpt['stop_symbol_idx'] == mchef_config.get_num_graphs() mol_chef_params = get_mchef.MChefParams(params.cuda_details, indices_to_graphs, len(indices_to_graphs), chkpt['stop_symbol_idx'], latent_dim) molchef_wae = get_mchef.get_mol_chef(mol_chef_params) molchef_wae = params.cuda_details.return_cudafied(molchef_wae) molchef_wae.load_state_dict(chkpt['wae_state_dict']) self.ae = molchef_wae self.cuda_details = params.cuda_details self.rng = np.random.RandomState(1001) self.latent_dim = latent_dim self.stop_symbol_indx = chkpt['stop_symbol_idx']
def train_molecule_chef_qed_hiv(property_predictor, predictor_label_to_optimize): params = Params() # Set the random seeds. rng = np.random.RandomState(5156416) torch.manual_seed(rng.choice(1000000)) # Set up data # == The property data train_prop_dataset, val_prop_dataset = ( get_train_and_val_product_property_datasets( params, property_predictor, predictor_label_to_optimize)) print("Created property datasets!") # == The sequence data stop_symbol_idx = mchef_config.get_num_graphs( ) # comes after al the graphs trsfm = symbol_sequence_data.TrsfmSeqStrToArray( symbol_sequence_data.StopSymbolDetails(True, stop_symbol_idx), shuffle_seq_flag=True, rng=rng) reaction_bags_dataset = symbol_sequence_data.SymbolSequenceDataset( params.path_react_bags_train, trsfm) reaction_train_dataset = merged_dataset.MergedDataset( reaction_bags_dataset, train_prop_dataset) train_dataloader = DataLoader(reaction_train_dataset, batch_size=params.batch_size, shuffle=True, collate_fn=collate_datasets_func) reaction_bags_dataset_val = symbol_sequence_data.SymbolSequenceDataset( params.path_react_bags_val, trsfm) reaction_val_dataset = merged_dataset.MergedDataset( reaction_bags_dataset_val, val_prop_dataset) val_dataloader = DataLoader(reaction_val_dataset, batch_size=500, shuffle=False, collate_fn=collate_datasets_func) # == The graph data indices_to_graphs = atom_features_dataset.PickledGraphDataset( params.path_mol_details, params.cuda_details) assert stop_symbol_idx == len( indices_to_graphs), "stop symbol index should be after graphs" # Set up Model mol_chef_params = get_mchef.MChefParams(params.cuda_details, indices_to_graphs, len(indices_to_graphs), stop_symbol_idx, params.latent_dim) mc_wae = get_mchef.get_mol_chef(mol_chef_params) mc_wae = params.cuda_details.return_cudafied(mc_wae) # set up trainer optimizer = optim.Adam(mc_wae.parameters(), lr=params.learning_rate) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer, gamma=params.lr_reduction_factor) # Set up some loggers tb_writer_train = tb_.get_tb_writer( f"{TB_LOGS_FILE}/{params.run_name}_train") tb_writer_val = tb_.get_tb_writer(f"{TB_LOGS_FILE}/{params.run_name}_val") def add_details_to_train(dict_to_add): for name, value in dict_to_add.items(): tb_writer_train.add_scalar(name, value) train_log_helper = logging_tools.LogHelper([add_details_to_train]) tb_writer_train.global_step = 0 # Set up steps and setup funcs. def optimizer_step(): optimizer.step() tb_writer_train.global_step += 1 def setup_for_train(): mc_wae._logger_manager = train_log_helper mc_wae.train() # put in train mode def setup_for_val(): tb_writer_val.global_step = tb_writer_train.global_step mc_wae._tb_logger = None # turn off the more concise logging mc_wae.eval() # Run an initial validation setup_for_val() best_ae_obj_sofar = validation(val_dataloader, mc_wae, tb_writer_val, params.cuda_details, params.property_pred_factor, params.lambda_value) # Train! for epoch_num in range(params.num_epochs): print(f"We are starting epoch {epoch_num}") tb_writer_train.add_scalar("epoch_num", epoch_num) setup_for_train() train(train_dataloader, mc_wae, optimizer, optimizer_step, params.cuda_details, tb_writer_train, params.lambda_value, params.property_pred_factor) print("Switching to eval.") setup_for_val() ae_obj = validation(val_dataloader, mc_wae, tb_writer_val, params.cuda_details, params.property_pred_factor, params.lambda_value) if ae_obj >= best_ae_obj_sofar: print("** Best LL found so far! :-) **") best_ae_obj_sofar = ae_obj best_flag = True else: best_flag = False save_checkpoint( dict(epochs_completed=epoch_num + 1, wae_state_dict=mc_wae.state_dict(), optimizer=optimizer.state_dict(), learning_rate_scheduler=lr_scheduler.state_dict(), ll_from_val=ae_obj, wae_lambda_value=params.property_pred_factor, stop_symbol_idx=stop_symbol_idx), is_best=best_flag, filename=path.join( CHKPT_FOLDER, f"{params.run_name}-{datetime.datetime.now().isoformat()}.pth.pick" )) # See https://github.com/pytorch/pytorch/pull/7889, in PyTorch 1.1 you have to call scheduler after: if (epoch_num % params.lr_reduction_interval == 0 and epoch_num / params.lr_reduction_interval > 0.9): print("Running the learning rate scheduler. Optimizer is:") lr_scheduler.step() print(optimizer) print( f"==========================================================================================" )