Esempio n. 1
0
def compute_raw_weight(model: nn.Module,
                       input_size: typing.Tuple[int, int],
                       cuda=False):
    """
    compute d_flops_in and d_flops_out for every convolutional layers in the model

    Note: this method needs to do forward pass, which is time-consuming
    """
    model._flops_weight_computed = True

    # register hooks
    hook_handles = []
    for submodule in model.modules():
        if isinstance(submodule, nn.Conv2d):
            hook_handles.append(
                submodule.register_forward_hook(_conv_raw_weight_hook))
        elif isinstance(submodule, nn.Linear):
            hook_handles.append(
                submodule.register_forward_hook(_linear_raw_weight_hook))

    # do forward pass to compute the input spatial size for each layer
    random_input = torch.rand(8, 3, *input_size)
    if cuda:
        random_input = random_input.cuda()
        model = model.cuda()
    model(random_input)

    # remove hooks
    for h in hook_handles:
        h.remove()
Esempio n. 2
0
def compute_conv_flops_weight(model: nn.Module,
                              building_block,
                              input_size: typing.Tuple[int, int] = (32, 32),
                              cuda=False) -> typing.List[typing.Tuple[int]]:
    """
    compute the conv_flops_weight for the model

    :param building_block: the basic building block for CNN.
     Use BasicBlock for ResNet-56. Use VGGBlock for CIFAR VGG.
    :param input_size: the input spatial size of the network
    """
    if not model.flops_weight_computed or model.input_size is None or model.input_size != input_size:
        # initialization

        # set flag
        model._flops_weight_computed = True
        model.input_size = input_size

        # update when
        # 1. the flops is never be computed
        # 2. the input size is changed
        compute_raw_weight(model, input_size,
                           cuda=cuda)  # compute d_flops_in and d_flops_out

        # the weight is raw weight (without scaling!)
        # now compute the weight min and max for rescaling
        conv_flops_weight_raw: typing.List[typing.Tuple[int, int]] = []
        for submodule in model.modules():
            if isinstance(submodule, building_block):
                submodule: building_block
                block_raw_weight = submodule.get_conv_flops_weight(
                    update=True, scaling=False)
                conv_flops_weight_raw.append(block_raw_weight)
        # scale weight to [0, 1]
        # set weight_max and weight_min for each building blocks
        weights = []
        for block in conv_flops_weight_raw:
            for i in block:
                weights.append(i)

        max_weights = max(weights)
        min_weights = min(weights)

        # compute the min_weights and max_weights for rescaling
        # for all blocks, the raw_weight_min and the raw_weight_max are same
        for submodule in model.modules():
            if isinstance(submodule, building_block):
                submodule: building_block
                submodule.raw_weight_min = min_weights
                submodule.raw_weight_max = max_weights

    conv_flops_weight: typing.List[typing.Tuple[int]] = []
    for submodule in model.modules():
        if isinstance(submodule, building_block):
            submodule: building_block
            conv_flops_weight.append(submodule.conv_flops_weight)

    return conv_flops_weight