def main(config): if args.distributed: distutils.setup(config) try: setup_imports() trainer = registry.get_trainer_class(config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", "tensorboard"), local_rank=config["local_rank"], amp=config.get("amp", False), cpu=config.get("cpu", False), ) if config["checkpoint"] is not None: trainer.load_pretrained(config["checkpoint"]) start_time = time.time() if config["mode"] == "train": trainer.train() elif config["mode"] == "predict": assert ( trainer.test_loader is not None), "Test dataset is required for making predictions" assert config["checkpoint"] results_file = "predictions" trainer.predict( trainer.test_loader, results_file=results_file, disable_tqdm=False, ) elif config["mode"] == "run-relaxations": assert isinstance( trainer, ForcesTrainer ), "Relaxations are only possible for ForcesTrainer" assert (trainer.relax_dataset is not None ), "Relax dataset is required for making predictions" assert config["checkpoint"] trainer.run_relaxations() distutils.synchronize() if distutils.is_master(): print("Total time taken = ", time.time() - start_time) finally: if args.distributed: distutils.cleanup()
def __call__(self, config): setup_logging() self.config = copy.deepcopy(config) if args.distributed: distutils.setup(config) try: setup_imports() self.trainer = registry.get_trainer_class( config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], timestamp_id=config.get("timestamp_id", None), run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", "tensorboard"), local_rank=config["local_rank"], amp=config.get("amp", False), cpu=config.get("cpu", False), slurm=config.get("slurm", {}), ) self.task = registry.get_task_class(config["mode"])(self.config) self.task.setup(self.trainer) start_time = time.time() self.task.run() distutils.synchronize() if distutils.is_master(): logging.info(f"Total time taken: {time.time() - start_time}") finally: if args.distributed: distutils.cleanup()
def distributed_main(config): distutils.setup(config) main(config) distutils.cleanup()