def add_loss_hooks(hooks, loss_cfg, cfg): if cfg.LOSS.name == "swav_loss": hooks.extend([SwAVUpdateQueueScoresHook(), NormalizePrototypesHook()]) if cfg.LOSS.name == "swav_momentum_loss": hooks.extend( [ SwAVMomentumHook( cfg.LOSS["swav_momentum_loss"]["momentum"], cfg.LOSS["swav_momentum_loss"]["momentum_eval_mode_iter_start"], cfg.LOSS["swav_momentum_loss"]["crops_for_assign"], ), SwAVMomentumNormalizePrototypesHook(), ] ) if cfg.LOSS.name == "dino_loss": hooks.append(DINOHook()) if cfg.LOSS.name == "deepclusterv2_loss": hooks.extend([InitMemoryHook(), ClusterMemoryHook()]) if cfg.LOSS.name == "moco_loss": hooks.extend( [ MoCoHook( cfg.LOSS["moco_loss"]["momentum"], shuffle_batch=(not cfg.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN), ) ] ) return hooks
def swav(head: str = "swav_head", **kwargs): loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get("swav_loss")(**kwargs) head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) return loss_fn, head, [ SwAVUpdateQueueScoresHook(), NormalizePrototypesHook(), TrainingSetupHook() ]
def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]: """ The utility function that prepares all the hoooks that will be used in training based on user selection. Some basic hooks are used by default. Optional hooks: - Tensorboard hook, - loss specific hooks (swav loss, deepcluster loss, moco loss) used only when the loss is being used - model complexity hook (if user wants to compute model flops, activations, params) enable the hook via HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY = True Returns: hooks (List(functions)): list containing the hook functions that will be used """ hooks = [] # conditionally add hooks based on use-case if cfg.HOOKS.PERF_STATS.MONITOR_PERF_STATS: perf_stat_freq = (cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY if cfg.HOOKS.PERF_STATS.PERF_STAT_FREQUENCY > 0 else None) hooks.append(LogPerfTimeMetricsHook(perf_stat_freq)) if cfg.LOSS.name == "swav_loss": hooks.extend([SwAVUpdateQueueScoresHook(), NormalizePrototypesHook()]) if cfg.LOSS.name == "swav_momentum_loss": hooks.extend([ SwAVMomentumHook( cfg.LOSS["swav_momentum_loss"]["momentum"], cfg.LOSS["swav_momentum_loss"] ["momentum_eval_mode_iter_start"], cfg.LOSS["swav_momentum_loss"]["crops_for_assign"], ), SwAVMomentumNormalizePrototypesHook(), ]) if cfg.LOSS.name == "deepclusterv2_loss": hooks.extend([InitMemoryHook(), ClusterMemoryHook()]) if cfg.LOSS.name == "moco_loss": hooks.extend([ MoCoHook( cfg.LOSS["moco_loss"]["momentum"], shuffle_batch=( not cfg.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN), ) ]) if cfg.HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY: hooks.extend([SSLModelComplexityHook()]) if cfg.HOOKS.LOG_GPU_STATS: hooks.extend([LogGpuStatsHook()]) if cfg.HOOKS.MEMORY_SUMMARY.PRINT_MEMORY_SUMMARY: hooks.extend( [LogGpuMemoryHook(cfg.HOOKS.MEMORY_SUMMARY.LOG_ITERATION_NUM)]) if cfg.HOOKS.TENSORBOARD_SETUP.USE_TENSORBOARD: assert is_tensorboard_available(), ( "Tensorboard must be installed to use it. Please install tensorboard using:" "If pip environment: `pip install tensorboard` " "If using conda and you prefer conda install of tensorboard: " "`conda install -c conda-forge tensorboard`") tb_hook = get_tensorboard_hook(cfg) hooks.extend([tb_hook]) if cfg.MODEL.GRAD_CLIP.USE_GRAD_CLIP: hooks.extend([ GradClipHook( norm_type=cfg.MODEL.GRAD_CLIP.NORM_TYPE, max_norm=cfg.MODEL.GRAD_CLIP.MAX_NORM, ) ]) # hooks that are used irrespective of workflow type rolling_btime_freq = (cfg.HOOKS.PERF_STATS.ROLLING_BTIME_FREQ if cfg.HOOKS.PERF_STATS.ROLLING_BTIME_FREQ > 0 else None) if ProfilingHook.is_enabled(cfg.PROFILING): hooks.append(ProfilingHook(profiling_config=cfg.PROFILING)) world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE checkpoint_folder = get_checkpoint_folder(cfg) hooks.extend([ CheckNanLossHook(), SetDataSamplerEpochHook(), FreezeParametersHook(), UpdateBatchesSeenHook(), UpdateTrainBatchTimeHook(), UpdateTestBatchTimeHook(), UpdateTrainIterationNumHook(), LogLossMetricsCheckpointHook(world_size), LogLossLrEtaHook(checkpoint_folder, rolling_btime_freq), ]) return hooks
def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]: """ The utility function that prepares all the hoooks that will be used in training based on user selection. Some basic hooks are used by default. Optional hooks: - Tensorboard hook, - loss specific hooks (swav loss, deepcluster loss, moco loss) used only when the loss is being used - model complexity hook (if user wants to compute model flops, activations, params) enable the hook via MODEL.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY = True Returns: hooks (List(functions)): list containing the hook functions that will be used """ hooks = [] # conditionally add hooks based on use-case if cfg.MONITOR_PERF_STATS: perf_stat_freq = ( cfg.PERF_STAT_FREQUENCY if cfg.PERF_STAT_FREQUENCY > 0 else None ) hooks.append(LogPerfTimeMetricsHook(perf_stat_freq)) if cfg.LOSS.name == "swav_loss": hooks.extend([SwAVUpdateQueueScoresHook(), NormalizePrototypesHook()]) if cfg.LOSS.name == "swav_momentum_loss": hooks.extend( [ SwAVMomentumHook( cfg.LOSS["swav_momentum_loss"]["momentum"], cfg.LOSS["swav_momentum_loss"]["momentum_eval_mode_iter_start"], cfg.LOSS["swav_momentum_loss"]["crops_for_assign"], ), SwAVMomentumNormalizePrototypesHook(), ] ) if cfg.LOSS.name == "deepclusterv2_loss": hooks.extend([InitMemoryHook(), ClusterMemoryHook()]) if cfg.LOSS.name == "moco_loss": hooks.extend( [ MoCoHook( cfg.LOSS["moco_loss"]["momentum"], shuffle_batch=(not cfg.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN), ) ] ) if cfg.MODEL.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY: hooks.extend([SSLModelComplexityHook()]) if cfg.TENSORBOARD_SETUP.USE_TENSORBOARD: assert is_tensorboard_available(), "Tensorboard must be installed to use it." tb_hook = get_tensorboard_hook(cfg) hooks.extend([tb_hook]) # hooks that are used irrespective of workflow type rolling_btime_freq = cfg.ROLLING_BTIME_FREQ if cfg.ROLLING_BTIME_FREQ > 0 else None hooks.extend( [ CheckNanLossHook(), SetDataSamplerEpochHook(), FreezeParametersHook(), UpdateBatchesSeenHook(), UpdateTrainBatchTimeHook(), UpdateTestBatchTimeHook(), UpdateTrainIterationNumHook(), LogLossMetricsCheckpointHook(), LogLossLrEtaHook(rolling_btime_freq), LogGpuStatsHook(), ] ) return hooks