Ejemplo n.º 1
0
def test_png_generation():
    DATASET = "cifar10"
    ARCH = "resnet20_cifar"
    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)
    # 2 different ways to create a PNG
    distiller.draw_img_classifier_to_file(model, 'model.png', DATASET, True)
    distiller.draw_img_classifier_to_file(model, 'model.png', DATASET, False)
Ejemplo n.º 2
0
def test_png_generation(display_param_nodes):
    dataset = "cifar10"
    arch = "resnet20_cifar"
    model, _ = common.setup_test(arch, dataset, parallel=True)
    # 2 different ways to create a PNG
    distiller.draw_img_classifier_to_file(model, 'model.png', dataset,
                                          display_param_nodes)
Ejemplo n.º 3
0
def ranked_filter_pruning(config, ratio_to_prune, is_parallel):
    """Test L1 ranking and pruning of filters.
    First we rank and prune the filters of a Convolutional layer using
    a L1RankedStructureParameterPruner.  Then we physically remove the
    filters from the model (via "thining" process).
    """
    model, zeros_mask_dict = common.setup_test(config.arch, config.dataset,
                                               is_parallel)

    for pair in config.module_pairs:
        # Test that we can access the weights tensor of the first convolution in layer 1
        conv1_p = distiller.model_find_param(model, pair[0] + ".weight")
        assert conv1_p is not None
        num_filters = conv1_p.size(0)

        # Test that there are no zero-filters
        assert distiller.sparsity_3D(conv1_p) == 0.0

        # Create a filter-ranking pruner
        pruner = distiller.pruning.L1RankedStructureParameterPruner(
            "filter_pruner",
            group_type="Filters",
            desired_sparsity=ratio_to_prune,
            weights=pair[0] + ".weight")
        pruner.set_param_mask(conv1_p,
                              pair[0] + ".weight",
                              zeros_mask_dict,
                              meta=None)

        conv1 = common.find_module_by_name(model, pair[0])
        assert conv1 is not None
        # Test that the mask has the correct fraction of filters pruned.
        # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
        expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
        expected_pruning = expected_cnt_removed_filters / conv1.out_channels
        masker = zeros_mask_dict[pair[0] + ".weight"]
        assert masker is not None
        assert distiller.sparsity_3D(masker.mask) == expected_pruning

        # Use the mask to prune
        assert distiller.sparsity_3D(conv1_p) == 0
        masker.apply_mask(conv1_p)
        assert distiller.sparsity_3D(conv1_p) == expected_pruning

        # Remove filters
        conv2 = common.find_module_by_name(model, pair[1])
        assert conv2 is not None
        assert conv1.out_channels == num_filters
        assert conv2.in_channels == num_filters

    # Test thinning
    distiller.remove_filters(model,
                             zeros_mask_dict,
                             config.arch,
                             config.dataset,
                             optimizer=None)
    assert conv1.out_channels == num_filters - expected_cnt_removed_filters
    assert conv2.in_channels == num_filters - expected_cnt_removed_filters
    return model, zeros_mask_dict
Ejemplo n.º 4
0
def test_negative():
    DATASET = "cifar10"
    ARCH = "resnet20_cifar"
    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)

    with pytest.raises(ValueError):
        # png is not a supported summary type, so we expect this to fail with a ValueError
        distiller.model_summary(model, what='png', dataset=DATASET)
Ejemplo n.º 5
0
def test_negative():
    dataset = "cifar10"
    arch = "resnet20_cifar"
    model, _ = common.setup_test(arch, dataset, parallel=True)

    with pytest.raises(ValueError):
        # png is not a supported summary type, so we expect this to fail with a ValueError
        distiller.model_summary(model, what='png', dataset=dataset)
Ejemplo n.º 6
0
def test_conv_fc_interface(model=None, zeros_mask_dict=None):
    """A special case of convolution filter-pruning occurs when the next layer is
    fully-connected (linear).  This test is for this case and uses VGG16.
    """
    arch = "vgg19"
    dataset = "imagenet"
    ratio_to_prune = 0.1
    conv_name = "features.34"
    fc_name = "classifier.0"
    dummy_input = torch.randn(1, 3, 224, 224)

    if model is None or zeros_mask_dict is None:
        model, zeros_mask_dict = common.setup_test(arch, dataset)

    # Run forward and backward passes, in order to create the gradients and optimizer params
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=0.1)
    run_forward_backward(model, optimizer, dummy_input)

    conv = common.find_module_by_name(model, conv_name)
    assert conv is not None

    conv_p = distiller.model_find_param(model, conv_name + ".weight")
    assert conv_p is not None
    assert conv_p.dim() == 4

    # Create a filter-ranking pruner
    reg_regims = {conv_name + ".weight": [ratio_to_prune, "3D"]}
    pruner = distiller.pruning.L1RankedStructureParameterPruner(
        "filter_pruner", reg_regims)
    pruner.set_param_mask(conv_p,
                          conv_name + ".weight",
                          zeros_mask_dict,
                          meta=None)

    # Use the mask to prune
    masker = zeros_mask_dict[conv_name + ".weight"]
    assert masker is not None
    masker.apply_mask(conv_p)
    num_filters = conv_p.size(0)
    expected_cnt_removed_filters = int(ratio_to_prune * conv.out_channels)

    # Remove filters
    fc = common.find_module_by_name(model, fc_name)
    assert fc is not None

    # Test thinning
    fm_size = fc.in_features // conv.out_channels
    num_nnz_filters = num_filters - expected_cnt_removed_filters
    distiller.remove_filters(model, zeros_mask_dict, arch, dataset, optimizer)
    assert conv.out_channels == num_nnz_filters
    assert fc.in_features == fm_size * num_nnz_filters

    # Run again, to make sure the optimizer and gradients shapes were updated correctly
    run_forward_backward(model, optimizer, dummy_input)
    run_forward_backward(model, optimizer, dummy_input)
Ejemplo n.º 7
0
def test_summary(arch, add_softmax):
    dataset = 'cifar10' if arch.endswith('cifar') else 'imagenet'
    model, _ = common.setup_test(arch, dataset, parallel=True)

    with tempfile.NamedTemporaryFile() as f:
        distiller.export_img_classifier_to_onnx(model,
                                                f.name,
                                                dataset,
                                                add_softmax=add_softmax)
Ejemplo n.º 8
0
def test_summary():
    DATASET = "cifar10"
    ARCH = "resnet20_cifar"
    model, zeros_mask_dict = common.setup_test(ARCH, DATASET, parallel=True)

    distiller.model_summary(model, what='sparsity', dataset=DATASET)
    distiller.model_summary(model, what='compute', dataset=DATASET)
    distiller.model_summary(model, what='model', dataset=DATASET)
    distiller.model_summary(model, what='modules', dataset=DATASET)
Ejemplo n.º 9
0
def test_summary():
    dataset = "cifar10"
    arch = "resnet20_cifar"
    model, _ = common.setup_test(arch, dataset, parallel=True)

    distiller.model_summary(model, what='sparsity', dataset=dataset)
    distiller.model_summary(model, what='compute', dataset=dataset)
    distiller.model_summary(model, what='model', dataset=dataset)
    distiller.model_summary(model, what='modules', dataset=dataset)
Ejemplo n.º 10
0
def conv_fc_interface_test(arch, dataset, conv_names, fc_names, is_parallel=parallel, model=None, zeros_mask_dict=None):
    """A special case of convolution filter-pruning occurs when the next layer is
    fully-connected (linear).  This test is for this case and uses VGG16.
    """
    ratio_to_prune = 0.1
    # Choose the layer names according to the data-parallelism setting
    names_idx = 0 if not is_parallel else 1
    conv_name = conv_names[names_idx]
    fc_name = fc_names[names_idx]

    dummy_input = torch.randn(1, 3, 224, 224).cuda()

    if model is None or zeros_mask_dict is None:
        model, zeros_mask_dict = common.setup_test(arch, dataset, is_parallel)

    # Run forward and backward passes, in order to create the gradients and optimizer params
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
    run_forward_backward(model, optimizer, dummy_input)

    conv = common.find_module_by_name(model, conv_name)
    assert conv is not None

    conv_p = distiller.model_find_param(model, conv_name + ".weight")
    assert conv_p is not None
    assert conv_p.dim() == 4

    # Create a filter-ranking pruner
    pruner = distiller.pruning.L1RankedStructureParameterPruner("filter_pruner",
                                                                group_type="Filters",
                                                                desired_sparsity=ratio_to_prune,
                                                                weights=conv_name + ".weight")
    pruner.set_param_mask(conv_p, conv_name + ".weight", zeros_mask_dict, meta=None)

    # Use the mask to prune
    masker = zeros_mask_dict[conv_name + ".weight"]
    assert masker is not None
    masker.apply_mask(conv_p)
    num_filters = conv_p.size(0)
    expected_cnt_removed_filters = int(ratio_to_prune * conv.out_channels)

    # Remove filters
    fc = common.find_module_by_name(model, fc_name)
    assert fc is not None

    # Test thinning
    fm_size = fc.in_features // conv.out_channels
    num_nnz_filters = num_filters - expected_cnt_removed_filters
    input_shape = tuple(distiller.apputils.classification_get_input_shape(dataset))
    distiller.remove_filters(model, zeros_mask_dict, input_shape, optimizer)
    assert conv.out_channels == num_nnz_filters
    assert fc.in_features == fm_size * num_nnz_filters

    # Run again, to make sure the optimizer and gradients shapes were updated correctly
    run_forward_backward(model, optimizer, dummy_input)
    run_forward_backward(model, optimizer, dummy_input)
Ejemplo n.º 11
0
def test_ranked_channel_pruning():
    model, zeros_mask_dict = common.setup_test("resnet20_cifar",
                                               "cifar10",
                                               parallel=False)

    # Test that we can access the weights tensor of the first convolution in layer 1
    conv1_p = distiller.model_find_param(model, "layer1.0.conv1.weight")
    assert conv1_p is not None

    # Test that there are no zero-channels
    assert distiller.sparsity_ch(conv1_p) == 0.0

    # # Create a channel-ranking pruner
    pruner = distiller.pruning.L1RankedStructureParameterPruner(
        "channel_pruner",
        group_type="Channels",
        desired_sparsity=0.1,
        weights="layer1.0.conv1.weight")
    pruner.set_param_mask(conv1_p,
                          "layer1.0.conv1.weight",
                          zeros_mask_dict,
                          meta=None)

    conv1 = common.find_module_by_name(model, "layer1.0.conv1")
    assert conv1 is not None

    # Test that the mask has the correct fraction of channels pruned.
    # We asked for 10%, but there are only 16 channels, so we have to settle for 1/16 channels
    logger.info("layer1.0.conv1 = {}".format(conv1))
    expected_pruning = int(0.1 * conv1.in_channels) / conv1.in_channels
    assert distiller.sparsity_ch(
        zeros_mask_dict["layer1.0.conv1.weight"].mask) == expected_pruning

    # Use the mask to prune
    assert distiller.sparsity_ch(conv1_p) == 0
    zeros_mask_dict["layer1.0.conv1.weight"].apply_mask(conv1_p)
    assert distiller.sparsity_ch(conv1_p) == expected_pruning

    # Remove channels (and filters)
    conv0 = common.find_module_by_name(model, "conv1")
    assert conv0 is not None
    assert conv0.out_channels == 16
    assert conv1.in_channels == 16

    # Test thinning
    input_shape = tuple(
        distiller.apputils.classification_get_input_shape("cifar10"))
    distiller.remove_channels(model,
                              zeros_mask_dict,
                              input_shape,
                              optimizer=None)
    assert conv0.out_channels == 15
    assert conv1.in_channels == 15
Ejemplo n.º 12
0
def test_compute_summary():
    dataset = "cifar10"
    arch = "simplenet_cifar"
    model, _ = common.setup_test(arch, dataset, parallel=True)
    df_compute = distiller.model_performance_summary(
        model, distiller.get_dummy_input(dataset))
    module_macs = df_compute.loc[:, 'MACs'].to_list()
    #                     [conv1,  conv2,  fc1,   fc2,   fc3]
    assert module_macs == [352800, 240000, 48000, 10080, 840]

    dataset = "imagenet"
    arch = "mobilenet"
    model, _ = common.setup_test(arch, dataset, parallel=True)
    df_compute = distiller.model_performance_summary(
        model, distiller.get_dummy_input(dataset))
    module_macs = df_compute.loc[:, 'MACs'].to_list()
    expected_macs = [
        10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224,
        903168, 25690112, 1806336, 51380224, 451584, 25690112, 903168,
        51380224, 903168, 51380224, 903168, 51380224, 903168, 51380224, 903168,
        51380224, 225792, 25690112, 451584, 51380224, 1024000
    ]
    assert module_macs == expected_macs
Ejemplo n.º 13
0
def test_sg_macs():
    '''Compare the MACs of different modules as computed by a SummaryGraph
    and model summary.'''
    import common
    sg = create_graph('imagenet', 'mobilenet')
    assert sg
    model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False)
    df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet'))
    modules_macs = df_compute.loc[:, ['Name', 'MACs']]
    for name, mod in model.named_modules():
        if isinstance(mod, (nn.Conv2d, nn.Linear)):
            summary_macs = int(modules_macs.loc[modules_macs.Name == name].MACs)
            sg_macs = sg.find_op(name)['attrs']['MACs']
            assert summary_macs == sg_macs
Ejemplo n.º 14
0
def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
    """Test removal of arbitrary channels.
    The test receives a specification of channels to remove.
    Based on this specification, the channels are pruned and then physically
    removed from the model (via a "thinning" process).
    """
    model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel)

    pair = config.module_pairs[0]
    conv2 = common.find_module_by_name(model, pair[1])
    assert conv2 is not None

    # Test that we can access the weights tensor of the first convolution in layer 1
    conv2_p = distiller.model_find_param(model, pair[1] + ".weight")
    assert conv2_p is not None

    assert conv2_p.dim() == 4
    num_channels = conv2_p.size(1)
    cnt_nnz_channels = num_channels - len(channels_to_remove)
    mask = create_channels_mask(conv2_p, channels_to_remove)
    assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels
    # Cool, so now we have a mask for pruning our channels.

    # Use the mask to prune
    zeros_mask_dict[pair[1] + ".weight"].mask = mask
    zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p)
    all_channels = set([ch for ch in range(num_channels)])
    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, pair[1] + ".weight"))
    channels_removed = all_channels - nnz_channels
    logger.info("Channels removed {}".format(channels_removed))

    # Now, let's do the actual network thinning
    distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None)
    conv1 = common.find_module_by_name(model, pair[0])
    assert conv1
    assert conv1.out_channels == cnt_nnz_channels
    assert conv2.in_channels == cnt_nnz_channels
    assert conv1.weight.size(0) == cnt_nnz_channels
    assert conv2.weight.size(1) == cnt_nnz_channels
    if config.bn_name is not None:
        bn1 = common.find_module_by_name(model, config.bn_name)
        assert bn1.running_var.size(0) == cnt_nnz_channels
        assert bn1.running_mean.size(0) == cnt_nnz_channels
        assert bn1.num_features == cnt_nnz_channels
        assert bn1.bias.size(0) == cnt_nnz_channels
        assert bn1.weight.size(0) == cnt_nnz_channels

    dummy_input = common.get_dummy_input(config.dataset)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
    run_forward_backward(model, optimizer, dummy_input)

    # Let's test saving and loading a thinned model.
    # We save 3 times, and load twice, to make sure to cover some corner cases:
    #   - Make sure that after loading, the model still has hold of the thinning recipes
    #   - Make sure that after a 2nd load, there no problem loading (in this case, the
    #   - tensors are already thin, so this is a new flow)
    # (1)
    save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None)
    model_2 = create_model(False, config.dataset, config.arch, parallel=is_parallel)
    model(dummy_input)
    model_2(dummy_input)
    conv2 = common.find_module_by_name(model_2, pair[1])
    assert conv2 is not None
    with pytest.raises(KeyError):
        model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
    compression_scheduler = distiller.CompressionScheduler(model)
    hasattr(model, 'thinning_recipes')

    run_forward_backward(model, optimizer, dummy_input)

    # (2)
    save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler)
    model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
    assert hasattr(model_2, 'thinning_recipes')
    logger.info("test_arbitrary_channel_pruning - Done")

    # (3)
    save_checkpoint(epoch=0, arch=config.arch, model=model_2, optimizer=None, scheduler=compression_scheduler)
    model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
    assert hasattr(model_2, 'thinning_recipes')
    logger.info("test_arbitrary_channel_pruning - Done 2")
def test_mnist(what):
    dataset = "mnist"
    arch = "simplenet_mnist"
    model, _ = common.setup_test(arch, dataset, parallel=True)
    distiller.model_summary(model, what, dataset=dataset)
Ejemplo n.º 16
0
def ranked_filter_pruning(config,
                          ratio_to_prune,
                          is_parallel,
                          rounding_fn=math.floor):
    """Test L1 ranking and pruning of filters.
    First we rank and prune the filters of a Convolutional layer using
    a L1RankedStructureParameterPruner.  Then we physically remove the
    filters from the model (via "thining" process).
    """
    logger.info("executing: %s (invoked by %s)" %
                (inspect.currentframe().f_code.co_name,
                 inspect.currentframe().f_back.f_code.co_name))

    model, zeros_mask_dict = common.setup_test(config.arch, config.dataset,
                                               is_parallel)

    for pair in config.module_pairs:
        # Test that we can access the weights tensor of the first convolution in layer 1
        conv1_p = distiller.model_find_param(model, pair[0] + ".weight")
        assert conv1_p is not None
        num_filters = conv1_p.size(0)

        # Test that there are no zero-filters
        assert distiller.sparsity_3D(conv1_p) == 0.0

        # Create a filter-ranking pruner
        pruner = distiller.pruning.L1RankedStructureParameterPruner(
            "filter_pruner",
            group_type="Filters",
            desired_sparsity=ratio_to_prune,
            weights=pair[0] + ".weight",
            rounding_fn=rounding_fn)
        pruner.set_param_mask(conv1_p,
                              pair[0] + ".weight",
                              zeros_mask_dict,
                              meta=None)

        conv1 = common.find_module_by_name(model, pair[0])
        assert conv1 is not None
        # Test that the mask has the correct fraction of filters pruned.
        # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters
        expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels)
        expected_pruning = expected_cnt_removed_filters / conv1.out_channels
        masker = zeros_mask_dict[pair[0] + ".weight"]
        assert masker is not None
        assert distiller.sparsity_3D(masker.mask) == expected_pruning

        # Use the mask to prune
        assert distiller.sparsity_3D(conv1_p) == 0
        masker.apply_mask(conv1_p)
        assert distiller.sparsity_3D(conv1_p) == expected_pruning

        # Remove filters
        conv2 = common.find_module_by_name(model, pair[1])
        assert conv2 is not None
        assert conv1.out_channels == num_filters
        assert conv2.in_channels == num_filters

    # Test thinning
    distiller.remove_filters(model,
                             zeros_mask_dict,
                             config.arch,
                             config.dataset,
                             optimizer=None)
    assert conv1.out_channels == num_filters - expected_cnt_removed_filters
    assert conv2.in_channels == num_filters - expected_cnt_removed_filters

    # Test the thinned model
    dummy_input = distiller.get_dummy_input(config.dataset,
                                            distiller.model_device(model))
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=0.1)
    run_forward_backward(model, optimizer, dummy_input)

    return model, zeros_mask_dict
Ejemplo n.º 17
0
def test_summary(what):
    dataset = "cifar10"
    arch = "resnet20_cifar"
    model, _ = common.setup_test(arch, dataset, parallel=True)

    distiller.model_summary(model, what, dataset=dataset)
Ejemplo n.º 18
0
def arbitrary_channel_pruning(config, channels_to_remove):
    """Test removal of arbitrary channels.

    The test receives a specification of channels to remove.
    Based on this specification, the channels are pruned and then physically
    removed from the model (via a "thinning" process).
    """
    model, zeros_mask_dict = common.setup_test(config.arch, config.dataset)

    conv2 = common.find_module_by_name(model, config.conv2_name)
    assert conv2 is not None

    # Test that we can access the weights tensor of the first convolution in layer 1
    conv2_p = distiller.model_find_param(model, config.conv2_name + ".weight")
    assert conv2_p is not None

    assert conv2_p.dim() == 4
    num_filters = conv2_p.size(0)
    num_channels = conv2_p.size(1)
    kernel_height = conv2_p.size(2)
    kernel_width = conv2_p.size(3)
    cnt_nnz_channels = num_channels - len(channels_to_remove)

    # Let's build our 4D mask.
    # We start with a 1D mask of channels, with all but our specified channels set to one
    channels = torch.ones(num_channels)
    for ch in channels_to_remove:
        channels[ch] = 0

    # Now let's expand back up to a 4D mask
    mask = channels.expand(num_filters, num_channels)
    mask.unsqueeze_(-1)
    mask.unsqueeze_(-1)
    mask = mask.expand(num_filters, num_channels, kernel_height, kernel_width).contiguous()

    assert mask.shape == conv2_p.shape
    assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels

    # Cool, so now we have a mask for pruning our channels.
    # Use the mask to prune
    zeros_mask_dict[config.conv2_name + ".weight"].mask = mask
    zeros_mask_dict[config.conv2_name + ".weight"].apply_mask(conv2_p)
    all_channels = set([ch for ch in range(num_channels)])
    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, config.conv2_name + ".weight"))
    channels_removed = all_channels - nnz_channels
    logger.info("Channels removed {}".format(channels_removed))

    # Now, let's do the actual network thinning
    distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset)
    conv1 = common.find_module_by_name(model, config.conv1_name)
    logger.info(conv1)
    logger.info(conv2)
    assert conv1.out_channels == cnt_nnz_channels
    assert conv2.in_channels == cnt_nnz_channels
    assert conv1.weight.size(0) == cnt_nnz_channels
    assert conv2.weight.size(1) == cnt_nnz_channels
    if config.bn_name is not None:
        bn1 = common.find_module_by_name(model, config.bn_name)
        assert bn1.running_var.size(0) == cnt_nnz_channels
        assert bn1.running_mean.size(0) == cnt_nnz_channels
        assert bn1.num_features == cnt_nnz_channels
        assert bn1.bias.size(0) == cnt_nnz_channels
        assert bn1.weight.size(0) == cnt_nnz_channels

    # Let's test saving and loading a thinned model.
    # We save 3 times, and load twice, to make sure to cover some corner cases:
    #   - Make sure that after loading, the model still has hold of the thinning recipes
    #   - Make sure that after a 2nd load, there no problem loading (in this case, the
    #   - tensors are already thin, so this is a new flow)
    # (1)
    save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None)
    model_2 = create_model(False, config.dataset, config.arch, parallel=False)
    dummy_input = torch.randn(1, 3, 32, 32)
    model(dummy_input)
    model_2(dummy_input)
    conv2 = common.find_module_by_name(model_2, config.conv2_name)
    assert conv2 is not None
    with pytest.raises(KeyError):
        model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
    compression_scheduler = distiller.CompressionScheduler(model)
    hasattr(model, 'thinning_recipes')

    # (2)
    save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler)
    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
    assert hasattr(model_2, 'thinning_recipes')
    logger.info("test_arbitrary_channel_pruning - Done")

    # (3)
    save_checkpoint(epoch=0, arch=config.arch, model=model_2, optimizer=None, scheduler=compression_scheduler)
    model_2, compression_scheduler, start_epoch = load_checkpoint(model_2, 'checkpoint.pth.tar')
    assert hasattr(model_2, 'thinning_recipes')
    logger.info("test_arbitrary_channel_pruning - Done 2")