def _init_model(self, checkpoint_path: str) -> None: """Create a model instance and load weights.""" # load weights logger.info(f"Load weights from the checkpoint {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) state_dict = checkpoint["state_dict"] self.orig_acc = checkpoint["test_acc"] is_pruned = (next((name for name in state_dict if "mask" in name), None) is not None) if is_pruned: logger.info("Dummy prunning to load pruned weights") model_utils.dummy_pruning(self.params_all) model_utils.initialize_params(self.model, state_dict) logger.info("Initialized weights") # check the trained model is pruned if is_pruned: logger.info( "Get masks and remove prunning reparameterization for prepare_qat" ) self.mask = model_utils.get_masks(self.model) model_utils.remove_pruning_reparameterization(self.params_all)
def _generate_reshaped_bn(self, bn: nn.BatchNorm2d, mask_idx: torch.Tensor) -> nn.BatchNorm2d: """Generate new bn given old_bn, mask.""" reshaped_bn = torch.nn.BatchNorm2d( num_features=mask_idx.size()[0], eps=bn.eps, momentum=bn.momentum, affine=bn.affine, track_running_stats=bn.track_running_stats, ).to(self.device) prune_bn: Tuple[Tuple[nn.BatchNorm2d, str], Tuple[nn.BatchNorm2d, str]] = ( (reshaped_bn, "weight"), (reshaped_bn, "bias"), ) model_utils.dummy_pruning(prune_bn) # set data to reshaped reshaped_bn.running_mean = torch.gather(bn.running_mean, 0, mask_idx) # type: ignore reshaped_bn.running_var = torch.gather(bn.running_var, 0, mask_idx) # type: ignore reshaped_bn.weight_mask.data = torch.gather(bn.weight_mask, 0, mask_idx) # type: ignore reshaped_bn.weight_orig.data = torch.gather(bn.weight_orig, 0, mask_idx) # type: ignore reshaped_bn.bias_mask.set_(torch.gather(bn.bias_mask, 0, mask_idx)) # type: ignore reshaped_bn.bias_orig.set_(torch.gather(bn.bias_orig, 0, mask_idx)) # type: ignore reshaped_bn.num_batches_tracked = bn.num_batches_tracked return reshaped_bn
def _generate_reshaped_conv( self, in_mask: Optional[torch.Tensor], out_mask: torch.Tensor, conv: nn.Conv2d, ) -> nn.Conv2d: """Generate new conv given old conv and masks(in and out or out only).""" # Shrink both input, output channel of conv, and extract weight(orig, mask) [_, i, h, w] = getattr(conv, "weight").size() # Make mask for input if in_mask is not None: # make masking matrix[o, i]: in_mask.T * out_mask # mask_flattened : [o*i] mask_flattened = in_mask.unsqueeze(1).T * out_mask.unsqueeze(1) mask_flattened = mask_flattened.reshape(-1) mask_idx = (mask_flattened == 1).nonzero().view(-1, 1, 1).repeat(1, h, w) new_out = (out_mask == 1).nonzero().size()[0] new_in = (in_mask == 1).nonzero().size()[0] orig = conv.weight_orig.reshape(-1, h, w) # type: ignore mask = conv.weight_mask.reshape(-1, h, w) # type: ignore orig = torch.gather(orig, 0, mask_idx).reshape(new_out, new_in, h, w) mask = torch.gather(mask, 0, mask_idx).reshape(new_out, new_in, h, w) # Case only when there is out_mask else: # extract one masked index out_mask = (out_mask == 1).nonzero().view(-1, 1, 1).repeat(1, h, w) out_mask = out_mask.unsqueeze(1).repeat(1, i, 1, 1) orig = torch.gather(conv.weight_orig, 0, out_mask) # type: ignore mask = torch.gather(conv.weight_mask, 0, out_mask) # type: ignore # Create reshaped conv reshaped_conv = torch.nn.Conv2d( in_channels=mask.size()[1], out_channels=mask.size()[0], kernel_size=conv.kernel_size, # type: ignore bias=conv.bias is not None, padding=conv.padding, # type: ignore dilation=conv.dilation, # type: ignore groups=conv.groups, stride=conv.stride, # type: ignore ).to(self.device) # dummy prune to copy orig, mask to new conv # Note: pruned conv bias is not supported prune_conv = ((reshaped_conv, "weight"), ) model_utils.dummy_pruning(prune_conv) # Overwrite data to new(reshaped) conv reshaped_conv.weight_orig.data = orig reshaped_conv.weight_mask.data = mask return reshaped_conv
def _load_masks(self) -> None: """Load masks.""" if not self.mask: return model_utils.dummy_pruning(self.params_all) for name, _ in self.model.named_buffers(): if name in self.mask: module_name, mask_name = name.rsplit(".", 1) module = eval("self.model." + module_name) module._buffers[mask_name] = self.mask[name]
def __init__( self, config: Dict[str, Any], dir_prefix: str, wandb_log: bool, wandb_init_params: Dict[str, Any], device: torch.device, ) -> None: """Initialize.""" super(Pruner, self).__init__(config, dir_prefix) self.wandb_log = wandb_log self.pretrain_dir_name = "pretrain" self.dir_postfix = "pruned" self.init_params_name = "init_params" self.init_params_path = "" self.device = device self.plotter = Plotter(self.wandb_log) # create an initial model self.trainer = Trainer( config=self.config["TRAIN_CONFIG"], dir_prefix=dir_prefix, checkpt_dir=self.pretrain_dir_name, wandb_log=wandb_log, wandb_init_params=wandb_init_params, device=device, ) self.model = self.trainer.model self.model_params = model_utils.get_params( self.model, ( (nn.Conv2d, "weight"), (nn.Conv2d, "bias"), (nn.BatchNorm2d, "weight"), (nn.BatchNorm2d, "bias"), (nn.Linear, "weight"), (nn.Linear, "bias"), ), ) self.params_to_prune = self.get_params_to_prune() # to calculate sparsity properly model_utils.dummy_pruning(self.model_params) model_utils.dummy_pruning(self.params_to_prune)
def __init__( self, config: Dict[str, Any], checkpoint_path: str, dir_prefix: str, device: torch.device, ) -> None: """Initialize.""" super(Shrinker, self).__init__(config, dir_prefix) self.train_config = self.config["TRAIN_CONFIG"] self.checkpoint_path = checkpoint_path self.device = device # create a trainer self.trainer = Trainer( config=self.config["TRAIN_CONFIG"], dir_prefix=dir_prefix, checkpt_dir="", device=self.device, wandb_log=False, wandb_init_params=None, ) self.model = self.trainer.model # create adjacent module getter input_size = (1, *self.trainer.input_size) self.adjmodule_getter = AdjModuleGetter(self.model, input_size=input_size, device=self.device) # Note: model must have nn.Flatten to get last conv shape info self.last_conv_shape = self.adjmodule_getter.last_conv_shape # dummy pruning self.params_all = model_utils.get_params( self.model, ( (nn.Conv2d, "weight"), (nn.Conv2d, "bias"), (nn.BatchNorm2d, "weight"), (nn.BatchNorm2d, "bias"), (nn.Linear, "weight"), (nn.Linear, "bias"), ), ) model_utils.dummy_pruning(self.params_all)
def _generate_reshaped_fc(self, mask: torch.Tensor, fc: nn.Linear) -> nn.Linear: """Generate new fc given old fc, mask, last_conv_shape.""" # expand considering last_dim # ex) bn_dim * last_dim * last_dim feed into NN # repeat last_dim * last_dim in_mask = torch.flatten( mask.view(-1, 1, 1).repeat(1, self.last_conv_shape, self.last_conv_shape)) # Do shrink on fc in_features_size = int((in_mask == 1).sum()) out_features, _ = fc.weight.size() weight_mask = in_mask.repeat(out_features) weight_mask_idx = (weight_mask == 1).nonzero().squeeze(1) weight = fc.weight.detach().clone() weight = torch.gather(torch.flatten(weight), 0, weight_mask_idx).reshape(out_features, in_features_size) reshaped_fc = torch.nn.Linear( in_features=in_features_size, out_features=out_features, bias=fc.bias is not None, ).to(self.device) param_to_prune = ((reshaped_fc, "weight"), ) model_utils.dummy_pruning(param_to_prune) reshaped_fc.weight_orig.data = weight # type: ignore reshaped_fc.weight_mask.data = torch.ones_like(weight) # type: ignore # Note: this doesn't work if bias is pruned and dimension changed # Only available for the networks that have only one last fc if hasattr(fc, "bias"): reshaped_fc.bias.data = fc.bias prune.remove(reshaped_fc, "weight") return reshaped_fc