Пример #1
0
def _test_constructor(layer, param_name, mask_creator):
    mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator)
    assert mask.layers[0] == layer
    assert mask.param_names[0] == param_name
    assert not mask.store_init
    assert not mask.store_unmasked
    assert mask.track_grad_mom == -1.0
    assert not mask.global_sparsity
    assert not mask.enabled
    assert mask_creator == mask.mask_creator
Пример #2
0
def _test_set_param_mask_from_sparsity(
    layer, param_name, param, sparsity, mask_creator
):
    mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator)
    mask.set_param_data(param, 0)
    mask.set_param_masks_from_sparsity(sparsity)
    measured = tensor_sparsity(mask.param_masks[0])
    assert (measured - sparsity).abs() < 0.01
    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
Пример #3
0
def _test_set_param_mask_from_abs_threshold(
    layer,
    param_name,
    param,
    threshold,
    expected_sparsity,
    mask_creator,
):
    mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator)
    mask.set_param_data(param, 0)
    mask.set_param_masks_from_abs_threshold(threshold)
    sparsity = tensor_sparsity(mask.param_masks[0])
    assert (sparsity - expected_sparsity).abs() < 0.01
    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
Пример #4
0
def _test_set_param_mask(layer, param_name, param_mask):
    mask = ModuleParamPruningMask([layer], [param_name])
    result = mask.set_param_masks([param_mask])[0]
    res_unmasked = (result == 1.0).type(torch.float32)
    res_masked = (result == -1.0).type(torch.float32)
    res_no_change = (result == 0.0).type(torch.float32)
    mask_ones = (param_mask == 1.0).type(torch.float32)
    mask_zeros = (param_mask == 0.0).type(torch.float32)

    assert torch.sum(res_unmasked.abs()) < sys.float_info.epsilon
    assert torch.sum((res_masked - mask_zeros).abs()) < sys.float_info.epsilon
    assert torch.sum((res_no_change - mask_ones).abs()) < sys.float_info.epsilon

    mask.enabled = True
    mask.apply()
    param_data_zeros = (mask.params_data[0] == 0.0).type("float32")

    assert torch.sum((param_data_zeros - mask_zeros).abs()) < sys.float_info.epsilon
Пример #5
0
def _test_set_param_data(layer, param_name, data):
    mask = ModuleParamPruningMask([layer], param_name)
    mask.set_param_data(data, 0)
    assert torch.sum((mask.params_data[0] - data).abs()) < sys.float_info.epsilon