Ejemplo n.º 1
0
def collect_conv_details(model, dataset):
    dummy_input = get_dummy_input(dataset)
    g = SummaryGraph(model.cuda(), dummy_input.cuda())
    conv_layers = OrderedDict()
    total_macs = 0
    total_nnz = 0
    for id, (name, m) in enumerate(model.named_modules()):
        if isinstance(m, torch.nn.Conv2d):
            conv = SimpleNamespace()
            conv.t = len(conv_layers)
            conv.k = m.kernel_size[0]
            conv.stride = m.stride

            # Use the SummaryGraph to obtain some other details of the models
            conv_op = g.find_op(normalize_module_name(name))
            assert conv_op is not None

            conv.weights_vol = conv_op['attrs']['weights_vol']
            total_nnz += conv.weights_vol
            conv.macs = conv_op['attrs']['MACs']
            conv_pname = name + ".weight"
            conv_p = distiller.model_find_param(model, conv_pname)
            conv.macs *= distiller.density_ch(conv_p)
            total_macs += conv.macs

            conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2]
            conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3]
            conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2]
            conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3]

            conv.name = name
            conv.id = id
            conv_layers[len(conv_layers)] = conv

    return conv_layers, total_macs, total_nnz
Ejemplo n.º 2
0
def collect_conv_details(model, dataset):
    if dataset == 'imagenet':
        dummy_input = torch.randn(1, 3, 224, 224)
    elif dataset == 'cifar10':
        dummy_input = torch.randn(1, 3, 32, 32)
    else:
        raise ValueError("dataset %s is not supported" % dataset)

    g = SummaryGraph(model.cuda(), dummy_input.cuda())
    conv_layers = OrderedDict()
    total_macs = 0
    for id, (name, m) in enumerate(model.named_modules()):
        if isinstance(m, torch.nn.Conv2d):
            conv = SimpleNamespace()
            conv.t = len(conv_layers)
            conv.k = m.kernel_size[0]
            conv.stride = m.stride

            # Use the SummaryGraph to obtain some other details of the models
            conv_op = g.find_op(normalize_module_name(name))
            assert conv_op is not None

            conv.macs = conv_op['attrs']['MACs']
            total_macs += conv.macs
            conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2]
            conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3]
            conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2]
            conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3]

            conv.name = name
            conv.id = id
            conv_layers[len(conv_layers)] = conv

    return conv_layers, total_macs
Ejemplo n.º 3
0
def get_model_compute_budget(model, dataset, layers_to_prune=None):
    """Return the compute budget of the Convolution layers in an image-classifier.
    """
    dummy_input = distiller.get_dummy_input(dataset)
    g = SummaryGraph(model, dummy_input)
    total_macs = 0
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Conv2d):
            # Use the SummaryGraph to obtain some other details of the models
            conv_op = g.find_op(normalize_module_name(name))
            assert conv_op is not None
            total_macs += conv_op['attrs']['MACs']
    del g
    return total_macs
Ejemplo n.º 4
0
def create_graph(dataset, arch):
    dummy_input = get_input(dataset)
    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)

    model = create_model(False, dataset, arch, parallel=False)
    assert model is not None
    return SummaryGraph(model, dummy_input)
Ejemplo n.º 5
0
def create_graph(dataset, arch):
    if dataset == 'imagenet':
        dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
    elif dataset == 'cifar10':
        dummy_input = torch.randn((1, 3, 32, 32))
    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)

    model = create_model(False, dataset, arch, parallel=False)
    assert model is not None
    return SummaryGraph(model, dummy_input)
Ejemplo n.º 6
0
def collect_conv_details(model,
                         dataset,
                         perform_thinning,
                         layers_to_prune=None):
    dummy_input = distiller.get_dummy_input(dataset)
    g = SummaryGraph(model, dummy_input)
    conv_layers = OrderedDict()
    total_macs = 0
    total_params = 0
    for id, (name, m) in enumerate(model.named_modules()):
        if isinstance(m, torch.nn.Conv2d):
            conv = SimpleNamespace()
            conv.t = len(conv_layers)
            conv.k = m.kernel_size[0]
            conv.stride = m.stride

            # Use the SummaryGraph to obtain some other details of the models
            conv_op = g.find_op(normalize_module_name(name))
            assert conv_op is not None

            conv.weights_vol = conv_op['attrs']['weights_vol']
            total_params += conv.weights_vol
            conv.macs = conv_op['attrs']['MACs']
            conv_pname = name + ".weight"
            conv_p = distiller.model_find_param(model, conv_pname)
            if not perform_thinning:
                #conv.macs *= distiller.density_ch(conv_p)  # Channel pruning
                conv.macs *= distiller.density_3D(conv_p)  # Filter pruning
            total_macs += conv.macs

            conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2]
            conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3]
            conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2]
            conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3]

            conv.name = name
            conv.id = id
            if layers_to_prune is None or name in layers_to_prune:
                conv_layers[len(conv_layers)] = conv
    return conv_layers, total_macs, total_params
Ejemplo n.º 7
0
def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
    # sgraph = create_graph(dataset, arch)

    # Add our model in remove filters 2018-09-19 CKH
    # set dummy_input base on arch from compression scheduler 2018-10-19 13:36:53 by CKH
    if arch == 'deblurGAN':
        dummy_input = torch.randn((1, 3, 224, 224))
    else:
        dummy_input = torch.randn((1, 1, 224, 224))

    sgraph = SummaryGraph(model, dummy_input.cuda())

    thinning_recipe = create_thinning_recipe_filters(sgraph, model,
                                                     zeros_mask_dict)
    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
    return model