Exemple #1
0
def test_magnitude_pruning():
    # Create a 4-D tensor of 1s
    a = torch.ones(3, 64, 32, 32)
    # Change one element
    a[1, 4, 17, 31] = 0.2
    # Create a masks dictionary and populate it with one ParameterMasker
    zeros_mask_dict = {}
    masker = distiller.ParameterMasker('a')
    zeros_mask_dict['a'] = masker
    # Try to use a MagnitudeParameterPruner with defining a default threshold
    with pytest.raises(AssertionError):
        pruner = distiller.pruning.MagnitudeParameterPruner("test", None)

    # Now define the default threshold
    thresholds = {"*": 0.4}
    pruner = distiller.pruning.MagnitudeParameterPruner("test", thresholds)
    assert distiller.sparsity(a) == 0
    # Create a mask for parameter 'a'
    pruner.set_param_mask(a, 'a', zeros_mask_dict, None)
    assert common.almost_equal(distiller.sparsity(zeros_mask_dict['a'].mask), 1/distiller.volume(a))

    # Let's now use the masker to prune a parameter
    masker = zeros_mask_dict['a']
    masker.apply_mask(a)
    assert common.almost_equal(distiller.sparsity(a), 1/distiller.volume(a))
    # We can use the masker on other tensors, if we want (and if they have the correct shape).
    # Remember that the mask was created already, so we're not thresholding - we are pruning
    b = torch.ones(3, 64, 32, 32)
    b[:] = 0.3
    masker.apply_mask(b)
    assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a))
Exemple #2
0
def create_model_masks(model):
    # Create the masks
    zeros_mask_dict = {}
    for name, param in model.named_parameters():
        masker = distiller.ParameterMasker(name)
        zeros_mask_dict[name] = masker
    return zeros_mask_dict
Exemple #3
0
def setup_test(arch, dataset, parallel):
    model = create_model(False, dataset, arch, parallel=parallel)
    assert model is not None

    # Create the masks
    zeros_mask_dict = {}
    for name, param in model.named_parameters():
        masker = distiller.ParameterMasker(name)
        zeros_mask_dict[name] = masker
    return model, zeros_mask_dict