Ejemplo n.º 1
0

from tests.sparseml.pytorch.helpers import (  # noqa isort:skip
    test_epoch,
    test_loss,
    test_steps_per_epoch,
)


@pytest.mark.skipif(
    os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False),
    reason="Skipping pytorch tests",
)
@pytest.mark.parametrize(
    "modifier_lambda",
    [lambda: ScheduledModifierManager([ScheduledModifierImpl()])],
    scope="function",
)
@pytest.mark.parametrize("model_lambda", [LinearNet], scope="function")
@pytest.mark.parametrize(
    "optim_lambda", [create_optim_sgd, create_optim_adam], scope="function"
)
class TestManagerImpl(ModifierTest):
    def test_yaml(
        self,
        modifier_lambda: Callable[[], Modifier],
        model_lambda: Callable[[], Module],
        optim_lambda: Callable[[Module], Optimizer],
        test_epoch: float,  # noqa: F811
        test_steps_per_epoch: float,  # noqa: F811
    ):
Ejemplo n.º 2
0
def test_manager_yaml():
    manager = ScheduledModifierManager([ScheduledModifierImpl()])
    yaml_str = str(manager)
    assert yaml_str
Ejemplo n.º 3
0
def create_config(project: Project, optim: ProjectOptimization,
                  framework: str) -> str:
    """
    Creates a optimization config yaml for a given project and optimization

    :param project: project to create with
    :param optim: project optimizer to create with
    :param framework: the framework to create the config for
    """
    # add imports in function so they don't fail if they don't have env setup
    # for frameworks other than the requested
    if framework == "pytorch":
        from sparseml.pytorch.optim import (
            EpochRangeModifier,
            GMPruningModifier,
            LearningRateModifier,
            ScheduledModifierManager,
            SetLearningRateModifier,
            TrainableParamsModifier,
        )
    elif framework == "tensorflow":
        from sparseml.tensorflow_v1.optim import (
            EpochRangeModifier,
            GMPruningModifier,
            LearningRateModifier,
            ScheduledModifierManager,
            SetLearningRateModifier,
            TrainableParamsModifier,
        )
    else:
        _LOGGER.error("Unsupported framework {} provided".format(framework))
        raise ValidationError(
            "Unsupported framework {} provided".format(framework))

    mods = [
        EpochRangeModifier(
            start_epoch=optim.start_epoch
            if optim.start_epoch is not None else -1,
            end_epoch=optim.end_epoch if optim.end_epoch is not None else -1,
        )
    ]
    node_weight_name_lookup = {
        node["id"]: node["weight_name"]
        for node in project.model.analysis["nodes"] if node["prunable"]
    }

    for mod in optim.pruning_modifiers:
        sparsity_to_params = {}

        for node in mod.nodes:
            # node is coming from DB, so already had prunable checks
            # add assert here to fail early for non prunable nodes
            assert node["node_id"] in node_weight_name_lookup

            sparsity = node["sparsity"]
            node_id = node["node_id"]
            weight_name = node_weight_name_lookup[node_id]

            if sparsity is None:
                continue

            if sparsity not in sparsity_to_params:
                sparsity_to_params[sparsity] = []

            sparsity_to_params[sparsity].append(weight_name)

        for sparsity, params in sparsity_to_params.items():
            gm_pruning = GMPruningModifier(
                init_sparsity=0.05,
                final_sparsity=sparsity,
                start_epoch=mod.start_epoch
                if mod.start_epoch is not None else -1,
                end_epoch=mod.end_epoch if mod.end_epoch is not None else -1,
                update_frequency=mod.update_frequency
                if mod.update_frequency else -1,
                params=params,
            )

            if mod.mask_type:
                gm_pruning.mask_type = mod.mask_type

            mods.append(gm_pruning)

    for lr_schedule_modifier in optim.lr_schedule_modifiers:
        for mod in lr_schedule_modifier.lr_mods:
            mod = ProjectOptimizationModifierLRSchema().dump(mod)
            start_epoch = mod["start_epoch"] if mod[
                "start_epoch"] is not None else -1
            end_epoch = mod["end_epoch"] if mod["end_epoch"] is not None else -1

            if mod["clazz"] == "set":
                mods.append(
                    SetLearningRateModifier(
                        learning_rate=mod["init_lr"],
                        start_epoch=start_epoch,
                    ))
            else:
                lr_class_mapping = {
                    "step": "StepLR",
                    "multi_step": "MultiStepLR",
                    "exponential": "ExponentialLR",
                }
                assert mod["clazz"] in lr_class_mapping
                mods.append(
                    LearningRateModifier(
                        lr_class=lr_class_mapping[mod["clazz"]],
                        lr_kwargs=mod["args"],
                        init_lr=mod["init_lr"],
                        start_epoch=start_epoch,
                        end_epoch=end_epoch,
                    ))

    for trainable_modifier in optim.trainable_modifiers:
        mod = ProjectOptimizationModifierTrainableSchema().dump(
            trainable_modifier)
        start_epoch = mod["start_epoch"] if mod[
            "start_epoch"] is not None else -1
        end_epoch = mod["end_epoch"] if mod["end_epoch"] is not None else -1

        if "nodes" not in mod:
            continue
        trainable_nodes = []
        untrainable_nodes = []
        for node in mod["nodes"]:
            assert node["node_id"] in node_weight_name_lookup

            weight_name = node_weight_name_lookup[node["node_id"]]
            if node["trainable"]:
                trainable_nodes.append(weight_name)
            else:
                untrainable_nodes.append(weight_name)

        if len(trainable_nodes) > 0:
            mods.append(
                TrainableParamsModifier(trainable_nodes,
                                        True,
                                        start_epoch=start_epoch,
                                        end_epoch=end_epoch))

        if len(untrainable_nodes) > 0:
            mods.append(
                TrainableParamsModifier(
                    untrainable_nodes,
                    False,
                    start_epoch=start_epoch,
                    end_epoch=end_epoch,
                ))
    # TODO: add quantization support when ready

    return str(ScheduledModifierManager(mods))