def prune_model(model, amount, prune_mask, method=prune.L1Unstructured): model.to('cpu') model.mask_to_device('cpu') for name, module in model.named_modules(): # re-apply current mask to the model if isinstance(module, torch.nn.Linear): # if name is not "fc4": prune.custom_from_mask(module, "weight", prune_mask[name]) parameters_to_prune = ( (model.fc1, 'weight'), (model.fc2, 'weight'), (model.fc3, 'weight'), (model.fc4, 'weight'), ) prune.global_unstructured( # global prune the model parameters_to_prune, pruning_method=method, amount=amount, ) for name, module in model.named_modules(): # make pruning "permanant" by removing the orig/mask values from the state dict if isinstance(module, torch.nn.Linear): # if name is not "fc4": torch.logical_and(module.weight_mask, prune_mask[name], out=prune_mask[name]) # Update progress mask prune.remove(module, 'weight') # remove all those values in the global pruned model return model
def prune_fc(layer, paths) -> None: names = [name for (name, content) in layer.named_parameters() if "weight" in name] contents = [content for (name, content) in layer.named_parameters() if "weight" in name] for name, content in zip(names, contents): prune.custom_from_mask(layer, name=name, mask=paths.get_fc())
def prune_conv_layer(model, layer_index, filter_index, criterion='lrp', cuda_flag=False): ''' input parameters 1. model: 현재 모델 2. layer_index: 자르고자 하는 layer index 3. filter_index: 자르고자 하는 layer의 filter index ''' conv = dict(model.named_modules())[layer_index] if not hasattr(conv, "output_mask"): # Instantiate output mask tensor of shape (num_output_channels, ) conv.output_mask = torch.ones(conv.weight.shape[0]) # Make sure the filter was not pruned before assert conv.output_mask[filter_index] != 0 conv.output_mask[filter_index] = 0 mask_weight = conv.output_mask.view(-1, 1, 1, 1).expand_as(conv.weight) torch_prune.custom_from_mask(conv, "weight", mask_weight) if conv.bias is not None: mask_bias = conv.output_mask torch_prune.custom_from_mask(conv, "bias", mask_bias) if cuda_flag: conv.weight = conv.weight.cuda() # conv.module.bias = conv.module.bias.cuda() return model
def global_unstructured( self, pruning_method: torch.nn.utils.prune.BasePruningMethod, **kwargs ): """Based on https://pytorch.org/docs/stable/_modules/torch/nn/utils/prune.html#global_unstructured. Modify scores depending on the algorithm. """ assert isinstance(self.params_to_prune, Iterable) scores = self.get_prune_score() t = torch.nn.utils.parameters_to_vector(scores) # similarly, flatten the masks (if they exist), or use a flattened vector # of 1s of the same dimensions as t default_mask = torch.nn.utils.parameters_to_vector( [ getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) for (module, name) in self.params_to_prune # type: ignore ] ) # use the canonical pruning methods to compute the new mask, even if the # parameter is now a flattened out version of `parameters` container = prune.PruningContainer() container._tensor_name = "temp" # type: ignore method = pruning_method(**kwargs) method._tensor_name = "temp" # type: ignore if method.PRUNING_TYPE != "unstructured": raise TypeError( 'Only "unstructured" PRUNING_TYPE supported for ' "the `pruning_method`. Found method {} of type {}".format( pruning_method, method.PRUNING_TYPE ) ) container.add_pruning_method(method) # use the `compute_mask` method from `PruningContainer` to combine the # mask computed by the new method with the pre-existing mask final_mask = container.compute_mask(t, default_mask) # Pointer for slicing the mask to match the shape of each parameter pointer = 0 for module, name in self.params_to_prune: # type: ignore param = getattr(module, name) # The length of the parameter num_param = param.numel() # Slice the mask, reshape it param_mask = final_mask[pointer : pointer + num_param].view_as(param) # Assign the correct pre-computed mask to each parameter and add it # to the forward_pre_hooks like any other pruning method prune.custom_from_mask(module, name, param_mask) # Increment the pointer to continue slicing the final_mask pointer += num_param
def prune_layer(layer, prune_rate, reset=True): """ Prune given 'layer' with 'prune_rate' and reset surviving weights to their initial values, if 'reset' is True. Calls with 'prune_rate'=0.0 do not remove weights. """ if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d): pruned_mask = prune_mask(layer, prune_rate) if reset: prune.remove(layer, name='weight') # temporarily remove pruning layer.weight = nn.Parameter( layer.weight_init.clone()) # set weights to initial weights prune.custom_from_mask(layer, name='weight', mask=pruned_mask) # apply pruned mask
def setup_graph(self): # initialize the masks if self._mask_init_method == 'random': for i, mask in enumerate(self.masks): size = mask.size() self.masks[i] = torch.tensor(get_mask_random(size, self._default_sparsity), dtype=mask.dtype).to(device) # initialize masked weight for i, layer in enumerate(self.layers): if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', self.masks[i])
def prune_connections(model, paths, output_type="m1"): """ Prune connections in model in accordance with the paths found by the algorithm. Parameters ---------- model : Model Model which is to be pruned paths : Paths Paths objects, specifies which paths in the model must be pruned output_type : str, optional Whether the cell type will be M1 or M2. The default is "m1". Returns ------- model: Model Pruned model. """ if output_type == "m1": # Output type M1 path = torch.from_numpy(paths.m1) prune.custom_from_mask(model.x2h, "weight", path) prune.custom_from_mask(model.h2h, "weight", path) elif output_type == "m2": # Output type M2 path = torch.from_numpy(paths.m2) prune.custom_from_mask(model.x2h, "weight", path) prune.custom_from_mask(model.h2h, "weight", path) else: raise ValueError("output_type kwarg must be m1 or m2") return model
def prune_cell_m2(layer, paths) -> None: """Prune M2 cell. Since M2 cells only have one output, the paths are adapted accordingly.""" # Copy names, contents names = [name for (name, content) in layer.named_parameters() if "weight" in name] contents = [content for (name, content) in layer.named_parameters() if "weight" in name] # Pruning step for name, content in zip(names, contents): if "hh" in name: # hh layer has 4*n_out, n_out dims prune.custom_from_mask(layer, name=name, mask=paths.get_m2(True)) else: # ih layer has 4*n_out, n_in dims prune.custom_from_mask(layer, name=name, mask=paths.get_m2())
def prune_cell_m1(layer, paths) -> None: """Prune M1 cell. Shape of contents entry is 4*hiddens, n_input + n_output """ # Copy names, contents names = [name for (name, content) in layer.named_parameters() if "weight" in name] contents = [content for (name, content) in layer.named_parameters() if "weight" in name] # Do we also want to prune the biases? And do we even need contents? # Pruning step for name, content in zip(names, contents): prune.custom_from_mask(layer, name=name, mask=paths.get_m1())
def test_prune_conv_layer_correctly(self): """ Prune the mask for a pruned convolutional layer in one step. Should zero out the three weights (as ceil(12*0.2)=3) with the lowest magnitude. """ # initialize conv layer with 16 given weights and pruned mask initial_weights = torch.tensor([ 1.2, -0.1, 1.2, 4.3, -2.1, -1.1, -0.8, 1.2, 0.5, 0.2, 0.4, 1.4, 2.2, -0.8, 0.4, 0.9 ]).view(2, 2, 2, 2) initial_mask = torch.tensor( [1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1.]).view(2, 2, 2, 2) test_layer = nn.Conv2d(2, 2, kernel_size=2, padding=1) test_layer.weight = nn.Parameter(initial_weights.clone()) test_layer.register_buffer('weight_init', initial_weights.clone()) test_layer = prune.custom_from_mask(test_layer, name='weight', mask=initial_mask) mp.prune_layer(layer=test_layer, prune_rate=0.2) expected_weights = torch.tensor([ 1.2, -0., 1.2, 4.3, -2.1, -1.1, -0., 1.2, 0., 0., 0., 1.4, 2.2, -0., 0., 0.9 ]).view(2, 2, 2, 2) self.assertIs((test_layer.weight == expected_weights).all().item(), True)
def pruning_generate(model, state_dict): parameters_to_prune = [] for (name, m) in model.named_modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): m = prune.custom_from_mask(m, name='weight', mask=state_dict[name + ".weight_mask"])
def update_layer_mask(self, layer, layer_mask, noise_std=1e-5): score = self.grad_dict[layer].abs() + noise_std # Add noise for slight bit of randomness. score_drop = score * layer_mask layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask) # Randomly revive n_prune many connections from non-existing connections. score_grow = score * (~layer_mask_dropped) layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune) # update the weight if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', layer_mask) return layer_mask, new_mask
def prune_model(model, index_list): mag_w_mask = torch.zeros(100, 512) for index in index_list: mag_w_mask[:, index] = 1 module = model.fc[0] pruned_w = prune.custom_from_mask(module, name='weight', mask=mag_w_mask) model.fc[0] = pruned_w return model
def update_layer_mask(self, layer, layer_mask, noise_std=1e-5): layer_weight = layer.weight layer_grad = self.grad_dict[layer] # Remove weight smaller than adaptive threshold layer_mask_dropped, drop_n = sparsify_weight(layer_weight.abs(), layer_mask, self._drop_fraction) # Grow weight whose gradient larger than adaptive threshold score_grow = layer_grad * (~layer_mask_dropped) layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, int(drop_n*self._grow_fraction)) # update the weight if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', layer_mask) return layer_mask, new_mask
def update_layer_mask(self, layer, layer_mask, noise_std=1e-5): layer_weight = layer.weight # Add noise for slight bit of randomness and drop masked_weights = layer_mask * layer_weight score_drop = masked_weights.abs() + self._random_normal(layer_weight.size(), std=noise_std) layer_mask_dropped, n_prune = self.drop_minimum(score_drop, layer_mask) # Randomly revive n_prune many connections from non-existing connections. score_grow = self.grad_dict[layer].abs() * (~layer_mask_dropped) + self._random_uniform(layer_weight.size()) * noise_std layer_mask, new_mask = self.grow_maximum(score_grow, layer_mask_dropped, n_prune) # update the weight if prune.is_pruned(layer): prune.remove(layer, 'weight') prune.custom_from_mask(layer, 'weight', layer_mask) return layer_mask, new_mask
def load_state(self, modelpath='/raid/zhassylbekov/sungbae/model/initstate.pth'): # umask = dict(self.u_embeddings.named_buffers())['weight_mask'].cpu() vmask = dict(self.v_embeddings.named_buffers())['weight_mask'].cpu() cuda_using = next(self.parameters()).is_cuda self = SkipGramModel(self.vocab_size, self.emb_dimension) if cuda_using: self.load_state_dict(torch.load(modelpath)) # prune.custom_from_mask(self.u_embeddings,name='weight',mask=umask) prune.custom_from_mask(self.v_embeddings, name='weight', mask=vmask) self.cuda() else: self.load_state_dict(torch.load(modelpath)) # prune.custom_from_mask(self.u_embeddings,name='weight',mask=umask) prune.custom_from_mask(self.v_embeddings, name='weight', mask=vmask)
def prune_step_change(self, pstep, prune_mode): cuda_using = next(self.parameters()).is_cuda #reimport fixed u and v weights ui, vi = self.load_weights() #fix current weights uc, vc = (self.u_embeddings.weight.data.clone().cpu(), self.v_embeddings.weight.data.clone().cpu()) #fix current masks if not list(self.u_embeddings.named_buffers()): #prune.identity(self.u_embeddings, name='weight') prune.identity(self.v_embeddings, name='weight') #umask = dict(self.u_embeddings.named_buffers())['weight_mask'].cpu() vmask = dict(self.v_embeddings.named_buffers())['weight_mask'].cpu() #u_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension) v_temp = torch.nn.Embedding(self.vocab_size, self.emb_dimension) if prune_mode == 'change': f = lambda x, y: x - y elif prune_mode == 'absolute change': f = lambda x, y: torch.abs(x) - torch.abs(y) else: f = lambda x, y: x # weights to be left must have higher function outputs #u_temp.weight.data.copy_(f(uc,ui)) v_temp.weight.data.copy_(f(vc, vi)) #prune.custom_from_mask(u_temp,name='weight',mask=umask) prune.custom_from_mask(v_temp, name='weight', mask=vmask) if cuda_using: # u_temp.cuda() v_temp.cuda() #prune.l1_unstructured(u_temp, name='weight', amount=pstep) prune.l1_unstructured(v_temp, name='weight', amount=pstep) #checked, cuda <-> cpu crash DNE #u_temp.weight.data.copy_(uc) v_temp.weight.data.copy_(vc) #self.u_embeddings = u_temp self.v_embeddings = v_temp
def prune_equal_fanin( model: torch.nn.Module, epoch: int, prune_epoch: int, k: int = 2, device: torch.device = torch.device('cpu')) -> torch.nn.Module: """ Prune the linear layers of the network such that each neuron has the same fan-in. :param model: pytorch model. :param epoch: current training epoch. :param prune_epoch: training epoch when pruning needs to be applied. :param k: fan-in. :param device: cpu or cuda device. :return: Pruned model """ if epoch != prune_epoch: return model model.eval() for i, module in enumerate(model.children()): # prune only Linear layers if isinstance(module, torch.nn.Linear): # create mask mask = torch.ones(module.weight.shape) # identify weights with the lowest absolute values param_absneg = -torch.abs(module.weight) idx = torch.topk(param_absneg, k=param_absneg.shape[1] - k, dim=1)[1] for j in range(len(idx)): mask[j, idx[j]] = 0 # prune mask = mask.to(device) prune.custom_from_mask(module, name="weight", mask=mask) # print(f"Pruned {k}/{module.weight.shape[1]} weights") return model
def test_prune_mask_for_linear_layer_correctly(self): """ Prune the mask for an unpruned linear layer in one step. Should zero out the two weights with the lowest magnitude. """ # initialize linear layer with 10 given weights and unpruned mask initial_weights = torch.tensor([[1., -2., 3., -1.5, -3.], [-1., 2., -4., 0.5, 1.5]]) test_layer = nn.Linear(2, 5) test_layer.weight = nn.Parameter(initial_weights.clone()) test_layer = prune.custom_from_mask(test_layer, name='weight', mask=torch.ones_like( test_layer.weight)) test_mask_pruned = mp.prune_mask(layer=test_layer, prune_rate=0.2) self.assertIs( test_mask_pruned.equal( torch.tensor([[0., 1., 1., 1., 1.], [1., 1., 1., 0., 1.]])), True)
def test_do_not_apply_init_weight_after_pruning_linear_layer(self): """ Generate, modify and prune an unpruned linear layer. Its weights should not be reset. """ # initialize linear layer with 6 given weights and unpruned mask initial_weights = torch.tensor([[1., -2., 3.], [-4., 5., -6.]]) test_layer = nn.Linear(2, 3) test_layer.weight = nn.Parameter( 2 * initial_weights.clone()) # fake training, i.e. save modify weights test_layer.register_buffer('weight_init', initial_weights.clone()) test_layer = prune.custom_from_mask(test_layer, name='weight', mask=torch.ones_like( test_layer.weight)) mp.prune_layer(layer=test_layer, prune_rate=0.2, reset=False) expected_weights = torch.tensor([[0., -0., 6.], [-8., 10., -12.]]) self.assertIs(test_layer.weight.equal(expected_weights), True)
def test_prune_linear_layer_correctly(self): """ Prune the mask for a pruned linear layer in one step. Should zero out the two weights (as ceil(8*0.2)=2) with the lowest magnitude. """ # initialize linear layer with 10 given weights and pruned mask initial_weights = torch.tensor([[1., -2., 3., -1.5, -3.], [-1., 2., -4., 0.5, 1.5]]) initial_mask = torch.tensor([[0., 1., 1., 1., 1.], [1., 1., 1., 0., 1.]]) test_layer = nn.Linear(2, 5) test_layer.weight = nn.Parameter(initial_weights.clone()) test_layer.register_buffer('weight_init', initial_weights.clone()) test_layer = prune.custom_from_mask(test_layer, name='weight', mask=initial_mask) mp.prune_layer(layer=test_layer, prune_rate=0.2) expected_weights = torch.tensor([[0., -2., 3., -0., -3.], [-0., 2., -4., 0., 1.5]]) self.assertIs(test_layer.weight.equal(expected_weights), True)
def test_prune_mask_for_conv_layer_correctly(self): """ Prune the mask for an unpruned convolutional layer in one step. Should zero out the two weights with the lowest magnitude. """ # Initialize conv layer with 16 given weights and unpruned mask initial_weights = torch.tensor([ 1.2, -0.1, 1.2, 4.3, -2.1, -1.1, -0.8, 1.2, 0.5, 0.2, 0.4, 1.4, 2.2, -0.8, 0.4, 0.9 ]).view(2, 2, 2, 2) test_layer = nn.Conv2d(2, 2, kernel_size=2, padding=1) test_layer.weight = nn.Parameter(initial_weights.clone()) test_layer = prune.custom_from_mask(test_layer, name='weight', mask=torch.ones_like( test_layer.weight)) test_mask_pruned = mp.prune_mask(layer=test_layer, prune_rate=0.2) expected_mask = torch.tensor( [1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1.]).view(2, 2, 2, 2) self.assertIs(test_mask_pruned.equal(expected_mask), True)
def __init__(self, in_planes, planes, num_bits, num_bits_weight, stride, type_prune, sparsity, layer_num, option='A'): super(ResNetBlock, self).__init__() if in_planes == 3: op = QConv2d(in_planes, planes, num_bits, num_bits_weight, kernel_size=3, stride=1, padding=1, bias=False) if type_prune == 'channel': op = prune.ln_structured(op, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = op.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat( (tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned) self.add_module("conv", op) bn_op = nn.BatchNorm2d(planes) self.add_module("bn", bn_op) self.add_module("relu", nn.ReLU(inplace=True)) elif in_planes == 1: self.add_module("avg_pool", nn.AvgPool2d(kernel_size=8, stride=1)) self.add_module("flatten", Flatten()) op = nn.Linear(in_features=64, out_features=10) if type_prune == 'channel': op = prune.ln_structured(op, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = op.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat( (tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:op.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) op = prune.custom_from_mask(op, name='weight', mask=tmp_pruned) self.add_module("fc", op) else: op = BasicBlock(in_planes, planes, num_bits, num_bits_weight, stride, type_prune, sparsity, layer_num, option) self.add_module("conv", op)
def setup_masks(layer): """ Setup a mask of ones for all linear and convolutional layers. """ if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d): prune.custom_from_mask(layer, name='weight', mask=torch.ones_like(layer.weight))
def main(args): save_folder = args.save_folder model_folder = os.path.join(args.model_root, save_folder) makedirs(model_folder) setattr(args, 'model_folder', model_folder) logger = create_logger(model_folder, 'train', 'info') print_args(args, logger) # seed torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic=True torch.backends.cudnn.benchmark=False if "ResNet" in args.model : depth_ = args.model.split('-')[1] # when it is dst, resnet block is special if args.prune_method!='dst' : p_type=None res_dict = { '20' : resnet20(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type), '32' : resnet32(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type), '44': resnet44(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type), '56': resnet56(num_classes=int(args.dataset.split('-')[1]), prune_type = p_type) } net = res_dict[depth_] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net.to(device) # set trainer trainer = Trainer(args, logger) # loss loss = nn.CrossEntropyLoss() # dataloader kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {} if args.dataset == 'cifar-10': train_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=True, download=True, transform=transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])), batch_size=args.batch_size, shuffle=False, **kwargs) elif args.dataset == 'cifar-100': train_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=True, download=True, transform=transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ])), batch_size=args.batch_size, shuffle=False, **kwargs) # optimizer & scheduler optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1) # pruning if args.prune_method=='global': if args.prune_type=='group': tmps = [] for n,conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1]==3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = 4 - tmp_pruned.shape[1] % 4 tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, 4) tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() tmps.append(tmp) tmps = torch.cat(tmps) num = tmps.shape[0]*(1 - args.sparsity)#sparsity 0.2 top_k = torch.topk(tmps, int(num), sorted=True) threshold = top_k.values[-1] for n,conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1]==3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = 4 - tmp_pruned.shape[1] % 4 tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, 4) tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) prune.custom_from_mask(conv, name='weight', mask=tmp_pruned) elif args.prune_type =='filter': tmps = [] for n,conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1]==3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() tmps.append(tmp) tmps = torch.cat(tmps) num = tmps.shape[0]*(1 - args.sparsity)#sparsity 0.5 top_k = torch.topk(tmps, int(num), sorted=True) threshold = top_k.values[-1] for n,conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1]==3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) prune.custom_from_mask(conv, name='weight', mask=tmp_pruned) print(f'model pruned!!(sparsity : {args.sparsity : .2f}, prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning') elif args.prune_method=='uniform': assert False, 'uniform code is not ready' elif args.prune_method =='dst': print(f'model pruned!!(prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning') # Training trainer.train(net, loss, device, train_loader, test_loader, optimizer=optimizer, scheduler=scheduler)
def main(args): save_folder = args.save_folder model_folder = os.path.join(args.model_root, save_folder) makedirs(model_folder) setattr(args, 'model_folder', model_folder) logger = create_logger(model_folder, 'train', 'info') print_args(args, logger) # seed torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ResNet if "ResNet" in args.model: depth_ = args.model.split('-')[1] # when it is dst, resnet block is special if args.prune_method != 'dst': p_type = None res_dict = { '18': resnet18(pretrained=args.pretrained, progress=True, prune_type=p_type), '34': resnet34(pretrained=args.pretrained, progress=True, prune_type=p_type), '50': resnet50(pretrained=args.pretrained, progress=True, prune_type=p_type), '101': resnet101(pretrained=args.pretrained, progress=True, prune_type=p_type) } net = res_dict[depth_] #elif 'efficientnet' in args.model: # net = EfficientNet.from_pretrained(args.model) elif args.model == 'efficientnet_b0': print('efficientnet-b0 load...') net = efficientnet_b0(pretrained=args.pretrained) # MobileNet elif args.model == "mobilenetv3-large-1.0": print('mobilenetv3-large-1.0') net = mobilenetv3_large_100(pretrained=args.pretrained) elif args.model == 'once-mobilenetv3-large-1.0': print('once-mobilenetv3-large-1.0') net, image_size = ofa_specialized( 'note8_lat@[email protected]_finetune@25', pretrained=args.pretrained) elif args.model == 'mobilenetv2-120d': print('mobilenetv2-120d load...') net = mobilenetv2_120d(pretrained=args.pretrained) # conv1 trainable if args.conv1_not_train: print('conv1 weight not train') if args.model == "mobilenetv3-large-1.0": for param in net.conv_stem.parameters(): param.requires_grad = False elif "ResNet" in args.model: for param in net.conv1.parameters(): param.requires_grad = False else: assert (False, 'not ready') # custom pretrain path if args.pretrain_path: print('load custom pretrain weight...') net.load_state_dict(torch.load(args.pretrain_path)) net2 = copy.deepcopy(net) # for save removed_models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = nn.DataParallel(net) net.to(device) # KD if args.KD: print('knowledge distillation model load!') teacher_net, image_size = ofa_specialized( 'flops@[email protected]_finetune@75', pretrained=True) # 79.6% teacher_net = nn.DataParallel(teacher_net) teacher_net.to(device) # set trainer if args.KD: trainer = Trainer_KD(args, logger) else: trainer = Trainer(args, logger) # loss loss = nn.CrossEntropyLoss() # dataloader if args.model != 'once-mobilenetv3-large-1': image_size = 224 if args.dataset == 'imagenet': train_loader = torch.utils.data.DataLoader( datasets.ImageNet( '/data/imagenet/', split='train', download=False, transform=transforms.Compose([ transforms.RandomSizedCrop(image_size), transforms.RandomHorizontalFlip(), #ImageNetPolicy(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ])), batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker, pin_memory=True) # test_loader = torch.utils.data.DataLoader(datasets.ImageNet( '/data/imagenet/', split='val', download=False, transform=transforms.Compose([ transforms.Resize(int(image_size / 0.875)), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ])), batch_size=args.batch_size, shuffle=False, num_workers=args.num_worker, pin_memory=True) # optimizer & scheduler optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) if args.scheduler == 'multistep': scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=eval(args.multi_step_epoch), gamma=args.multi_step_gamma) elif args.scheduler == 'plateau': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode='max', patience=3, verbose=True, factor=0.3, threshold=1e-4, min_lr=1e-6) # pruning if args.prune_method == 'global': if args.prune_type == 'group_filter': tmps = [] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() # (out, ch, h, w) tmp_pruned = tmp_pruned.view(original_size[0], -1) # (out, inp) #append_size = 4 - tmp_pruned.shape[1] % 4 #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, args.block_size) # out, -1, 4 tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand( tmp_pruned.shape) # out, -1, 4 tmp = tmp_pruned.flatten() tmps.append(tmp) tmps = torch.cat(tmps) num = tmps.shape[0] * (1 - args.sparsity) #sparsity 0.2 top_k = torch.topk(tmps, int(num), sorted=True) threshold = top_k.values[-1] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) #append_size = 4 - tmp_pruned.shape[1] % 4 #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, args.block_size) # out, -1, 4 tmp_pruned = tmp_pruned.abs().mean(2, keepdim=True).expand( tmp_pruned.shape) # out,-1, 4 tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) # out, inp #tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view( original_size) # out, ch, h, w prune.custom_from_mask(conv, name='weight', mask=tmp_pruned) elif args.prune_type == 'group_channel': tmps = [] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() # (out, ch, h, w) tmp_pruned = tmp_pruned.view(original_size[0], -1) # (out, inp) #append_size = 4 - tmp_pruned.shape[1] % 4 #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view( -1, args.block_size, tmp_pruned.shape[1]) # out, -1, 4 tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand( tmp_pruned.shape) # out, -1, 4 tmp = tmp_pruned.flatten() tmps.append(tmp) tmps = torch.cat(tmps) num = tmps.shape[0] * (1 - args.sparsity) #sparsity 0.2 top_k = torch.topk(tmps, int(num), sorted=True) threshold = top_k.values[-1] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) #append_size = 4 - tmp_pruned.shape[1] % 4 #tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view( -1, args.block_size, tmp_pruned.shape[1]) # out, -1, 4 tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand( tmp_pruned.shape) # out,-1, 4 tmp_pruned = tmp_pruned.ge(threshold) #tmp_pruned = tmp_pruned.view(original_size[0], -1) # out, inp #tmp_pruned = tmp_pruned[:, 0: conv.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view( original_size) # out, ch, h, w prune.custom_from_mask(conv, name='weight', mask=tmp_pruned) elif args.prune_type == 'filter': tmps = [] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand( tmp_pruned.shape) tmp = tmp_pruned.flatten() tmps.append(tmp) tmps = torch.cat(tmps) num = tmps.shape[0] * (1 - args.sparsity) #sparsity 0.5 top_k = torch.topk(tmps, int(num), sorted=True) threshold = top_k.values[-1] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned.abs().mean(1, keepdim=True).expand( tmp_pruned.shape) tmp = tmp_pruned.flatten() tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:conv.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) prune.custom_from_mask(conv, name='weight', mask=tmp_pruned) elif args.prune_type == 'channel': tmps = [] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned.abs().mean(0, keepdim=True).expand( tmp_pruned.shape) tmp = tmp_pruned.flatten() tmps.append(tmp) tmps = torch.cat(tmps) num = tmps.shape[0] * (1 - args.sparsity) #sparsity 0.5 top_k = torch.topk(tmps, int(num), sorted=True) threshold = top_k.values[-1] for n, conv in enumerate(net.modules()): if isinstance(conv, nn.Conv2d): if conv.weight.shape[1] <= 3: continue tmp_pruned = conv.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned.abs().mean(0, keepdim=True).expand( tmp_pruned.shape) tmp = tmp_pruned.flatten() tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:conv.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) prune.custom_from_mask(conv, name='weight', mask=tmp_pruned) print( f'model pruned!!(sparsity : {args.sparsity : .2f}, prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning' ) elif args.prune_method == 'uniform': assert False, 'uniform code is not ready' elif args.prune_method == 'dst': print( f'model pruned!!(prune_method : {args.prune_method}, prune_type : {args.prune_type}-level pruning' ) elif args.prune_method == None: print('Not pruned model training started!') # Training if args.KD: trainer.train(net, teacher_net, loss, device, train_loader, test_loader, optimizer=optimizer, scheduler=scheduler) else: trainer.train(net, loss, device, train_loader, test_loader, optimizer=optimizer, scheduler=scheduler) # save removed models filename = os.path.join(args.model_folder, 'pruned_models.pth') temp = torch.load(filename) temp_dict = OrderedDict() for i in temp: if ('orig' in i): value = temp[i] * temp[i.split('_orig')[0] + '_mask'] temp_dict[i.split('module.')[1].split('_orig')[0]] = value elif 'mask' not in i: temp_dict[i.split('module.')[1]] = temp[i] net2.load_state_dict(temp_dict) save_model(net2, os.path.join(args.model_folder, 'removed_models.pth')) print('saved removed models')
def __init__(self, in_planes, planes, num_bits, num_bits_weight, stride, type_prune, sparsity, layer_num, option='A'): super(BasicBlock, self).__init__() self.conv1 = QConv2d(in_planes, planes, num_bits, num_bits_weight, kernel_size=3, stride=stride, padding=1, bias=False) if type_prune == 'channel': self.conv1 = prune.ln_structured(self.conv1, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = self.conv1.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:self.conv1.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) self.conv1 = prune.custom_from_mask(self.conv1, name='weight', mask=tmp_pruned) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = QConv2d(planes, planes, num_bits, num_bits_weight, kernel_size=3, stride=1, padding=1, bias=False) if type_prune == 'channel': self.conv2 = prune.ln_structured(self.conv2, name='weight', amount=sparsity, n=2, dim=0) elif type_prune == 'group': width = 4 tmp_pruned = self.conv2.weight.data.clone() original_size = tmp_pruned.size() tmp_pruned = tmp_pruned.view(original_size[0], -1) append_size = width - tmp_pruned.shape[1] % width tmp_pruned = torch.cat((tmp_pruned, tmp_pruned[:, 0:append_size]), 1) tmp_pruned = tmp_pruned.view(tmp_pruned.shape[0], -1, width) tmp_pruned = tmp_pruned.pow(2.0).mean( 2, keepdim=True).pow(0.5).expand(tmp_pruned.shape) tmp = tmp_pruned.flatten() num = tmp.shape[0] * (1 - sparsity) top_k = torch.topk(tmp, int(num), sorted=True) threshold = top_k.values[-1] tmp_pruned = tmp_pruned.ge(threshold) tmp_pruned = tmp_pruned.view(original_size[0], -1) tmp_pruned = tmp_pruned[:, 0:self.conv2.weight.data[0].nelement()] tmp_pruned = tmp_pruned.contiguous().view(original_size) self.conv2 = prune.custom_from_mask(self.conv2, name='weight', mask=tmp_pruned) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: if option == 'A': """ For CIFAR10 ResNet paper uses option A. """ self.shortcut = LambdaLayer(lambda x: F.pad( x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) elif option == 'B': self.shortcut = nn.Sequential( QConv2d(in_planes, self.expansion * planes, num_bits, num_bits_weight, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes))