def collect_conv_details(model, dataset): dummy_input = get_dummy_input(dataset) g = SummaryGraph(model.cuda(), dummy_input.cuda()) conv_layers = OrderedDict() total_macs = 0 total_nnz = 0 for id, (name, m) in enumerate(model.named_modules()): if isinstance(m, torch.nn.Conv2d): conv = SimpleNamespace() conv.t = len(conv_layers) conv.k = m.kernel_size[0] conv.stride = m.stride # Use the SummaryGraph to obtain some other details of the models conv_op = g.find_op(normalize_module_name(name)) assert conv_op is not None conv.weights_vol = conv_op['attrs']['weights_vol'] total_nnz += conv.weights_vol conv.macs = conv_op['attrs']['MACs'] conv_pname = name + ".weight" conv_p = distiller.model_find_param(model, conv_pname) conv.macs *= distiller.density_ch(conv_p) total_macs += conv.macs conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2] conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3] conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2] conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3] conv.name = name conv.id = id conv_layers[len(conv_layers)] = conv return conv_layers, total_macs, total_nnz
def collect_conv_details(model, dataset): if dataset == 'imagenet': dummy_input = torch.randn(1, 3, 224, 224) elif dataset == 'cifar10': dummy_input = torch.randn(1, 3, 32, 32) else: raise ValueError("dataset %s is not supported" % dataset) g = SummaryGraph(model.cuda(), dummy_input.cuda()) conv_layers = OrderedDict() total_macs = 0 for id, (name, m) in enumerate(model.named_modules()): if isinstance(m, torch.nn.Conv2d): conv = SimpleNamespace() conv.t = len(conv_layers) conv.k = m.kernel_size[0] conv.stride = m.stride # Use the SummaryGraph to obtain some other details of the models conv_op = g.find_op(normalize_module_name(name)) assert conv_op is not None conv.macs = conv_op['attrs']['MACs'] total_macs += conv.macs conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2] conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3] conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2] conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3] conv.name = name conv.id = id conv_layers[len(conv_layers)] = conv return conv_layers, total_macs
def collect_conv_details(model, dataset, perform_thinning, layers_to_prune=None): dummy_input = distiller.get_dummy_input(dataset) g = SummaryGraph(model, dummy_input) conv_layers = OrderedDict() total_macs = 0 total_params = 0 for id, (name, m) in enumerate(model.named_modules()): if isinstance(m, torch.nn.Conv2d): conv = SimpleNamespace() conv.t = len(conv_layers) conv.k = m.kernel_size[0] conv.stride = m.stride # Use the SummaryGraph to obtain some other details of the models conv_op = g.find_op(normalize_module_name(name)) assert conv_op is not None conv.weights_vol = conv_op['attrs']['weights_vol'] total_params += conv.weights_vol conv.macs = conv_op['attrs']['MACs'] conv_pname = name + ".weight" conv_p = distiller.model_find_param(model, conv_pname) if not perform_thinning: #conv.macs *= distiller.density_ch(conv_p) # Channel pruning conv.macs *= distiller.density_3D(conv_p) # Filter pruning total_macs += conv.macs conv.ofm_h = g.param_shape(conv_op['outputs'][0])[2] conv.ofm_w = g.param_shape(conv_op['outputs'][0])[3] conv.ifm_h = g.param_shape(conv_op['inputs'][0])[2] conv.ifm_w = g.param_shape(conv_op['inputs'][0])[3] conv.name = name conv.id = id if layers_to_prune is None or name in layers_to_prune: conv_layers[len(conv_layers)] = conv return conv_layers, total_macs, total_params