コード例 #1
0
def test_gm_pruning_yaml():
    init_sparsity = 0.05
    final_sparsity = 0.8
    start_epoch = 5.0
    end_epoch = 15.0
    update_frequency = 1.0
    params = ["re:.*weight"]
    inter_func = "cubic"
    mask_type = "filter"
    global_sparsity = False
    yaml_str = f"""
    !GMPruningModifier
        init_sparsity: {init_sparsity}
        final_sparsity: {final_sparsity}
        start_epoch: {start_epoch}
        end_epoch: {end_epoch}
        update_frequency: {update_frequency}
        params: {params}
        inter_func: {inter_func}
        mask_type: {mask_type}
        global_sparsity: {global_sparsity}
    """
    yaml_modifier = GMPruningModifier.load_obj(
        yaml_str)  # type: GMPruningModifier
    serialized_modifier = GMPruningModifier.load_obj(
        str(yaml_modifier))  # type: GMPruningModifier
    obj_modifier = GMPruningModifier(
        init_sparsity=init_sparsity,
        final_sparsity=final_sparsity,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        update_frequency=update_frequency,
        params=params,
        inter_func=inter_func,
        mask_type=mask_type,
        global_sparsity=global_sparsity,
    )

    assert isinstance(yaml_modifier, GMPruningModifier)
    assert (yaml_modifier.init_sparsity == serialized_modifier.init_sparsity ==
            obj_modifier.init_sparsity)
    assert (yaml_modifier.final_sparsity == serialized_modifier.final_sparsity
            == obj_modifier.final_sparsity)
    assert (yaml_modifier.start_epoch == serialized_modifier.start_epoch ==
            obj_modifier.start_epoch)
    assert (yaml_modifier.end_epoch == serialized_modifier.end_epoch ==
            obj_modifier.end_epoch)
    assert (yaml_modifier.update_frequency ==
            serialized_modifier.update_frequency ==
            obj_modifier.update_frequency)
    assert yaml_modifier.params == serialized_modifier.params == obj_modifier.params
    assert (yaml_modifier.inter_func == serialized_modifier.inter_func ==
            obj_modifier.inter_func)
    assert (str(yaml_modifier.mask_type) == str(serialized_modifier.mask_type)
            == str(obj_modifier.mask_type))
    assert (str(yaml_modifier.global_sparsity) == str(
        serialized_modifier.global_sparsity) == str(
            obj_modifier.global_sparsity))
コード例 #2
0
def test_magnitude_pruning_yaml():
    init_sparsity = 0.05
    final_sparsity = 0.8
    start_epoch = 5.0
    end_epoch = 15.0
    update_frequency = 1.0
    params = "__ALL_PRUNABLE__"
    inter_func = "cubic"
    mask_type = "filter"
    yaml_str = f"""
    !MagnitudePruningModifier
        init_sparsity: {init_sparsity}
        final_sparsity: {final_sparsity}
        start_epoch: {start_epoch}
        end_epoch: {end_epoch}
        update_frequency: {update_frequency}
        params: {params}
        inter_func: {inter_func}
        mask_type: {mask_type}
    """
    yaml_modifier = MagnitudePruningModifier.load_obj(
        yaml_str)  # type: MagnitudePruningModifier
    serialized_modifier = GMPruningModifier.load_obj(
        str(yaml_modifier))  # type: MagnitudePruningModifier
    obj_modifier = GMPruningModifier(
        init_sparsity=init_sparsity,
        final_sparsity=final_sparsity,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        update_frequency=update_frequency,
        params=params,
        inter_func=inter_func,
        mask_type=mask_type,
    )

    assert isinstance(yaml_modifier, MagnitudePruningModifier)
    assert (yaml_modifier.init_sparsity == serialized_modifier.init_sparsity ==
            obj_modifier.init_sparsity)
    assert (yaml_modifier.final_sparsity == serialized_modifier.final_sparsity
            == obj_modifier.final_sparsity)
    assert (yaml_modifier.start_epoch == serialized_modifier.start_epoch ==
            obj_modifier.start_epoch)
    assert (yaml_modifier.end_epoch == serialized_modifier.end_epoch ==
            obj_modifier.end_epoch)
    assert (yaml_modifier.update_frequency ==
            serialized_modifier.update_frequency ==
            obj_modifier.update_frequency)
    assert yaml_modifier.params == serialized_modifier.params == obj_modifier.params
    assert (yaml_modifier.inter_func == serialized_modifier.inter_func ==
            obj_modifier.inter_func)
    assert (str(yaml_modifier.mask_type) == str(serialized_modifier.mask_type)
            == str(obj_modifier.mask_type))
コード例 #3
0
def test_gm_pruning_yaml(params, final_sparsity):
    init_sparsity = 0.05
    start_epoch = 5.0
    end_epoch = 15.0
    update_frequency = 1.0
    inter_func = "cubic"
    mask_type = "filter"
    global_sparsity = False
    yaml_str = f"""
    !LegacyGMPruningModifier
        init_sparsity: {init_sparsity}
        final_sparsity: {final_sparsity}
        start_epoch: {start_epoch}
        end_epoch: {end_epoch}
        update_frequency: {update_frequency}
        params: {params}
        inter_func: {inter_func}
        mask_type: {mask_type}
        global_sparsity: {global_sparsity}
    """
    yaml_modifier = GMPruningModifier.load_obj(yaml_str)  # type: GMPruningModifier
    serialized_modifier = GMPruningModifier.load_obj(
        str(yaml_modifier)
    )  # type: GMPruningModifier
    obj_modifier = GMPruningModifier(
        init_sparsity=init_sparsity,
        final_sparsity=final_sparsity,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        update_frequency=update_frequency,
        params=params,
        inter_func=inter_func,
        mask_type=mask_type,
        global_sparsity=global_sparsity,
    )

    assert isinstance(yaml_modifier, GMPruningModifier)
    _test_pruning_modifier_serialization_vals(
        yaml_modifier, serialized_modifier, obj_modifier
    )
    assert (
        str(yaml_modifier.global_sparsity)
        == str(serialized_modifier.global_sparsity)
        == str(obj_modifier.global_sparsity)
    )
コード例 #4
0
            assert torch.all(param == 0.0)


@flaky(max_runs=3, min_passes=2)
@pytest.mark.skipif(
    os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False),
    reason="Skipping pytorch tests",
)
@pytest.mark.parametrize(
    "modifier_lambda",
    [
        lambda: GMPruningModifier(
            init_sparsity=0.05,
            final_sparsity=0.95,
            start_epoch=0.0,
            end_epoch=15.0,
            update_frequency=1.0,
            params=["re:.*weight"],
            inter_func="linear",
        ),
        lambda: GMPruningModifier(
            params=["re:seq.block1.*weight"],
            init_sparsity=0.05,
            final_sparsity=0.95,
            start_epoch=10.0,
            end_epoch=25.0,
            update_frequency=1.0,
            inter_func="cubic",
            global_sparsity=True,
        ),
        lambda: GMPruningModifier(
コード例 #5
0
            obj_modifier.end_epoch)
    assert yaml_modifier.params == serialized_modifier.params == obj_modifier.params


@pytest.mark.skipif(
    os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False),
    reason="Skipping pytorch tests",
)
@pytest.mark.parametrize(
    "modifier_lambda",
    [
        lambda: GMPruningModifier(
            init_sparsity=0.05,
            final_sparsity=0.95,
            start_epoch=0.0,
            end_epoch=15.0,
            update_frequency=1.0,
            params=["re:.*weight"],
            inter_func="linear",
        ),
        lambda: GMPruningModifier(
            params=["re:seq.block1.*weight"],
            init_sparsity=0.05,
            final_sparsity=0.95,
            start_epoch=10.0,
            end_epoch=25.0,
            update_frequency=1.0,
            inter_func="cubic",
        ),
        lambda: GMPruningModifier(
            params=["seq.fc1.weight", "seq.fc2.weight"],
コード例 #6
0
def create_config(project: Project, optim: ProjectOptimization,
                  framework: str) -> str:
    """
    Creates a optimization config yaml for a given project and optimization

    :param project: project to create with
    :param optim: project optimizer to create with
    :param framework: the framework to create the config for
    """
    # add imports in function so they don't fail if they don't have env setup
    # for frameworks other than the requested
    if framework == "pytorch":
        from sparseml.pytorch.optim import (
            EpochRangeModifier,
            GMPruningModifier,
            LearningRateModifier,
            ScheduledModifierManager,
            SetLearningRateModifier,
            TrainableParamsModifier,
        )
    elif framework == "tensorflow":
        from sparseml.tensorflow_v1.optim import (
            EpochRangeModifier,
            GMPruningModifier,
            LearningRateModifier,
            ScheduledModifierManager,
            SetLearningRateModifier,
            TrainableParamsModifier,
        )
    else:
        _LOGGER.error("Unsupported framework {} provided".format(framework))
        raise ValidationError(
            "Unsupported framework {} provided".format(framework))

    mods = [
        EpochRangeModifier(
            start_epoch=optim.start_epoch
            if optim.start_epoch is not None else -1,
            end_epoch=optim.end_epoch if optim.end_epoch is not None else -1,
        )
    ]
    node_weight_name_lookup = {
        node["id"]: node["weight_name"]
        for node in project.model.analysis["nodes"] if node["prunable"]
    }

    for mod in optim.pruning_modifiers:
        sparsity_to_params = {}

        for node in mod.nodes:
            # node is coming from DB, so already had prunable checks
            # add assert here to fail early for non prunable nodes
            assert node["node_id"] in node_weight_name_lookup

            sparsity = node["sparsity"]
            node_id = node["node_id"]
            weight_name = node_weight_name_lookup[node_id]

            if sparsity is None:
                continue

            if sparsity not in sparsity_to_params:
                sparsity_to_params[sparsity] = []

            sparsity_to_params[sparsity].append(weight_name)

        for sparsity, params in sparsity_to_params.items():
            gm_pruning = GMPruningModifier(
                init_sparsity=0.05,
                final_sparsity=sparsity,
                start_epoch=mod.start_epoch
                if mod.start_epoch is not None else -1,
                end_epoch=mod.end_epoch if mod.end_epoch is not None else -1,
                update_frequency=mod.update_frequency
                if mod.update_frequency else -1,
                params=params,
            )

            if mod.mask_type:
                gm_pruning.mask_type = mod.mask_type

            mods.append(gm_pruning)

    for lr_schedule_modifier in optim.lr_schedule_modifiers:
        for mod in lr_schedule_modifier.lr_mods:
            mod = ProjectOptimizationModifierLRSchema().dump(mod)
            start_epoch = mod["start_epoch"] if mod[
                "start_epoch"] is not None else -1
            end_epoch = mod["end_epoch"] if mod["end_epoch"] is not None else -1

            if mod["clazz"] == "set":
                mods.append(
                    SetLearningRateModifier(
                        learning_rate=mod["init_lr"],
                        start_epoch=start_epoch,
                    ))
            else:
                lr_class_mapping = {
                    "step": "StepLR",
                    "multi_step": "MultiStepLR",
                    "exponential": "ExponentialLR",
                }
                assert mod["clazz"] in lr_class_mapping
                mods.append(
                    LearningRateModifier(
                        lr_class=lr_class_mapping[mod["clazz"]],
                        lr_kwargs=mod["args"],
                        init_lr=mod["init_lr"],
                        start_epoch=start_epoch,
                        end_epoch=end_epoch,
                    ))

    for trainable_modifier in optim.trainable_modifiers:
        mod = ProjectOptimizationModifierTrainableSchema().dump(
            trainable_modifier)
        start_epoch = mod["start_epoch"] if mod[
            "start_epoch"] is not None else -1
        end_epoch = mod["end_epoch"] if mod["end_epoch"] is not None else -1

        if "nodes" not in mod:
            continue
        trainable_nodes = []
        untrainable_nodes = []
        for node in mod["nodes"]:
            assert node["node_id"] in node_weight_name_lookup

            weight_name = node_weight_name_lookup[node["node_id"]]
            if node["trainable"]:
                trainable_nodes.append(weight_name)
            else:
                untrainable_nodes.append(weight_name)

        if len(trainable_nodes) > 0:
            mods.append(
                TrainableParamsModifier(trainable_nodes,
                                        True,
                                        start_epoch=start_epoch,
                                        end_epoch=end_epoch))

        if len(untrainable_nodes) > 0:
            mods.append(
                TrainableParamsModifier(
                    untrainable_nodes,
                    False,
                    start_epoch=start_epoch,
                    end_epoch=end_epoch,
                ))
    # TODO: add quantization support when ready

    return str(ScheduledModifierManager(mods))