Example #1
0
def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
    """
    The utility function that prepares all the hoooks that will be used in training
    based on user selection. Some basic hooks are used by default.

    Optional hooks:
        - Tensorboard hook,
        - loss specific hooks (swav loss, deepcluster loss, moco loss) used only when the
          loss is being used
        - model complexity hook (if user wants to compute model flops, activations, params)
          enable the hook via HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY = True

    Returns:
        hooks (List(functions)): list containing the hook functions that will be used
    """
    hooks = []

    # conditionally add hooks based on use-case
    if cfg.HOOKS.PERF_STATS.MONITOR_PERF_STATS:
        perf_stat_freq = (cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY
                          if cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY > 0 else
                          None)
        hooks.append(LogPerfTimeMetricsHook(perf_stat_freq))
    if cfg.LOSS.name == "swav_loss":
        hooks.extend([SwAVUpdateQueueScoresHook(), NormalizePrototypesHook()])
    if cfg.LOSS.name == "swav_momentum_loss":
        hooks.extend([
            SwAVMomentumHook(
                cfg.LOSS["swav_momentum_loss"]["momentum"],
                cfg.LOSS["swav_momentum_loss"]
                ["momentum_eval_mode_iter_start"],
                cfg.LOSS["swav_momentum_loss"]["crops_for_assign"],
            ),
            SwAVMomentumNormalizePrototypesHook(),
        ])
    if cfg.LOSS.name == "deepclusterv2_loss":
        hooks.extend([InitMemoryHook(), ClusterMemoryHook()])
    if cfg.LOSS.name == "moco_loss":
        hooks.extend([
            MoCoHook(
                cfg.LOSS["moco_loss"]["momentum"],
                shuffle_batch=(
                    not cfg.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN),
            )
        ])
    if cfg.HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY:
        hooks.extend([SSLModelComplexityHook()])
    if cfg.HOOKS.LOG_GPU_STATS:
        hooks.extend([LogGpuStatsHook()])
    if cfg.HOOKS.MEMORY_SUMMARY.PRINT_MEMORY_SUMMARY:
        hooks.extend(
            [LogGpuMemoryHook(cfg.HOOKS.MEMORY_SUMMARY.LOG_ITERATION_NUM)])
    if cfg.HOOKS.TENSORBOARD_SETUP.USE_TENSORBOARD:
        assert is_tensorboard_available(), (
            "Tensorboard must be installed to use it. Please install tensorboard using:"
            "If pip environment: `pip install tensorboard` "
            "If using conda and you prefer conda install of tensorboard: "
            "`conda install -c conda-forge tensorboard`")
        tb_hook = get_tensorboard_hook(cfg)
        hooks.extend([tb_hook])
    if cfg.MODEL.GRAD_CLIP.USE_GRAD_CLIP:
        hooks.extend([
            GradClipHook(
                norm_type=cfg.MODEL.GRAD_CLIP.NORM_TYPE,
                max_norm=cfg.MODEL.GRAD_CLIP.MAX_NORM,
            )
        ])

    # hooks that are used irrespective of workflow type
    rolling_btime_freq = (cfg.HOOKS.PERF_STATS.ROLLING_BTIME_FREQ
                          if cfg.HOOKS.PERF_STATS.ROLLING_BTIME_FREQ > 0 else
                          None)

    if ProfilingHook.is_enabled(cfg.PROFILING):
        hooks.append(ProfilingHook(profiling_config=cfg.PROFILING))

    world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
    checkpoint_folder = get_checkpoint_folder(cfg)
    hooks.extend([
        CheckNanLossHook(),
        SetDataSamplerEpochHook(),
        FreezeParametersHook(),
        UpdateBatchesSeenHook(),
        UpdateTrainBatchTimeHook(),
        UpdateTestBatchTimeHook(),
        UpdateTrainIterationNumHook(),
        LogLossMetricsCheckpointHook(world_size),
        LogLossLrEtaHook(checkpoint_folder, rolling_btime_freq),
    ])
    return hooks
Example #2
0
def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
    """
    The utility function that prepares all the hoooks that will be used in training
    based on user selection. Some basic hooks are used by default.

    Optional hooks:
        - Tensorboard hook,
        - loss specific hooks (swav loss, deepcluster loss, moco loss) used only when the
          loss is being used
        - model complexity hook (if user wants to compute model flops, activations, params)
          enable the hook via MODEL.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY = True

    Returns:
        hooks (List(functions)): list containing the hook functions that will be used
    """
    hooks = []

    # conditionally add hooks based on use-case
    if cfg.MONITOR_PERF_STATS:
        perf_stat_freq = (
            cfg.PERF_STAT_FREQUENCY if cfg.PERF_STAT_FREQUENCY > 0 else None
        )
        hooks.append(LogPerfTimeMetricsHook(perf_stat_freq))
    if cfg.LOSS.name == "swav_loss":
        hooks.extend([SwAVUpdateQueueScoresHook(), NormalizePrototypesHook()])
    if cfg.LOSS.name == "swav_momentum_loss":
        hooks.extend(
            [
                SwAVMomentumHook(
                    cfg.LOSS["swav_momentum_loss"]["momentum"],
                    cfg.LOSS["swav_momentum_loss"]["momentum_eval_mode_iter_start"],
                    cfg.LOSS["swav_momentum_loss"]["crops_for_assign"],
                ),
                SwAVMomentumNormalizePrototypesHook(),
            ]
        )
    if cfg.LOSS.name == "deepclusterv2_loss":
        hooks.extend([InitMemoryHook(), ClusterMemoryHook()])
    if cfg.LOSS.name == "moco_loss":
        hooks.extend(
            [
                MoCoHook(
                    cfg.LOSS["moco_loss"]["momentum"],
                    shuffle_batch=(not cfg.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN),
                )
            ]
        )
    if cfg.MODEL.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY:
        hooks.extend([SSLModelComplexityHook()])
    if cfg.TENSORBOARD_SETUP.USE_TENSORBOARD:
        assert is_tensorboard_available(), "Tensorboard must be installed to use it."
        tb_hook = get_tensorboard_hook(cfg)
        hooks.extend([tb_hook])

    # hooks that are used irrespective of workflow type
    rolling_btime_freq = cfg.ROLLING_BTIME_FREQ if cfg.ROLLING_BTIME_FREQ > 0 else None
    hooks.extend(
        [
            CheckNanLossHook(),
            SetDataSamplerEpochHook(),
            FreezeParametersHook(),
            UpdateBatchesSeenHook(),
            UpdateTrainBatchTimeHook(),
            UpdateTestBatchTimeHook(),
            UpdateTrainIterationNumHook(),
            LogLossMetricsCheckpointHook(),
            LogLossLrEtaHook(rolling_btime_freq),
            LogGpuStatsHook(),
        ]
    )
    return hooks