Beispiel #1
0
def test_set_lr_yaml():
    start_epoch = 10.0
    yaml_str = """
    !SetLearningRateModifier
        learning_rate: {}
        start_epoch: {}
    """.format(
        SET_LR, start_epoch
    )
    yaml_modifier = SetLearningRateModifier.load_obj(
        yaml_str
    )  # type: SetLearningRateModifier
    serialized_modifier = SetLearningRateModifier.load_obj(
        str(yaml_modifier)
    )  # type: SetLearningRateModifier
    obj_modifier = SetLearningRateModifier(
        learning_rate=SET_LR, start_epoch=start_epoch
    )

    assert isinstance(yaml_modifier, SetLearningRateModifier)
    assert (
        yaml_modifier.learning_rate
        == serialized_modifier.learning_rate
        == obj_modifier.learning_rate
    )
    assert (
        yaml_modifier.start_epoch
        == serialized_modifier.start_epoch
        == obj_modifier.start_epoch
    )
##############################
#
# SetLearningRateModifier tests
#
##############################


@pytest.mark.skipif(
    os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False),
    reason="Skipping pytorch tests",
)
@pytest.mark.parametrize(
    "modifier_lambda",
    [
        lambda: SetLearningRateModifier(learning_rate=0.1),
        lambda: SetLearningRateModifier(learning_rate=0.03, start_epoch=5),
    ],
    scope="function",
)
@pytest.mark.parametrize("model_lambda", [LinearNet], scope="function")
@pytest.mark.parametrize(
    "optim_lambda",
    [
        lambda model: SGD(model.parameters(), INIT_LR),
        lambda model: Adam(model.parameters(), INIT_LR),
    ],
    scope="function",
)
class TestSetLRModifierImpl(ScheduledModifierTest):
    def test_lifecycle(
def create_config(project: Project, optim: ProjectOptimization,
                  framework: str) -> str:
    """
    Creates a optimization config yaml for a given project and optimization

    :param project: project to create with
    :param optim: project optimizer to create with
    :param framework: the framework to create the config for
    """
    # add imports in function so they don't fail if they don't have env setup
    # for frameworks other than the requested
    if framework == "pytorch":
        from sparseml.pytorch.optim import (
            EpochRangeModifier,
            GMPruningModifier,
            LearningRateModifier,
            ScheduledModifierManager,
            SetLearningRateModifier,
            TrainableParamsModifier,
        )
    elif framework == "tensorflow":
        from sparseml.tensorflow_v1.optim import (
            EpochRangeModifier,
            GMPruningModifier,
            LearningRateModifier,
            ScheduledModifierManager,
            SetLearningRateModifier,
            TrainableParamsModifier,
        )
    else:
        _LOGGER.error("Unsupported framework {} provided".format(framework))
        raise ValidationError(
            "Unsupported framework {} provided".format(framework))

    mods = [
        EpochRangeModifier(
            start_epoch=optim.start_epoch
            if optim.start_epoch is not None else -1,
            end_epoch=optim.end_epoch if optim.end_epoch is not None else -1,
        )
    ]
    node_weight_name_lookup = {
        node["id"]: node["weight_name"]
        for node in project.model.analysis["nodes"] if node["prunable"]
    }

    for mod in optim.pruning_modifiers:
        sparsity_to_params = {}

        for node in mod.nodes:
            # node is coming from DB, so already had prunable checks
            # add assert here to fail early for non prunable nodes
            assert node["node_id"] in node_weight_name_lookup

            sparsity = node["sparsity"]
            node_id = node["node_id"]
            weight_name = node_weight_name_lookup[node_id]

            if sparsity is None:
                continue

            if sparsity not in sparsity_to_params:
                sparsity_to_params[sparsity] = []

            sparsity_to_params[sparsity].append(weight_name)

        for sparsity, params in sparsity_to_params.items():
            gm_pruning = GMPruningModifier(
                init_sparsity=0.05,
                final_sparsity=sparsity,
                start_epoch=mod.start_epoch
                if mod.start_epoch is not None else -1,
                end_epoch=mod.end_epoch if mod.end_epoch is not None else -1,
                update_frequency=mod.update_frequency
                if mod.update_frequency else -1,
                params=params,
            )

            if mod.mask_type:
                gm_pruning.mask_type = mod.mask_type

            mods.append(gm_pruning)

    for lr_schedule_modifier in optim.lr_schedule_modifiers:
        for mod in lr_schedule_modifier.lr_mods:
            mod = ProjectOptimizationModifierLRSchema().dump(mod)
            start_epoch = mod["start_epoch"] if mod[
                "start_epoch"] is not None else -1
            end_epoch = mod["end_epoch"] if mod["end_epoch"] is not None else -1

            if mod["clazz"] == "set":
                mods.append(
                    SetLearningRateModifier(
                        learning_rate=mod["init_lr"],
                        start_epoch=start_epoch,
                    ))
            else:
                lr_class_mapping = {
                    "step": "StepLR",
                    "multi_step": "MultiStepLR",
                    "exponential": "ExponentialLR",
                }
                assert mod["clazz"] in lr_class_mapping
                mods.append(
                    LearningRateModifier(
                        lr_class=lr_class_mapping[mod["clazz"]],
                        lr_kwargs=mod["args"],
                        init_lr=mod["init_lr"],
                        start_epoch=start_epoch,
                        end_epoch=end_epoch,
                    ))

    for trainable_modifier in optim.trainable_modifiers:
        mod = ProjectOptimizationModifierTrainableSchema().dump(
            trainable_modifier)
        start_epoch = mod["start_epoch"] if mod[
            "start_epoch"] is not None else -1
        end_epoch = mod["end_epoch"] if mod["end_epoch"] is not None else -1

        if "nodes" not in mod:
            continue
        trainable_nodes = []
        untrainable_nodes = []
        for node in mod["nodes"]:
            assert node["node_id"] in node_weight_name_lookup

            weight_name = node_weight_name_lookup[node["node_id"]]
            if node["trainable"]:
                trainable_nodes.append(weight_name)
            else:
                untrainable_nodes.append(weight_name)

        if len(trainable_nodes) > 0:
            mods.append(
                TrainableParamsModifier(trainable_nodes,
                                        True,
                                        start_epoch=start_epoch,
                                        end_epoch=end_epoch))

        if len(untrainable_nodes) > 0:
            mods.append(
                TrainableParamsModifier(
                    untrainable_nodes,
                    False,
                    start_epoch=start_epoch,
                    end_epoch=end_epoch,
                ))
    # TODO: add quantization support when ready

    return str(ScheduledModifierManager(mods))