Exemplo n.º 1
0
def main(config):
    if args.distributed:
        distutils.setup(config)

    try:
        setup_imports()
        trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
            task=config["task"],
            model=config["model"],
            dataset=config["dataset"],
            optimizer=config["optim"],
            identifier=config["identifier"],
            run_dir=config.get("run_dir", "./"),
            is_debug=config.get("is_debug", False),
            is_vis=config.get("is_vis", False),
            print_every=config.get("print_every", 10),
            seed=config.get("seed", 0),
            logger=config.get("logger", "tensorboard"),
            local_rank=config["local_rank"],
            amp=config.get("amp", False),
            cpu=config.get("cpu", False),
        )
        if config["checkpoint"] is not None:
            trainer.load_pretrained(config["checkpoint"])

        start_time = time.time()

        if config["mode"] == "train":
            trainer.train()

        elif config["mode"] == "predict":
            assert (
                trainer.test_loader
                is not None), "Test dataset is required for making predictions"
            assert config["checkpoint"]
            results_file = "predictions"
            trainer.predict(
                trainer.test_loader,
                results_file=results_file,
                disable_tqdm=False,
            )

        elif config["mode"] == "run-relaxations":
            assert isinstance(
                trainer, ForcesTrainer
            ), "Relaxations are only possible for ForcesTrainer"
            assert (trainer.relax_dataset is not None
                    ), "Relax dataset is required for making predictions"
            assert config["checkpoint"]
            trainer.run_relaxations()

        distutils.synchronize()

        if distutils.is_master():
            print("Total time taken = ", time.time() - start_time)

    finally:
        if args.distributed:
            distutils.cleanup()
Exemplo n.º 2
0
Arquivo: main.py Projeto: wood-b/ocp
    def __call__(self, config):
        setup_logging()
        self.config = copy.deepcopy(config)

        if args.distributed:
            distutils.setup(config)

        try:
            setup_imports()
            self.trainer = registry.get_trainer_class(
                config.get("trainer", "simple"))(
                    task=config["task"],
                    model=config["model"],
                    dataset=config["dataset"],
                    optimizer=config["optim"],
                    identifier=config["identifier"],
                    timestamp_id=config.get("timestamp_id", None),
                    run_dir=config.get("run_dir", "./"),
                    is_debug=config.get("is_debug", False),
                    is_vis=config.get("is_vis", False),
                    print_every=config.get("print_every", 10),
                    seed=config.get("seed", 0),
                    logger=config.get("logger", "tensorboard"),
                    local_rank=config["local_rank"],
                    amp=config.get("amp", False),
                    cpu=config.get("cpu", False),
                    slurm=config.get("slurm", {}),
                )
            self.task = registry.get_task_class(config["mode"])(self.config)
            self.task.setup(self.trainer)
            start_time = time.time()
            self.task.run()
            distutils.synchronize()
            if distutils.is_master():
                logging.info(f"Total time taken: {time.time() - start_time}")
        finally:
            if args.distributed:
                distutils.cleanup()
Exemplo n.º 3
0
def distributed_main(config):
    distutils.setup(config)
    main(config)
    distutils.cleanup()