示例#1
0
    def run(self, model, training_set, epoch):
        if self.collect_stats_frequency > 0 and epoch % self.collect_stats_frequency == 0:
            validation_set = next(self.validation_sets)
            other_values = {"lr": self.get_lr()}

            stats = ma.CollectStatsFromModel(
                model=model,
                epoch=epoch,
                training_set=training_set,
                validation_set=validation_set,
                writer=self._writer,
                other_values=other_values,
                logger=self.logger,
                sample_size=self.collect_stats_params["sample_size"],
                to_mol_func=uc.get_mol_func(
                    self.collect_stats_params["smiles_type"])).run()
            self._metric_epochs.append(stats["nll_plot/jsd_joined"])

        if isinstance(self.lr_scheduler,
                      torch.optim.lr_scheduler.ReduceLROnPlateau):
            metric = np.mean(
                self._metric_epochs[-self.lr_params["average_steps"]:])
            self.lr_scheduler.step(metric, epoch=epoch)
        else:
            self.lr_scheduler.step(epoch=epoch)

        lr_reached_min = (self.get_lr() < self.lr_params["min"])
        if lr_reached_min or self.epochs == epoch \
                or (self.save_frequency > 0 and (epoch % self.save_frequency == 0)):
            model.save(self._model_path(epoch))

        if self._writer and (epoch % self.WRITER_CACHE_EPOCHS == 0):
            self._reset_writer()

        return not lr_reached_min
def main():
    """Main function."""
    args = parse_args()

    randomize_func = functools.partial(uc.randomize_smiles,
                                       random_type=args.random_type)
    to_mol_func = uc.get_mol_func(args.smiles_type)
    to_smiles_func = uc.get_smi_func(args.smiles_type)
    mols_rdd = SC.textFile(args.input_smi_path) \
        .repartition(args.num_partitions) \
        .map(to_mol_func)\
        .persist()

    os.makedirs(args.output_smi_folder_path, exist_ok=True)

    for i in range(args.num_files):
        with open("{}/{:03d}.smi".format(args.output_smi_folder_path, i),
                  "w+") as out_file:
            for smi in mols_rdd.map(
                    lambda mol: to_smiles_func(randomize_func(mol))).collect():
                out_file.write("{}\n".format(smi))
示例#3
0
def main():
    """Main function."""
    args = parse_args()

    model = mm.Model.load_from_file(args.model_path, mode="sampling")
    training_set = list(uc.read_smi_file(args.training_set_path))
    validation_set = list(uc.read_smi_file(args.validation_set_path))

    writer = tbx.SummaryWriter(log_dir=args.log_path)

    ma.CollectStatsFromModel(model,
                             args.epoch,
                             training_set,
                             validation_set,
                             writer,
                             sample_size=args.sample_size,
                             with_weights=args.with_weights,
                             to_mol_func=uc.get_mol_func(args.smiles_type),
                             logger=LOG).run()

    writer.close()