def test_model_complexity_hook(self) -> None:
        model_configs = get_test_model_configs()

        task = get_test_classy_task()
        task.prepare()

        # create a model complexity hook
        model_complexity_hook = ModelComplexityHook()

        for model_config in model_configs:
            model = build_model(model_config)

            task.base_model = model

            with self.assertLogs():
                model_complexity_hook.on_start(task)
Example #2
0
def main(local_rank, c10d_backend, rdzv_init_url, max_world_size, classy_args):
    torch.manual_seed(0)
    set_video_backend(classy_args.video_backend)

    # Loads config, sets up task
    config = load_json(classy_args.config_file)

    task = build_task(config)

    # Load checkpoint, if available
    checkpoint = load_checkpoint(classy_args.checkpoint_folder)
    task.set_checkpoint(checkpoint)

    pretrained_checkpoint = load_checkpoint(classy_args.pretrained_checkpoint_folder)
    if pretrained_checkpoint is not None:
        assert isinstance(
            task, FineTuningTask
        ), "Can only use a pretrained checkpoint for fine tuning tasks"
        task.set_pretrained_checkpoint(pretrained_checkpoint)

    hooks = [
        LossLrMeterLoggingHook(classy_args.log_freq),
        ModelComplexityHook(),
        TimeMetricsHook(),
    ]

    if classy_args.checkpoint_folder != "":
        args_dict = vars(classy_args)
        args_dict["config"] = config
        hooks.append(
            CheckpointHook(
                classy_args.checkpoint_folder,
                args_dict,
                checkpoint_period=classy_args.checkpoint_period,
            )
        )
    if classy_args.profiler:
        hooks.append(ProfilerHook())

    task.set_hooks(hooks)

    assert c10d_backend == Backend.NCCL or c10d_backend == Backend.GLOO
    if c10d_backend == torch.distributed.Backend.NCCL:
        # needed to enable NCCL error handling
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

    coordinator = CoordinatorP2P(
        c10d_backend=c10d_backend,
        init_method=rdzv_init_url,
        max_num_trainers=max_world_size,
        process_group_timeout=60000,
    )
    trainer = ElasticTrainer(
        use_gpu=classy_args.device == "gpu",
        num_dataloader_workers=classy_args.num_workers,
        local_rank=local_rank,
        elastic_coordinator=coordinator,
        input_args={},
    )
    trainer.train(task)
    def test_model_complexity(self) -> None:
        """
        Test that the number of parameters and the FLOPs are calcuated correctly.
        """
        model_configs = get_test_model_configs()
        expected_mega_flops = [4122, 4274, 106152]
        expected_params = [25557032, 25028904, 43009448]
        local_variables = {}

        task = get_test_classy_task()
        task.prepare()

        # create a model complexity hook
        model_complexity_hook = ModelComplexityHook()

        for model_config, mega_flops, params in zip(model_configs,
                                                    expected_mega_flops,
                                                    expected_params):
            model = build_model(model_config)

            task.base_model = model

            with self.assertLogs() as log_watcher:
                model_complexity_hook.on_start(task, local_variables)

            # there should be 2 log statements generated
            self.assertEqual(len(log_watcher.output), 2)

            # first statement - either the MFLOPs or a warning
            if mega_flops is not None:
                match = re.search(
                    r"FLOPs for forward pass: (?P<mega_flops>[-+]?\d*\.\d+|\d+) MFLOPs",
                    log_watcher.output[0],
                )
                self.assertIsNotNone(match)
                self.assertEqual(mega_flops, float(match.group("mega_flops")))
            else:
                self.assertIn("Model contains unsupported modules",
                              log_watcher.output[0])

            # second statement
            match = re.search(
                r"Number of parameters in model: (?P<params>[-+]?\d*\.\d+|\d+)",
                log_watcher.output[1],
            )
            self.assertIsNotNone(match)
            self.assertEqual(params, float(match.group("params")))
def configure_hooks(args, config):
    hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]

    # Make a folder to store checkpoints and tensorboard logging outputs
    suffix = datetime.now().isoformat()
    base_folder = f"{Path(__file__).parent}/output_{suffix}"
    if args.checkpoint_folder == "":
        args.checkpoint_folder = base_folder + "/checkpoints"
        os.makedirs(args.checkpoint_folder, exist_ok=True)

    logging.info(f"Logging outputs to {base_folder}")
    logging.info(f"Logging checkpoints to {args.checkpoint_folder}")

    if not args.skip_tensorboard:
        try:
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(Path(base_folder) / "tensorboard", exist_ok=True)
            tb_writer = SummaryWriter(log_dir=Path(base_folder) /
                                      "tensorboard")
            hooks.append(TensorboardPlotHook(tb_writer))
        except ImportError:
            logging.warning(
                "tensorboard not installed, skipping tensorboard hooks")

    args_dict = vars(args)
    args_dict["config"] = config
    hooks.append(
        CheckpointHook(args.checkpoint_folder,
                       args_dict,
                       checkpoint_period=args.checkpoint_period))

    if args.profiler:
        hooks.append(ProfilerHook())
    if args.show_progress:
        hooks.append(ProgressBarHook())
    if args.visdom_server != "":
        hooks.append(VisdomHook(args.visdom_server, args.visdom_port))

    return hooks