def main(local_rank, c10d_backend, rdzv_init_url, max_world_size, classy_args): torch.manual_seed(0) set_video_backend(classy_args.video_backend) # Loads config, sets up task config = load_json(classy_args.config_file) task = build_task(config) # Load checkpoint, if available checkpoint = load_checkpoint(classy_args.checkpoint_folder) task.set_checkpoint(checkpoint) pretrained_checkpoint = load_checkpoint(classy_args.pretrained_checkpoint_folder) if pretrained_checkpoint is not None: assert isinstance( task, FineTuningTask ), "Can only use a pretrained checkpoint for fine tuning tasks" task.set_pretrained_checkpoint(pretrained_checkpoint) hooks = [ LossLrMeterLoggingHook(classy_args.log_freq), ModelComplexityHook(), TimeMetricsHook(), ] if classy_args.checkpoint_folder != "": args_dict = vars(classy_args) args_dict["config"] = config hooks.append( CheckpointHook( classy_args.checkpoint_folder, args_dict, checkpoint_period=classy_args.checkpoint_period, ) ) if classy_args.profiler: hooks.append(ProfilerHook()) task.set_hooks(hooks) assert c10d_backend == Backend.NCCL or c10d_backend == Backend.GLOO if c10d_backend == torch.distributed.Backend.NCCL: # needed to enable NCCL error handling os.environ["NCCL_BLOCKING_WAIT"] = "1" coordinator = CoordinatorP2P( c10d_backend=c10d_backend, init_method=rdzv_init_url, max_num_trainers=max_world_size, process_group_timeout=60000, ) trainer = ElasticTrainer( use_gpu=classy_args.device == "gpu", num_dataloader_workers=classy_args.num_workers, local_rank=local_rank, elastic_coordinator=coordinator, input_args={}, ) trainer.train(task)
def parse_args(): """Parse arguments. Parses the args from argparse. If hydra is installed, uses hydra based args (experimental). """ if hydra_available: global args, config _parse_hydra_args() return args, config else: args = parse_train_arguments() config = load_json(args.config_file) return args, config
@hydra.main(config_path="hydra_configs", config_name="args") def hydra_main(cfg): args = cfg check_generic_args(cfg) config = omegaconf.OmegaConf.to_container(cfg.config) main(args, config) # run all the things: if __name__ == "__main__": logger = logging.getLogger() logger.setLevel(logging.INFO) logging.info("Classy Vision's default training script.") # This imports all modules in the same directory as classy_train.py # Because of the way Classy Vision's registration decorators work, # importing a module has a side effect of registering it with Classy # Vision. This means you can give classy_train.py a config referencing your # custom module (e.g. my_dataset) and it'll actually know how to # instantiate it. file_root = Path(__file__).parent import_all_packages_from_directory(file_root) if hydra_available: hydra_main() else: args = parse_train_arguments() config = load_json(args.config_file) main(args, config)
def test_load_config(self): expected_config = self._get_config() config = util.load_json(self._json_config_file) self.assertEqual(config, expected_config)