def compute_raw_weight(model: nn.Module, input_size: typing.Tuple[int, int], cuda=False): """ compute d_flops_in and d_flops_out for every convolutional layers in the model Note: this method needs to do forward pass, which is time-consuming """ model._flops_weight_computed = True # register hooks hook_handles = [] for submodule in model.modules(): if isinstance(submodule, nn.Conv2d): hook_handles.append( submodule.register_forward_hook(_conv_raw_weight_hook)) elif isinstance(submodule, nn.Linear): hook_handles.append( submodule.register_forward_hook(_linear_raw_weight_hook)) # do forward pass to compute the input spatial size for each layer random_input = torch.rand(8, 3, *input_size) if cuda: random_input = random_input.cuda() model = model.cuda() model(random_input) # remove hooks for h in hook_handles: h.remove()
def compute_conv_flops_weight(model: nn.Module, building_block, input_size: typing.Tuple[int, int] = (32, 32), cuda=False) -> typing.List[typing.Tuple[int]]: """ compute the conv_flops_weight for the model :param building_block: the basic building block for CNN. Use BasicBlock for ResNet-56. Use VGGBlock for CIFAR VGG. :param input_size: the input spatial size of the network """ if not model.flops_weight_computed or model.input_size is None or model.input_size != input_size: # initialization # set flag model._flops_weight_computed = True model.input_size = input_size # update when # 1. the flops is never be computed # 2. the input size is changed compute_raw_weight(model, input_size, cuda=cuda) # compute d_flops_in and d_flops_out # the weight is raw weight (without scaling!) # now compute the weight min and max for rescaling conv_flops_weight_raw: typing.List[typing.Tuple[int, int]] = [] for submodule in model.modules(): if isinstance(submodule, building_block): submodule: building_block block_raw_weight = submodule.get_conv_flops_weight( update=True, scaling=False) conv_flops_weight_raw.append(block_raw_weight) # scale weight to [0, 1] # set weight_max and weight_min for each building blocks weights = [] for block in conv_flops_weight_raw: for i in block: weights.append(i) max_weights = max(weights) min_weights = min(weights) # compute the min_weights and max_weights for rescaling # for all blocks, the raw_weight_min and the raw_weight_max are same for submodule in model.modules(): if isinstance(submodule, building_block): submodule: building_block submodule.raw_weight_min = min_weights submodule.raw_weight_max = max_weights conv_flops_weight: typing.List[typing.Tuple[int]] = [] for submodule in model.modules(): if isinstance(submodule, building_block): submodule: building_block conv_flops_weight.append(submodule.conv_flops_weight) return conv_flops_weight