Example #1
0
def get_checkpoint_model_state_dict(config: AttrDict, state_dict: Dict[str, Any]):
    """
    Given a specified pre-trained VISSL model (composed of head and trunk),
    we get the state_dict that can be loaded by appending prefixes to model and trunk.

    Args:
        config (AttrDict): full config file
        state_dict (Dict): raw state_dict loaded from the checkpoint or weights file

    Returns:
        state_dict (Dict): vissl state_dict with layer names matching compatible with
                           vissl model. Hence this state_dict can be loaded directly.
    """
    from vissl.models import is_feature_extractor_model

    classy_state_dict = state_dict["base_model"]["model"]
    trunk_append_prefix, heads_append_prefix = "trunk.", "heads."
    if is_feature_extractor_model(config.MODEL):
        trunk_append_prefix = "trunk.base_model."

    trunk_state_dict = append_module_prefix(
        classy_state_dict["trunk"], trunk_append_prefix
    )
    heads_state_dict = append_module_prefix(
        classy_state_dict["heads"], heads_append_prefix
    )
    state_dict = {}
    state_dict.update(trunk_state_dict)
    state_dict.update(heads_state_dict)
    return state_dict
Example #2
0
def check_model_compatibilty(config: AttrDict, state_dict: Dict[str, Any]):
    """
    Given a VISSL model and state_dict, check if the state_dict can be loaded
    to VISSL model (trunk + head) based on the trunk and head prefix that is expected.
    If not compatible, we raise exception.

    Prefix checked for head: `heads.`
    Prefix checked for trunk: `trunk._feature_blocks.` or `trunk.base_model._feature_blocks.`
                              depending on the workflow type (training | evaluation).

    Args:
        config (AttrDict): root config
        state_dict (Dict[str, Any]): state dict that should be checked for compatibility
    """
    from vissl.models import is_feature_extractor_model

    trunk_append_prefix, heads_append_prefix = "trunk._feature_blocks.", "heads."
    if is_feature_extractor_model(config.MODEL):
        trunk_append_prefix = "trunk.base_model._feature_blocks."

    is_compatible = True
    for layername in state_dict.keys():
        if not (
            layername.startswith(trunk_append_prefix)
            or layername.startswith(heads_append_prefix)
        ):
            is_compatible = False
            break
    if not is_compatible:
        raise Exception(
            "Model provided in config.MODEL.WEIGHTS_INIT.PARAMS_FILE is not compatible "
            "with VISSL. Please set config.MODEL.WEIGHTS_INIT.APPEND_PREFIX and "
            "config.MODEL.WEIGHTS_INIT.REMOVE_PREFIX for making model compatible. "
            f"Expected trunk prefix: {trunk_append_prefix}"
        )
Example #3
0
def get_checkpoint_model_state_dict(config: AttrDict, state_dict: Dict[str,
                                                                       Any]):
    """
    Given a specified pre-trained VISSL model (composed of head and trunk),
    we get the state_dict that can be loaded by appending prefixes to model and trunk.

    Args:
        config (AttrDict): full config file
        state_dict (Dict): raw state_dict loaded from the checkpoint or weights file

    Returns:
        state_dict (Dict): vissl state_dict with layer names matching compatible with
                           vissl model. Hence this state_dict can be loaded directly.
    """
    from vissl.models import is_feature_extractor_model

    classy_state_dict = state_dict["base_model"]["model"]
    trunk_append_prefix, heads_append_prefix = "trunk.", "heads."
    # if the model is being loaded for feature extraction, we check
    # that the model is not already a feature extractor trunk. If
    # not, we add the appropriate prefix to append.
    if is_feature_extractor_model(
            config.MODEL) and not is_feature_extractor_state_dict(
                classy_state_dict["trunk"]):
        trunk_append_prefix = "trunk.base_model."
    elif not is_feature_extractor_model(config.MODEL):
        # Getting rid of the feature extractor prefix if we do not need it
        classy_state_dict["trunk"] = replace_module_prefix(
            classy_state_dict["trunk"], "base_model.", "")

    trunk_state_dict = append_module_prefix(classy_state_dict["trunk"],
                                            trunk_append_prefix)
    heads_state_dict = append_module_prefix(classy_state_dict["heads"],
                                            heads_append_prefix)
    state_dict = {}
    state_dict.update(trunk_state_dict)
    state_dict.update(heads_state_dict)
    return state_dict
Example #4
0
def adapt_to_feature_extractor_config(
        config: AttrDict, state_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Adapt a state dictionary to be compatible with a feature extractor configuration
    by replacing the "trunk." by "trunk.base_model."
    """
    from vissl.models import is_feature_extractor_model

    if not is_feature_extractor_model(config.MODEL):
        return state_dict

    return {
        k.replace("trunk.", "trunk.base_model."): v
        for k, v in state_dict.items()
    }
Example #5
0
def assert_hydra_conf(cfg):
    """
    Infer values of few parameters in the config file using the value of other config parameters
    1. Inferring losses
    2. Auto scale learning rate if user has specified auto scaling to be True.
    3. Infer meter names (model layer name being evaluated) since we support list meters
       that have multiple output and same target. This is very common in self-supervised
       learning where we want to evaluate metric for several layers of the models. VISSL
       supports running evaluation for multiple model layers in a single training run.
    4. Support multi-gpu DDP eval model by attaching a dummy parameter. This is particularly
       helpful for the multi-gpu feature extraction especially when the dataset is large for
       which features are being extracted.
    5. Infer what kind of labels are being used. If user has specified a labels source, we set
       LABEL_TYPE to "standard" (also vissl default), otherwise if no label is specified, we
       set the LABEL_TYPE to "sample_index".
    """
    cfg = infer_losses_config(cfg)
    cfg = infer_learning_rate(cfg)

    # in case of linear evaluation, we often evaluate several layers at a time. For each
    # layer, there's a separate accuracy meter. In such case, we want to output the layer
    # name in the meters output to make it easy to interpret the results. This is
    # currently only supported for cases where we have linear evaluation.
    if cfg.METERS is not None:
        from vissl.models import is_feature_extractor_model

        meter_name = cfg.METERS.get("name", "")
        valid_meters = ["accuracy_list_meter", "mean_ap_list_meter"]
        if meter_name:
            if meter_name in valid_meters and is_feature_extractor_model(
                    cfg.MODEL):
                cfg.METERS[meter_name]["num_meters"] = len(
                    cfg.MODEL.FEATURE_EVAL_SETTINGS.
                    LINEAR_EVAL_FEAT_POOL_OPS_MAP)
                cfg.METERS[meter_name]["meter_names"] = [
                    item[0] for item in cfg.MODEL.FEATURE_EVAL_SETTINGS.
                    LINEAR_EVAL_FEAT_POOL_OPS_MAP
                ]

    # in case of feature evaluation mode, we freeze the trunk. The Feature evaluation mode
    # is used for the feature extraction of trunk as well. VISSL supports distributed feature
    # extraction to speed up the extraction time. Since the model needs to be DDP for the
    # distributed extraction, we need some dummy parameters in the model otherwise model
    # can't be converted to DDP. So we attach some dummy head to the model.
    world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
    if (cfg.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON
            and cfg.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY
            and cfg.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY
            and world_size > 1 and len(cfg.MODEL.HEAD.PARAMS) == 0):
        cfg.MODEL.HEAD.PARAMS = [["mlp", {"dims": [2048, 1000]}]]

    # in SSL, during pre-training we don't want to use annotated labels or during feature
    # extraction, we don't have annotated labels for some datasets. In such cases, we set
    # the label type to be just the image index in the dataset, unless the
    # user has specifically provided "zero" as the label type, which is
    # necessary when the CutMixUp collator is being used for self-supervised
    # training.
    if len(cfg.DATA.TRAIN.LABEL_SOURCES
           ) == 0 and cfg.DATA.TRAIN.LABEL_TYPE != "zero":
        cfg.DATA.TRAIN.LABEL_TYPE = "sample_index"
    if len(cfg.DATA.TEST.LABEL_SOURCES
           ) == 0 and cfg.DATA.TEST.LABEL_TYPE != "zero":
        cfg.DATA.TEST.LABEL_TYPE = "sample_index"

    # if the user has specified the model initialization from a params_file, we check if
    # the params_file is a url. If it is, we download the file to a local cache directory
    # and use that instead
    from vissl.utils.checkpoint import get_checkpoint_folder
    from vissl.utils.io import cache_url, is_url

    if is_url(cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE):
        checkpoint_dir = get_checkpoint_folder(cfg)
        cache_dir = f"{checkpoint_dir}/params_file_cache/"
        cached_url_path = cache_url(url=cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE,
                                    cache_dir=cache_dir)
        cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE = cached_url_path

    # if we use a zero optimizer, we nest the optimizer related settings under the
    # base_optimizer.
    if cfg.OPTIMIZER.use_zero:
        cfg.OPTIMIZER["base_optimizer"] = cfg.OPTIMIZER.copy()
        cfg.OPTIMIZER.name = "zero"
        del cfg.OPTIMIZER.base_optimizer["param_schedulers"]
        del cfg.OPTIMIZER.base_optimizer["regularize_bn"]
        del cfg.OPTIMIZER.base_optimizer["regularize_bias"]
        del cfg.OPTIMIZER.base_optimizer["num_epochs"]
        del cfg.OPTIMIZER.base_optimizer["use_zero"]
        del cfg.OPTIMIZER.base_optimizer["head_optimizer_params"]
Example #6
0
def infer_and_assert_hydra_config(cfg):
    """
    Infer values of few parameters in the config file using the value of other config parameters
    1. Inferring losses
    2. Auto scale learning rate if user has specified auto scaling to be True.
    3. Infer meter names (model layer name being evaluated) since we support list meters
       that have multiple output and same target. This is very common in self-supervised
       learning where we want to evaluate metric for several layers of the models. VISSL
       supports running evaluation for multiple model layers in a single training run.
    4. Support multi-gpu DDP eval model by attaching a dummy parameter. This is particularly
       helpful for the multi-gpu feature extraction especially when the dataset is large for
       which features are being extracted.
    5. Infer what kind of labels are being used. If user has specified a labels source, we set
       LABEL_TYPE to "standard" (also vissl default), otherwise if no label is specified, we
       set the LABEL_TYPE to "sample_index".
    """
    cfg = infer_losses_config(cfg)
    cfg = infer_learning_rate(cfg)

    # pass the seed to cfg["MODEL"] so that model init on different nodes can
    # use the same seed.
    # TODO (Min): once FSDP supports sync'ing weights from rank 0, we don't need
    #             this anymore.
    cfg["MODEL"]["_MODEL_INIT_SEED"] = cfg.SEED_VALUE

    # in case of linear evaluation, we often evaluate several layers at a time. For each
    # layer, there's a separate accuracy meter. In such case, we want to output the layer
    # name in the meters output to make it easy to interpret the results. This is
    # currently only supported for cases where we have linear evaluation.
    if cfg.METERS is not None:
        from vissl.models import is_feature_extractor_model

        meter_name = cfg.METERS.get("name", "")
        valid_meters = ["accuracy_list_meter", "mean_ap_list_meter"]
        if meter_name:
            if meter_name in valid_meters and is_feature_extractor_model(cfg.MODEL):
                cfg.METERS[meter_name]["num_meters"] = len(
                    cfg.MODEL.FEATURE_EVAL_SETTINGS.LINEAR_EVAL_FEAT_POOL_OPS_MAP
                )
                cfg.METERS[meter_name]["meter_names"] = [
                    item[0]
                    for item in cfg.MODEL.FEATURE_EVAL_SETTINGS.LINEAR_EVAL_FEAT_POOL_OPS_MAP
                ]

    # in case of feature evaluation mode, we freeze the trunk. The Feature evaluation mode
    # is used for the feature extraction of trunk as well. VISSL supports distributed feature
    # extraction to speed up the extraction time. Since the model needs to be DDP for the
    # distributed extraction, we need some dummy parameters in the model otherwise model
    # can't be converted to DDP. So we attach some dummy head to the model.
    world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
    if (
        cfg.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON
        and cfg.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY
        and cfg.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY
        and world_size > 1
        and len(cfg.MODEL.HEAD.PARAMS) == 0
    ):
        cfg.MODEL.HEAD.PARAMS = [["mlp", {"dims": [2048, 1000]}]]

    # in SSL, during pre-training we don't want to use annotated labels or during feature
    # extraction, we don't have annotated labels for some datasets. In such cases, we set
    # the label type to be just the image index in the dataset, unless the
    # user has specifically provided "zero" as the label type, which is
    # necessary when the CutMixUp collator is being used for self-supervised
    # training.
    if len(cfg.DATA.TRAIN.LABEL_SOURCES) == 0 and cfg.DATA.TRAIN.LABEL_TYPE != "zero":
        cfg.DATA.TRAIN.LABEL_TYPE = "sample_index"
    if len(cfg.DATA.TEST.LABEL_SOURCES) == 0 and cfg.DATA.TEST.LABEL_TYPE != "zero":
        cfg.DATA.TEST.LABEL_TYPE = "sample_index"

    # if the user has specified the model initialization from a params_file, we check if
    # the params_file is a url. If it is, we download the file to a local cache directory
    # and use that instead
    from vissl.utils.checkpoint import get_checkpoint_folder
    from vissl.utils.io import cache_url, is_url

    if is_url(cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE):
        checkpoint_dir = get_checkpoint_folder(cfg)
        cache_dir = f"{checkpoint_dir}/params_file_cache/"
        cached_url_path = cache_url(
            url=cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE, cache_dir=cache_dir
        )
        cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE = cached_url_path

    # ZeRO2: Infer the settings for ShardedDDP which shards the optimizer state
    # and the model weights. For ShardedDDP, we must use the OSS optimizer,
    # set the right task name, use the PyTorch AMP if AMP is used.
    if cfg.MODEL.SHARDED_DDP_SETUP.USE_SDP:
        cfg.OPTIMIZER.use_zero = True
        cfg.TRAINER.TASK_NAME = "self_supervision_sdp_task"
        if cfg.MODEL.AMP_PARAMS.USE_AMP:
            cfg.MODEL.AMP_PARAMS.AMP_TYPE = "pytorch"

    # if we use a zero optimizer, we nest the optimizer related settings under the
    # base_optimizer.
    if cfg.OPTIMIZER.use_zero:
        cfg.OPTIMIZER["base_optimizer"] = cfg.OPTIMIZER.copy()
        cfg.OPTIMIZER.name = "zero"
        del cfg.OPTIMIZER.base_optimizer["param_schedulers"]
        del cfg.OPTIMIZER.base_optimizer["regularize_bn"]
        del cfg.OPTIMIZER.base_optimizer["regularize_bias"]
        del cfg.OPTIMIZER.base_optimizer["num_epochs"]
        del cfg.OPTIMIZER.base_optimizer["use_zero"]
        del cfg.OPTIMIZER.base_optimizer["head_optimizer_params"]

    # inference for the FSDP settings. Conditions are:
    # 1) use the FSDP task
    # 2) use the single param group in the optimizer
    # 3) if AMP is used, it must be PyTorch AMP
    # 4) If training SwAV, we automatically set the head to SwAV FSDP head
    # 4) Inference for the FSDP parameters to ensure the good convergence
    if cfg.MODEL.FSDP_CONFIG.AUTO_SETUP_FSDP:
        cfg.TRAINER.TASK_NAME = "self_supervision_fsdp_task"
        cfg.OPTIMIZER.construct_single_param_group_only = True

        # safely set flatten_parameters=True for FSDP trainings.
        cfg["MODEL"]["FSDP_CONFIG"]["flatten_parameters"] = True
        # recommended FSDP settings below for the convergence
        cfg["MODEL"]["FSDP_CONFIG"]["compute_dtype"] = "float32"

        # Inference of optimizer configuration
        if cfg["OPTIMIZER"]["use_larc"]:
            cfg["OPTIMIZER"]["name"] = "sgd_fsdp"

        # AMP based inference
        if cfg["MODEL"]["AMP_PARAMS"]["USE_AMP"]:
            cfg["MODEL"]["AMP_PARAMS"]["AMP_TYPE"] = "pytorch"
            cfg["MODEL"]["FSDP_CONFIG"]["mixed_precision"] = True
            cfg["MODEL"]["FSDP_CONFIG"]["fp32_reduce_scatter"] = True
        else:
            # if not using AMP, we can't use mixed_precision as it requires PyTorch AMP
            cfg["MODEL"]["FSDP_CONFIG"]["mixed_precision"] = False
            # if mixed_precision=False, FSDP mandates setting fp32_reduce_scatter=False
            cfg["MODEL"]["FSDP_CONFIG"]["fp32_reduce_scatter"] = False

        # Inference of the head in case of training with FSDP
        for i, head_param in enumerate(cfg.MODEL.HEAD.PARAMS):
            if head_param[0] == "swav_head":
                cfg.MODEL.HEAD.PARAMS[i][0] = "swav_head_fsdp"
            if head_param[0] == "eval_mlp":
                cfg.MODEL.HEAD.PARAMS[i][0] = "eval_mlp_fsdp"
            if head_param[0] == "mlp":
                cfg.MODEL.HEAD.PARAMS[i][0] = "mlp_fsdp"

        # Inference of the trunk in case of training with FSDP
        if cfg.MODEL.TRUNK.NAME == "regnet":
            cfg.MODEL.TRUNK.NAME = "regnet_fsdp"

        # Profiling the communication requires some setup for FSDP
        if cfg.PROFILING.MEMORY_PROFILING.TRACK_BY_LAYER_MEMORY:
            cfg["MODEL"]["FSDP_CONFIG"]["_TRACK_COMMUNICATIONS"] = True

        logging.info(f"Using the FSDP config: {cfg.MODEL.FSDP_CONFIG}")

    # Delete the AUTO_SETUP_FSDP key since we send the FSDP_CONFIG
    # to FSDP from fairscale which doesn't know about AUTO_SETUP_FSDP
    del cfg.MODEL.FSDP_CONFIG["AUTO_SETUP_FSDP"]
Example #7
0
def infer_and_assert_hydra_config(cfg, engine_name: str):
    """
    Infer values of few parameters in the config file using the value of other config parameters
    1. Inferring losses
    2. Auto scale learning rate if user has specified auto scaling to be True.
    3. Infer meter names (model layer name being evaluated) since we support list meters
       that have multiple output and same target. This is very common in self-supervised
       learning where we want to evaluate metric for several layers of the models. VISSL
       supports running evaluation for multiple model layers in a single training run.
    4. Support multi-gpu DDP eval model by attaching a dummy parameter. This is particularly
       helpful for the multi-gpu feature extraction especially when the dataset is large for
       which features are being extracted.
    5. Infer what kind of labels are being used. If user has specified a labels source, we set
       LABEL_TYPE to "standard" (also vissl default), otherwise if no label is specified, we
       set the LABEL_TYPE to "sample_index".
    """
    cfg = infer_losses_config(cfg)
    cfg = infer_learning_rate(cfg)
    assert_transforms(cfg)

    # pass the seed to cfg["MODEL"] so that model init on different nodes can
    # use the same seed.
    # TODO (Min): once FSDP supports sync'ing weights from rank 0, we don't need
    #             this anymore.
    cfg["MODEL"]["_MODEL_INIT_SEED"] = cfg.SEED_VALUE

    # in case of linear evaluation, we often evaluate several layers at a time. For each
    # layer, there's a separate accuracy meter. In such case, we want to output the layer
    # name in the meters output to make it easy to interpret the results. This is
    # currently only supported for cases where we have linear evaluation.
    if cfg.METERS is not None:
        from vissl.models import is_feature_extractor_model

        # Ensure backwards compatibility of cfg.METERS.name.
        meter_name = cfg.METERS.get("name", "")
        if meter_name:
            meter_names = set(cfg.METERS.get("names", []))
            meter_names.add(meter_name)
            cfg.METERS.names = list(meter_names)

        meter_names = cfg.METERS.get("names", [])
        valid_meters = [
            "accuracy_list_meter",
            "mean_ap_list_meter",
            "precision_at_k_list_meter",
            "recall_at_k_list_meter",
        ]

        for meter_name in meter_names:
            if meter_name in valid_meters:
                feat_eval_ops_map = (cfg.MODEL.FEATURE_EVAL_SETTINGS.
                                     LINEAR_EVAL_FEAT_POOL_OPS_MAP)
                all_meter_names = [item[0] for item in feat_eval_ops_map]
                if is_feature_extractor_model(cfg.MODEL):
                    cfg.METERS[meter_name]["num_meters"] = len(
                        feat_eval_ops_map)
                    cfg.METERS[meter_name]["meter_names"] = all_meter_names
                elif engine_name == "extract_label_predictions":
                    if len(feat_eval_ops_map) > 0:
                        cfg.METERS[meter_name]["num_meters"] = len(
                            feat_eval_ops_map)
                        cfg.METERS[meter_name]["meter_names"] = all_meter_names
                    else:
                        # if user is not extracting from multiple layers, we assume
                        # the model head is being used.
                        cfg.METERS[meter_name]["num_meters"] = 1

    # in SSL, during pre-training we don't want to use annotated labels or during feature
    # extraction, we don't have annotated labels for some datasets. In such cases, we set
    # the label type to be just the image index in the dataset, unless the
    # user has specifically provided "zero" as the label type, which is
    # necessary when the CutMixUp collator is being used for self-supervised
    # training.
    if len(cfg.DATA.TRAIN.LABEL_SOURCES
           ) == 0 and cfg.DATA.TRAIN.LABEL_TYPE != "zero":
        cfg.DATA.TRAIN.LABEL_TYPE = "sample_index"
    if len(cfg.DATA.TEST.LABEL_SOURCES
           ) == 0 and cfg.DATA.TEST.LABEL_TYPE != "zero":
        cfg.DATA.TEST.LABEL_TYPE = "sample_index"

    # if the user has specified the model initialization from a params_file, we check if
    # the params_file is a url. If it is, we download the file to a local cache directory
    # and use that instead
    from vissl.utils.checkpoint import get_checkpoint_folder
    from vissl.utils.io import cache_url, is_url

    if is_url(cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE):
        checkpoint_dir = get_checkpoint_folder(cfg)
        cache_dir = f"{checkpoint_dir}/params_file_cache/"
        cached_url_path = cache_url(url=cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE,
                                    cache_dir=cache_dir)
        cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE = cached_url_path

    # ZeRO2: Infer the settings for ShardedDDP which shards the optimizer state
    # and the model weights. For ShardedDDP, we must use the OSS optimizer,
    # set the right task name, use the PyTorch AMP if AMP is used.
    if cfg.MODEL.SHARDED_DDP_SETUP.USE_SDP:
        cfg.OPTIMIZER.use_zero = True
        cfg.TRAINER.TASK_NAME = "self_supervision_sdp_task"
        if cfg.MODEL.AMP_PARAMS.USE_AMP:
            cfg.MODEL.AMP_PARAMS.AMP_TYPE = "pytorch"

    # if we use a zero optimizer, we nest the optimizer related settings under the
    # base_optimizer.
    if cfg.OPTIMIZER.use_zero:
        cfg.OPTIMIZER["base_optimizer"] = cfg.OPTIMIZER.copy()
        cfg.OPTIMIZER.name = "zero"
        del cfg.OPTIMIZER.base_optimizer["param_schedulers"]
        del cfg.OPTIMIZER.base_optimizer["regularize_bn"]
        del cfg.OPTIMIZER.base_optimizer["regularize_bias"]
        del cfg.OPTIMIZER.base_optimizer["num_epochs"]
        del cfg.OPTIMIZER.base_optimizer["use_zero"]
        del cfg.OPTIMIZER.base_optimizer["head_optimizer_params"]

    # Infer fsdp settings
    cfg = infer_fsdp_setup(cfg)

    if cfg.DATA.TRAIN.BASE_DATASET == "generic_ssl":
        assert (
            cfg.DATA.TRAIN.get("TRAIN_PHASES_PER_EPOCH", 1) == 1
        ), "When using the generic_ssl, we must set TRAIN_PHASES_PER_EPOCH = 1."

    if cfg.METERS.model_output_mask:
        assert (
            len(cfg.DATA.TEST.DATA_SOURCES) > 0
        ), "Model output mask is only applicable when there is a test dataset."

        assert (cfg.DATA.TEST.BASE_DATASET == "generic_ssl"
                ), "Model output mask is only supported with ssl dataset."

        # Remove CHECK_NAN hooks, as model output masking casts the logits
        # to -inf, which will throw an error from the CHECK_NAN hooks.
        cfg.HOOKS.CHECK_NAN = False

    if cfg.HOOKS.EMA_MODEL.ENABLE_EMA_METERS:
        assert cfg.METERS.get("name", "") or cfg.METERS.get(
            "names", []
        ), "Please specify METER.name or METER.names if you are enabling the EMA_MODEL hook."