Example #1
0
    def training_phase(self):
        """ Trains model (`self.model`) and generates graphs.
        """
        self.train_dataloader = self.get_dataloader(
            hdf_path=self.train_h5_path,
            data_description="training set"
        )
        self.valid_dataloader = self.get_dataloader(
            hdf_path=self.valid_h5_path,
            data_description="validation set"
        )

        self.get_ts_properties()

        self.initialize_output_files()

        start_epoch, end_epoch = self.define_model_and_optimizer()

        print("* Beginning training.", flush=True)
        n_processed_batches = 0
        for epoch in range(start_epoch, end_epoch):

            self.current_epoch = epoch
            n_processed_batches = self.train_epoch(n_processed_batches=n_processed_batches)

            # evaluate model every `sample_every` epochs (not every epoch)
            if epoch % self.C.sample_every == 0:
                self.evaluate_model()
            else:
                util.write_model_status(score="NA")  # score not computed

        self.print_time_elapsed()
Example #2
0
    def evaluate_model(self):
        """
        Evaluates model every `sample_every` epochs by calculating the UC-JSD from generated
        structures. Saves model scores in `validation.log` and then saves model state.
        """
        if self.current_epoch % self.constants.sample_every == 0:
            self.model.eval()  # sets layers to eval mode (e.g. norm, dropout)
            with torch.no_grad():  # deactivates autograd engine

                # generate graphs required for model evaluation (molecules are saved as
                # `self` attributes)
                self.generate_graphs(n_samples=self.constants.n_samples,
                                     evaluation=True)

                print("* Evaluating model.", flush=True)
                self.analyzer.model = self.model
                self.analyzer.evaluate_model(
                    nll_per_action=self.nll_per_action)

                #self.nll_per_action = None  # don't need anymore

                print(f"* Saving model state at Epoch {self.current_epoch}.",
                      flush=True)
                # `pickle.HIGHEST_PROTOCOL` good for large objects
                model_path = self.constants.job_dir + f"model_restart_{self.current_epoch}.pth"
                torch.save(obj=self.model.state_dict(),
                           f=model_path,
                           pickle_protocol=pickle.HIGHEST_PROTOCOL)
        else:
            util.write_model_status(
                score="NA")  # score not computed, so use placeholder
Example #3
0
    def training_phase(self) -> None:
        """
        Trains model and generates graphs.
        """
        print("* Setting up training job.", flush=True)
        self.train_dataloader = self.get_dataloader(
            hdf_path=self.train_h5_path, data_description="training set")
        self.valid_dataloader = self.get_dataloader(
            hdf_path=self.valid_h5_path, data_description="validation set")

        self.load_training_set_properties()
        self.create_output_files()
        self.analyzer = Analyzer(valid_dataloader=self.valid_dataloader,
                                 train_dataloader=self.train_dataloader,
                                 start_time=self.start_time)

        start_epoch, end_epoch = self.define_model_and_optimizer()

        print("* Beginning training.", flush=True)
        for epoch in range(start_epoch, end_epoch):

            self.current_epoch = epoch
            avg_train_loss = self.train_epoch()
            avg_valid_loss = self.validation_epoch()

            util.write_model_status(epoch=self.current_epoch,
                                    lr=self.optimizer.param_groups[0]["lr"],
                                    training_loss=avg_train_loss,
                                    validation_loss=avg_valid_loss)

            self.evaluate_model()

        self.print_time_elapsed()
Example #4
0
def evaluate_model(valid_dataloader, train_dataloader, nll_per_action, model):
    """ Calculates the model score, which is the UC-JSD. Also calculates the mean
    NLL per action of the validation, training, and generated sets. Writes the
    scores to `validation.csv`.

    Args:
      valid_dataloader (torch.utils.data.dataloader.DataLoader) : Validation set
        data.
      train_dataloader (torch.utils.data.dataloader.DataLoader) : Training set
        data.
      nll_per_action (torch.Tensor) : Contains NLLs per action for one batch of
        generated structures.
      model (module.SummationMPNN) : Neural net model to evaluate.
    """
    epoch_key = util.get_last_epoch()

    # get NLL statistics of validation and training set
    print("-- Calculating NLL statistics for validation set.", flush=True)
    valid_nll_list, avg_valid_nll = get_validation_nll(
        dataloader=valid_dataloader, model=model)
    print("-- Calculating NLL statistics for training set.", flush=True)
    train_nll_list, avg_train_nll = get_validation_nll(
        dataloader=train_dataloader, model=model)

    # get average generated final NLL
    avg_gen_nll = torch.sum(nll_per_action) / C.n_samples

    # calculate absolute difference between all three sets
    abs_nll_diff = (abs(avg_valid_nll - avg_gen_nll) +
                    abs(avg_train_nll - avg_gen_nll) +
                    abs(avg_train_nll - avg_valid_nll))

    # initialize dictionary with NLL statistics
    model_scores = {
        "nll_val": valid_nll_list,
        "avg_nll_val": avg_valid_nll,
        "nll_train": train_nll_list,
        "avg_nll_train": avg_train_nll,
        "nll_gen": nll_per_action,
        "avg_nll_gen": avg_gen_nll,
        "abs_nll_diff": abs_nll_diff,
    }

    # get the UC-JSD and add it to the dictionary
    model_scores["UC-JSD"] = uc_jsd(nll_valid=model_scores["nll_val"],
                                    nll_train=model_scores["nll_train"],
                                    nll_sampled=model_scores["nll_gen"])

    # write results to disk
    util.write_validation_scores(output_dir=C.job_dir,
                                 epoch_key=epoch_key,
                                 model_scores=model_scores,
                                 append=bool("Epoch" in epoch_key))

    util.write_model_status(score=model_scores["UC-JSD"])
Example #5
0
    def train_epoch(self, n_processed_batches=0):
        """ Performs one training epoch.
        """
        print(f"* Training epoch {self.current_epoch}.", flush=True)
        loss_tensor = torch.zeros(len(self.train_dataloader), device="cuda")
        self.model.train()  # ensure model is in train mode

        # each batch consists of `batch_size` molecules
        # **note: "idx" == "index"
        for batch_idx, batch in tqdm(
            enumerate(self.train_dataloader), total=len(self.train_dataloader)
        ):
            n_processed_batches += 1
            batch = [b.cuda(non_blocking=True) for b in batch]
            nodes, edges, target_output = batch

            # return the output
            output = self.model(nodes, edges)

            # clear the gradients of all optimized `(torch.Tensor)`s
            self.model.zero_grad()
            self.optimizer.zero_grad()

            # compute the loss
            batch_loss = loss.graph_generation_loss(
                output=output,
                target_output=target_output,
            )

            loss_tensor[batch_idx] = batch_loss

            # backpropagate
            batch_loss.backward()
            self.optimizer.step()

            # update the learning rate
            self.update_learning_rate(n_batches=n_processed_batches)

        util.write_model_status(
            epoch=self.current_epoch,
            lr=self.optimizer.param_groups[0]["lr"],
            loss=torch.mean(loss_tensor),
        )
        return n_processed_batches
Example #6
0
    def create_output_files(self) -> None:
        """
        Creates output files (with appropriate headers) for new (i.e. non-restart) jobs.
        If restart a job, all new output will be appended to existing output files.
        """
        if not self.constants.restart:
            print("* Touching output files.", flush=True)
            # begin writing `generation.log` file
            csv_path_and_filename = self.constants.job_dir + "generation.log"
            util.properties_to_csv(prop_dict=self.ts_properties,
                                   csv_filename=csv_path_and_filename,
                                   epoch_key="Training set",
                                   append=False)

            # begin writing `convergence.log` file
            util.write_model_status(append=False)

            # create `generation/` subdirectory to write generation output to
            os.makedirs(self.constants.job_dir + "generation/", exist_ok=True)
Example #7
0
    def evaluate_model(self, nll_per_action: torch.Tensor) -> None:
        """
        Calculates the model score, which is the UC-JSD. Also calculates the mean NLL per action of
        the validation, training, and generated sets. Writes the scores to `validation.log`.

        Args:
        ----
            nll_per_action (torch.Tensor) : Contains NLLs per action a batch of generated graphs.
        """
        def _uc_jsd(nll_valid: torch.Tensor, nll_train: torch.Tensor,
                    nll_sampled: torch.Tensor) -> float:
            """
            Computes the UC-JSD (metric used for the benchmark of generative models in
            ArĂºs-Pous, J. et al., J. Chem. Inf., 2019, 1-13).

            Args:
            ----
                nll_valid (torch.Tensor) : NLLs for correct actions in validation set.
                nll_train (torch.Tensor) : NLLs for correct actions in training set.
                nll_sampled (torch.Tensor) : NLLs for sampled actions in the generated set.
            """
            min_len = min(len(nll_valid), len(nll_sampled), len(nll_train))

            # make all the distributions the same length (dim=0)
            nll_valid_norm = nll_valid[:min_len] / torch.sum(
                nll_valid[:min_len])
            nll_train_norm = nll_train[:min_len] / torch.sum(
                nll_train[:min_len])
            nll_sampled_norm = nll_sampled[:min_len] / torch.sum(
                nll_sampled[:min_len])

            nll_sum = (nll_valid_norm + nll_train_norm + nll_sampled_norm) / 3

            uc_jsd = (
                torch.nn.functional.kl_div(nll_valid_norm, nll_sum) +
                torch.nn.functional.kl_div(nll_train_norm, nll_sum) +
                torch.nn.functional.kl_div(nll_sampled_norm, nll_sum)) / 3

            return float(uc_jsd)

        epoch_key = util.get_last_epoch()

        print("-- Calculating NLL statistics for validation set.", flush=True)
        valid_nll_list, avg_valid_nll = self.get_validation_nll(
            dataset="validation")

        print("-- Calculating NLL statistics for training set.", flush=True)
        train_nll_list, avg_train_nll = self.get_validation_nll(
            dataset="training")

        # get average final NLL for the generation set
        avg_gen_nll = torch.sum(nll_per_action) / constants.n_samples

        # initialize dictionary with NLL statistics
        model_scores = {
            "nll_val": valid_nll_list,
            "avg_nll_val": avg_valid_nll,
            "nll_train": train_nll_list,
            "avg_nll_train": avg_train_nll,
            "nll_gen": nll_per_action,
            "avg_nll_gen": avg_gen_nll,
        }

        # get the UC-JSD and add it to the dictionary
        model_scores["UC-JSD"] = _uc_jsd(nll_valid=model_scores["nll_val"],
                                         nll_train=model_scores["nll_train"],
                                         nll_sampled=model_scores["nll_gen"])

        # write results to disk
        util.write_validation_scores(
            output_dir=constants.job_dir,
            epoch_key=epoch_key,
            model_scores=model_scores,
            append=bool(epoch_key != f"Epoch {constants.sample_every}"))
        util.write_model_status(score=model_scores["UC-JSD"])