Esempio n. 1
0
def test_gm_pruning_yaml():
    init_sparsity = 0.05
    final_sparsity = 0.8
    start_epoch = 5.0
    end_epoch = 15.0
    update_frequency = 1.0
    params = ["re:.*weight"]
    inter_func = "cubic"
    mask_type = "filter"
    global_sparsity = False
    yaml_str = f"""
    !GMPruningModifier
        init_sparsity: {init_sparsity}
        final_sparsity: {final_sparsity}
        start_epoch: {start_epoch}
        end_epoch: {end_epoch}
        update_frequency: {update_frequency}
        params: {params}
        inter_func: {inter_func}
        mask_type: {mask_type}
        global_sparsity: {global_sparsity}
    """
    yaml_modifier = GMPruningModifier.load_obj(
        yaml_str)  # type: GMPruningModifier
    serialized_modifier = GMPruningModifier.load_obj(
        str(yaml_modifier))  # type: GMPruningModifier
    obj_modifier = GMPruningModifier(
        init_sparsity=init_sparsity,
        final_sparsity=final_sparsity,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        update_frequency=update_frequency,
        params=params,
        inter_func=inter_func,
        mask_type=mask_type,
        global_sparsity=global_sparsity,
    )

    assert isinstance(yaml_modifier, GMPruningModifier)
    assert (yaml_modifier.init_sparsity == serialized_modifier.init_sparsity ==
            obj_modifier.init_sparsity)
    assert (yaml_modifier.final_sparsity == serialized_modifier.final_sparsity
            == obj_modifier.final_sparsity)
    assert (yaml_modifier.start_epoch == serialized_modifier.start_epoch ==
            obj_modifier.start_epoch)
    assert (yaml_modifier.end_epoch == serialized_modifier.end_epoch ==
            obj_modifier.end_epoch)
    assert (yaml_modifier.update_frequency ==
            serialized_modifier.update_frequency ==
            obj_modifier.update_frequency)
    assert yaml_modifier.params == serialized_modifier.params == obj_modifier.params
    assert (yaml_modifier.inter_func == serialized_modifier.inter_func ==
            obj_modifier.inter_func)
    assert (str(yaml_modifier.mask_type) == str(serialized_modifier.mask_type)
            == str(obj_modifier.mask_type))
    assert (str(yaml_modifier.global_sparsity) == str(
        serialized_modifier.global_sparsity) == str(
            obj_modifier.global_sparsity))
Esempio n. 2
0
def test_magnitude_pruning_yaml():
    init_sparsity = 0.05
    final_sparsity = 0.8
    start_epoch = 5.0
    end_epoch = 15.0
    update_frequency = 1.0
    params = "__ALL_PRUNABLE__"
    inter_func = "cubic"
    mask_type = "filter"
    yaml_str = f"""
    !MagnitudePruningModifier
        init_sparsity: {init_sparsity}
        final_sparsity: {final_sparsity}
        start_epoch: {start_epoch}
        end_epoch: {end_epoch}
        update_frequency: {update_frequency}
        params: {params}
        inter_func: {inter_func}
        mask_type: {mask_type}
    """
    yaml_modifier = MagnitudePruningModifier.load_obj(
        yaml_str)  # type: MagnitudePruningModifier
    serialized_modifier = GMPruningModifier.load_obj(
        str(yaml_modifier))  # type: MagnitudePruningModifier
    obj_modifier = GMPruningModifier(
        init_sparsity=init_sparsity,
        final_sparsity=final_sparsity,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        update_frequency=update_frequency,
        params=params,
        inter_func=inter_func,
        mask_type=mask_type,
    )

    assert isinstance(yaml_modifier, MagnitudePruningModifier)
    assert (yaml_modifier.init_sparsity == serialized_modifier.init_sparsity ==
            obj_modifier.init_sparsity)
    assert (yaml_modifier.final_sparsity == serialized_modifier.final_sparsity
            == obj_modifier.final_sparsity)
    assert (yaml_modifier.start_epoch == serialized_modifier.start_epoch ==
            obj_modifier.start_epoch)
    assert (yaml_modifier.end_epoch == serialized_modifier.end_epoch ==
            obj_modifier.end_epoch)
    assert (yaml_modifier.update_frequency ==
            serialized_modifier.update_frequency ==
            obj_modifier.update_frequency)
    assert yaml_modifier.params == serialized_modifier.params == obj_modifier.params
    assert (yaml_modifier.inter_func == serialized_modifier.inter_func ==
            obj_modifier.inter_func)
    assert (str(yaml_modifier.mask_type) == str(serialized_modifier.mask_type)
            == str(obj_modifier.mask_type))