def test_model_complexity_hook(self) -> None: model_configs = get_test_model_configs() task = get_test_classy_task() task.prepare() # create a model complexity hook model_complexity_hook = ModelComplexityHook() for model_config in model_configs: model = build_model(model_config) task.base_model = model with self.assertLogs(): model_complexity_hook.on_start(task)
def main(local_rank, c10d_backend, rdzv_init_url, max_world_size, classy_args): torch.manual_seed(0) set_video_backend(classy_args.video_backend) # Loads config, sets up task config = load_json(classy_args.config_file) task = build_task(config) # Load checkpoint, if available checkpoint = load_checkpoint(classy_args.checkpoint_folder) task.set_checkpoint(checkpoint) pretrained_checkpoint = load_checkpoint(classy_args.pretrained_checkpoint_folder) if pretrained_checkpoint is not None: assert isinstance( task, FineTuningTask ), "Can only use a pretrained checkpoint for fine tuning tasks" task.set_pretrained_checkpoint(pretrained_checkpoint) hooks = [ LossLrMeterLoggingHook(classy_args.log_freq), ModelComplexityHook(), TimeMetricsHook(), ] if classy_args.checkpoint_folder != "": args_dict = vars(classy_args) args_dict["config"] = config hooks.append( CheckpointHook( classy_args.checkpoint_folder, args_dict, checkpoint_period=classy_args.checkpoint_period, ) ) if classy_args.profiler: hooks.append(ProfilerHook()) task.set_hooks(hooks) assert c10d_backend == Backend.NCCL or c10d_backend == Backend.GLOO if c10d_backend == torch.distributed.Backend.NCCL: # needed to enable NCCL error handling os.environ["NCCL_BLOCKING_WAIT"] = "1" coordinator = CoordinatorP2P( c10d_backend=c10d_backend, init_method=rdzv_init_url, max_num_trainers=max_world_size, process_group_timeout=60000, ) trainer = ElasticTrainer( use_gpu=classy_args.device == "gpu", num_dataloader_workers=classy_args.num_workers, local_rank=local_rank, elastic_coordinator=coordinator, input_args={}, ) trainer.train(task)
def test_model_complexity(self) -> None: """ Test that the number of parameters and the FLOPs are calcuated correctly. """ model_configs = get_test_model_configs() expected_mega_flops = [4122, 4274, 106152] expected_params = [25557032, 25028904, 43009448] local_variables = {} task = get_test_classy_task() task.prepare() # create a model complexity hook model_complexity_hook = ModelComplexityHook() for model_config, mega_flops, params in zip(model_configs, expected_mega_flops, expected_params): model = build_model(model_config) task.base_model = model with self.assertLogs() as log_watcher: model_complexity_hook.on_start(task, local_variables) # there should be 2 log statements generated self.assertEqual(len(log_watcher.output), 2) # first statement - either the MFLOPs or a warning if mega_flops is not None: match = re.search( r"FLOPs for forward pass: (?P<mega_flops>[-+]?\d*\.\d+|\d+) MFLOPs", log_watcher.output[0], ) self.assertIsNotNone(match) self.assertEqual(mega_flops, float(match.group("mega_flops"))) else: self.assertIn("Model contains unsupported modules", log_watcher.output[0]) # second statement match = re.search( r"Number of parameters in model: (?P<params>[-+]?\d*\.\d+|\d+)", log_watcher.output[1], ) self.assertIsNotNone(match) self.assertEqual(params, float(match.group("params")))
def configure_hooks(args, config): hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()] # Make a folder to store checkpoints and tensorboard logging outputs suffix = datetime.now().isoformat() base_folder = f"{Path(__file__).parent}/output_{suffix}" if args.checkpoint_folder == "": args.checkpoint_folder = base_folder + "/checkpoints" os.makedirs(args.checkpoint_folder, exist_ok=True) logging.info(f"Logging outputs to {base_folder}") logging.info(f"Logging checkpoints to {args.checkpoint_folder}") if not args.skip_tensorboard: try: from torch.utils.tensorboard import SummaryWriter os.makedirs(Path(base_folder) / "tensorboard", exist_ok=True) tb_writer = SummaryWriter(log_dir=Path(base_folder) / "tensorboard") hooks.append(TensorboardPlotHook(tb_writer)) except ImportError: logging.warning( "tensorboard not installed, skipping tensorboard hooks") args_dict = vars(args) args_dict["config"] = config hooks.append( CheckpointHook(args.checkpoint_folder, args_dict, checkpoint_period=args.checkpoint_period)) if args.profiler: hooks.append(ProfilerHook()) if args.show_progress: hooks.append(ProgressBarHook()) if args.visdom_server != "": hooks.append(VisdomHook(args.visdom_server, args.visdom_port)) return hooks