def evaluate_generated_graphs(generated_graphs, termination, nlls, start_time, ts_properties, generation_batch_idx): """ Computes molecular properties for input set of generated graphs, saves results to CSV, and writes `generated_mols` to disk as a SMILES file. Properties are expensive to calculate, so only done when `gen_batch_idx` == 0 (i.e. for the first batch of generated molecules). Args: generated_graphs (list) : Contains `GenerationGraph`s. termination (torch.Tensor) : Molecular termination details; contains 1 at index if graph from `generated_mols` was "properly" terminated, 0 otherwise. nlls (torch.Tensor) : Contains final NLL of each item in `generated_mols`. start_time (time) : Program start time. ts_properties (dict) : Contains training set properties. gen_batch_idx (int) : Generation batch index. """ epoch_key = util.get_last_epoch() if generation_batch_idx == 0: # calculate molecular properties of generated set prop_dict = get_molecular_properties(molecules=generated_graphs, epoch_key=epoch_key, termination=termination) else: prop_dict = {} # initialize the property dictionary # add a few additional properties to the propery dictionary prop_dict[(epoch_key, "final_nll")] = nlls prop_dict[(epoch_key, "run_time")] = round(time.time() - start_time, 2) # output evaluation metrics to CSV output = C.job_dir # calculate validity list now, so as not to write to CSV in previous step epoch_id = epoch_key[6:] + "_" + str(generation_batch_idx) fraction_valid, validity_tensor = util.write_molecules( molecules=generated_graphs, final_nlls=nlls, epoch=epoch_id) # add these validity properties to the property dictionary prop_dict[(epoch_key, "fraction_valid")] = fraction_valid prop_dict[(epoch_key, "validity_tensor")] = validity_tensor # write these properties to disk, only for the first generation batch if generation_batch_idx == 0: util.properties_to_csv(prop_dict=prop_dict, csv_filename=f"{output}generation.csv", epoch_key=epoch_key, append=True) # join ts properties with prop_dict for plotting merged_properties = {**prop_dict, **ts_properties} # plot properties for this epoch plot_filename = f"{output}generation/features{epoch_key[6:]}.png" plot_molecular_properties(properties_dict=merged_properties, plot_filename=plot_filename)
def evaluate_model(valid_dataloader, train_dataloader, nll_per_action, model): """ Calculates the model score, which is the UC-JSD. Also calculates the mean NLL per action of the validation, training, and generated sets. Writes the scores to `validation.csv`. Args: valid_dataloader (torch.utils.data.dataloader.DataLoader) : Validation set data. train_dataloader (torch.utils.data.dataloader.DataLoader) : Training set data. nll_per_action (torch.Tensor) : Contains NLLs per action for one batch of generated structures. model (module.SummationMPNN) : Neural net model to evaluate. """ epoch_key = util.get_last_epoch() # get NLL statistics of validation and training set print("-- Calculating NLL statistics for validation set.", flush=True) valid_nll_list, avg_valid_nll = get_validation_nll( dataloader=valid_dataloader, model=model) print("-- Calculating NLL statistics for training set.", flush=True) train_nll_list, avg_train_nll = get_validation_nll( dataloader=train_dataloader, model=model) # get average generated final NLL avg_gen_nll = torch.sum(nll_per_action) / C.n_samples # calculate absolute difference between all three sets abs_nll_diff = (abs(avg_valid_nll - avg_gen_nll) + abs(avg_train_nll - avg_gen_nll) + abs(avg_train_nll - avg_valid_nll)) # initialize dictionary with NLL statistics model_scores = { "nll_val": valid_nll_list, "avg_nll_val": avg_valid_nll, "nll_train": train_nll_list, "avg_nll_train": avg_train_nll, "nll_gen": nll_per_action, "avg_nll_gen": avg_gen_nll, "abs_nll_diff": abs_nll_diff, } # get the UC-JSD and add it to the dictionary model_scores["UC-JSD"] = uc_jsd(nll_valid=model_scores["nll_val"], nll_train=model_scores["nll_train"], nll_sampled=model_scores["nll_gen"]) # write results to disk util.write_validation_scores(output_dir=C.job_dir, epoch_key=epoch_key, model_scores=model_scores, append=bool("Epoch" in epoch_key)) util.write_model_status(score=model_scores["UC-JSD"])
def evaluate_model(self, nll_per_action: torch.Tensor) -> None: """ Calculates the model score, which is the UC-JSD. Also calculates the mean NLL per action of the validation, training, and generated sets. Writes the scores to `validation.log`. Args: ---- nll_per_action (torch.Tensor) : Contains NLLs per action a batch of generated graphs. """ def _uc_jsd(nll_valid: torch.Tensor, nll_train: torch.Tensor, nll_sampled: torch.Tensor) -> float: """ Computes the UC-JSD (metric used for the benchmark of generative models in ArĂºs-Pous, J. et al., J. Chem. Inf., 2019, 1-13). Args: ---- nll_valid (torch.Tensor) : NLLs for correct actions in validation set. nll_train (torch.Tensor) : NLLs for correct actions in training set. nll_sampled (torch.Tensor) : NLLs for sampled actions in the generated set. """ min_len = min(len(nll_valid), len(nll_sampled), len(nll_train)) # make all the distributions the same length (dim=0) nll_valid_norm = nll_valid[:min_len] / torch.sum( nll_valid[:min_len]) nll_train_norm = nll_train[:min_len] / torch.sum( nll_train[:min_len]) nll_sampled_norm = nll_sampled[:min_len] / torch.sum( nll_sampled[:min_len]) nll_sum = (nll_valid_norm + nll_train_norm + nll_sampled_norm) / 3 uc_jsd = ( torch.nn.functional.kl_div(nll_valid_norm, nll_sum) + torch.nn.functional.kl_div(nll_train_norm, nll_sum) + torch.nn.functional.kl_div(nll_sampled_norm, nll_sum)) / 3 return float(uc_jsd) epoch_key = util.get_last_epoch() print("-- Calculating NLL statistics for validation set.", flush=True) valid_nll_list, avg_valid_nll = self.get_validation_nll( dataset="validation") print("-- Calculating NLL statistics for training set.", flush=True) train_nll_list, avg_train_nll = self.get_validation_nll( dataset="training") # get average final NLL for the generation set avg_gen_nll = torch.sum(nll_per_action) / constants.n_samples # initialize dictionary with NLL statistics model_scores = { "nll_val": valid_nll_list, "avg_nll_val": avg_valid_nll, "nll_train": train_nll_list, "avg_nll_train": avg_train_nll, "nll_gen": nll_per_action, "avg_nll_gen": avg_gen_nll, } # get the UC-JSD and add it to the dictionary model_scores["UC-JSD"] = _uc_jsd(nll_valid=model_scores["nll_val"], nll_train=model_scores["nll_train"], nll_sampled=model_scores["nll_gen"]) # write results to disk util.write_validation_scores( output_dir=constants.job_dir, epoch_key=epoch_key, model_scores=model_scores, append=bool(epoch_key != f"Epoch {constants.sample_every}")) util.write_model_status(score=model_scores["UC-JSD"])