def run(self, model, training_set, epoch):
        if self.collect_stats_frequency > 0 and epoch % self.collect_stats_frequency == 0:
            validation_set = next(self.validation_sets)
            other_values = {"lr": self.get_lr()}

            ma.CollectStatsFromModel(
                model=model,
                epoch=epoch,
                training_set=training_set,
                validation_set=validation_set,
                writer=self._writer,
                other_values=other_values,
                logger=self.logger,
                sample_size=self.collect_stats_params["sample_size"]).run()

        self.lr_scheduler.step(epoch=epoch)

        lr_reached_min = (self.get_lr() < self.lr_params["min"])
        if lr_reached_min or self.epochs == epoch \
                or (self.save_frequency > 0 and (epoch % self.save_frequency == 0)):
            model.save(self._model_path(epoch))

        if self._writer and (epoch % self.WRITER_CACHE_EPOCHS == 0):
            self._reset_writer()

        return not lr_reached_min
예제 #2
0
    def run(self, model, training_set, epoch):
        if self.collect_stats_frequency > 0 and epoch % self.collect_stats_frequency == 0:
            validation_set = next(self.validation_sets)
            other_values = {"lr": self.get_lr()}

            stats = ma.CollectStatsFromModel(
                model=model,
                epoch=epoch,
                training_set=training_set,
                validation_set=validation_set,
                writer=self._writer,
                other_values=other_values,
                logger=self.logger,
                sample_size=self.collect_stats_params["sample_size"],
                to_mol_func=uc.get_mol_func(
                    self.collect_stats_params["smiles_type"])).run()
            self._metric_epochs.append(stats["nll_plot/jsd_joined"])

        if isinstance(self.lr_scheduler,
                      torch.optim.lr_scheduler.ReduceLROnPlateau):
            metric = np.mean(
                self._metric_epochs[-self.lr_params["average_steps"]:])
            self.lr_scheduler.step(metric, epoch=epoch)
        else:
            self.lr_scheduler.step(epoch=epoch)

        lr_reached_min = (self.get_lr() < self.lr_params["min"])
        if lr_reached_min or self.epochs == epoch \
                or (self.save_frequency > 0 and (epoch % self.save_frequency == 0)):
            model.save(self._model_path(epoch))

        if self._writer and (epoch % self.WRITER_CACHE_EPOCHS == 0):
            self._reset_writer()

        return not lr_reached_min
예제 #3
0
def main():
    """Main function."""
    args = parse_args()

    model = mm.Model.load_from_file(args.model_path, mode="sampling")
    training_set = list(uc.read_smi_file(args.training_set_path))
    validation_set = list(uc.read_smi_file(args.validation_set_path))

    writer = tbx.SummaryWriter(log_dir=args.log_path)

    ma.CollectStatsFromModel(model,
                             args.epoch,
                             training_set,
                             validation_set,
                             writer,
                             sample_size=args.sample_size,
                             with_weights=args.with_weights,
                             to_mol_func=uc.get_mol_func(args.smiles_type),
                             logger=LOG).run()

    writer.close()
예제 #4
0
def main():
    """Main function."""
    args = parse_args()

    model = mm.DecoratorModel.load_from_file(args.model_path, mode="sampling")
    training_set = list(uc.read_csv_file(args.training_set_path, num_fields=2))
    validation_set = list(
        uc.read_csv_file(args.validation_set_path, num_fields=2))

    writer = tbx.SummaryWriter(log_dir=args.log_path)

    ma.CollectStatsFromModel(model,
                             args.epoch,
                             training_set,
                             validation_set,
                             writer,
                             sample_size=args.sample_size,
                             decoration_type=args.decoration_type,
                             with_weights=args.with_weights,
                             logger=LOG).run()

    writer.close()