Ejemplo n.º 1
0
    def _init_model(self, checkpoint_path: str) -> None:
        """Create a model instance and load weights."""
        # load weights
        logger.info(f"Load weights from the checkpoint {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path,
                                map_location=torch.device("cpu"))

        state_dict = checkpoint["state_dict"]
        self.orig_acc = checkpoint["test_acc"]

        is_pruned = (next((name for name in state_dict if "mask" in name),
                          None) is not None)

        if is_pruned:
            logger.info("Dummy prunning to load pruned weights")
            model_utils.dummy_pruning(self.params_all)

        model_utils.initialize_params(self.model, state_dict)
        logger.info("Initialized weights")

        # check the trained model is pruned

        if is_pruned:
            logger.info(
                "Get masks and remove prunning reparameterization for prepare_qat"
            )
            self.mask = model_utils.get_masks(self.model)
            model_utils.remove_pruning_reparameterization(self.params_all)
Ejemplo n.º 2
0
    def _generate_reshaped_bn(self, bn: nn.BatchNorm2d,
                              mask_idx: torch.Tensor) -> nn.BatchNorm2d:
        """Generate new bn given old_bn, mask."""
        reshaped_bn = torch.nn.BatchNorm2d(
            num_features=mask_idx.size()[0],
            eps=bn.eps,
            momentum=bn.momentum,
            affine=bn.affine,
            track_running_stats=bn.track_running_stats,
        ).to(self.device)

        prune_bn: Tuple[Tuple[nn.BatchNorm2d, str],
                        Tuple[nn.BatchNorm2d, str]] = (
                            (reshaped_bn, "weight"),
                            (reshaped_bn, "bias"),
                        )
        model_utils.dummy_pruning(prune_bn)

        # set data to reshaped
        reshaped_bn.running_mean = torch.gather(bn.running_mean, 0,
                                                mask_idx)  # type: ignore
        reshaped_bn.running_var = torch.gather(bn.running_var, 0,
                                               mask_idx)  # type: ignore
        reshaped_bn.weight_mask.data = torch.gather(bn.weight_mask, 0,
                                                    mask_idx)  # type: ignore
        reshaped_bn.weight_orig.data = torch.gather(bn.weight_orig, 0,
                                                    mask_idx)  # type: ignore
        reshaped_bn.bias_mask.set_(torch.gather(bn.bias_mask, 0,
                                                mask_idx))  # type: ignore
        reshaped_bn.bias_orig.set_(torch.gather(bn.bias_orig, 0,
                                                mask_idx))  # type: ignore
        reshaped_bn.num_batches_tracked = bn.num_batches_tracked

        return reshaped_bn
Ejemplo n.º 3
0
    def _generate_reshaped_conv(
        self,
        in_mask: Optional[torch.Tensor],
        out_mask: torch.Tensor,
        conv: nn.Conv2d,
    ) -> nn.Conv2d:
        """Generate new conv given old conv and masks(in and out or out only)."""
        # Shrink both input, output channel of conv, and extract weight(orig, mask)
        [_, i, h, w] = getattr(conv, "weight").size()

        # Make mask for input
        if in_mask is not None:
            # make masking matrix[o, i]: in_mask.T * out_mask
            # mask_flattened : [o*i]
            mask_flattened = in_mask.unsqueeze(1).T * out_mask.unsqueeze(1)
            mask_flattened = mask_flattened.reshape(-1)
            mask_idx = (mask_flattened == 1).nonzero().view(-1, 1,
                                                            1).repeat(1, h, w)

            new_out = (out_mask == 1).nonzero().size()[0]
            new_in = (in_mask == 1).nonzero().size()[0]

            orig = conv.weight_orig.reshape(-1, h, w)  # type: ignore
            mask = conv.weight_mask.reshape(-1, h, w)  # type: ignore
            orig = torch.gather(orig, 0,
                                mask_idx).reshape(new_out, new_in, h, w)
            mask = torch.gather(mask, 0,
                                mask_idx).reshape(new_out, new_in, h, w)

        # Case only when there is out_mask
        else:
            # extract one masked index
            out_mask = (out_mask == 1).nonzero().view(-1, 1, 1).repeat(1, h, w)
            out_mask = out_mask.unsqueeze(1).repeat(1, i, 1, 1)

            orig = torch.gather(conv.weight_orig, 0, out_mask)  # type: ignore
            mask = torch.gather(conv.weight_mask, 0, out_mask)  # type: ignore

        # Create reshaped conv
        reshaped_conv = torch.nn.Conv2d(
            in_channels=mask.size()[1],
            out_channels=mask.size()[0],
            kernel_size=conv.kernel_size,  # type: ignore
            bias=conv.bias is not None,
            padding=conv.padding,  # type: ignore
            dilation=conv.dilation,  # type: ignore
            groups=conv.groups,
            stride=conv.stride,  # type: ignore
        ).to(self.device)

        # dummy prune to copy orig, mask to new conv
        # Note: pruned conv bias is not supported
        prune_conv = ((reshaped_conv, "weight"), )
        model_utils.dummy_pruning(prune_conv)

        # Overwrite data to new(reshaped) conv
        reshaped_conv.weight_orig.data = orig
        reshaped_conv.weight_mask.data = mask

        return reshaped_conv
Ejemplo n.º 4
0
    def _load_masks(self) -> None:
        """Load masks."""
        if not self.mask:
            return

        model_utils.dummy_pruning(self.params_all)
        for name, _ in self.model.named_buffers():
            if name in self.mask:
                module_name, mask_name = name.rsplit(".", 1)
                module = eval("self.model." + module_name)
                module._buffers[mask_name] = self.mask[name]
Ejemplo n.º 5
0
    def __init__(
        self,
        config: Dict[str, Any],
        dir_prefix: str,
        wandb_log: bool,
        wandb_init_params: Dict[str, Any],
        device: torch.device,
    ) -> None:
        """Initialize."""
        super(Pruner, self).__init__(config, dir_prefix)
        self.wandb_log = wandb_log
        self.pretrain_dir_name = "pretrain"
        self.dir_postfix = "pruned"
        self.init_params_name = "init_params"
        self.init_params_path = ""
        self.device = device

        self.plotter = Plotter(self.wandb_log)

        # create an initial model
        self.trainer = Trainer(
            config=self.config["TRAIN_CONFIG"],
            dir_prefix=dir_prefix,
            checkpt_dir=self.pretrain_dir_name,
            wandb_log=wandb_log,
            wandb_init_params=wandb_init_params,
            device=device,
        )
        self.model = self.trainer.model

        self.model_params = model_utils.get_params(
            self.model,
            (
                (nn.Conv2d, "weight"),
                (nn.Conv2d, "bias"),
                (nn.BatchNorm2d, "weight"),
                (nn.BatchNorm2d, "bias"),
                (nn.Linear, "weight"),
                (nn.Linear, "bias"),
            ),
        )
        self.params_to_prune = self.get_params_to_prune()

        # to calculate sparsity properly
        model_utils.dummy_pruning(self.model_params)
        model_utils.dummy_pruning(self.params_to_prune)
Ejemplo n.º 6
0
    def __init__(
        self,
        config: Dict[str, Any],
        checkpoint_path: str,
        dir_prefix: str,
        device: torch.device,
    ) -> None:
        """Initialize."""
        super(Shrinker, self).__init__(config, dir_prefix)
        self.train_config = self.config["TRAIN_CONFIG"]
        self.checkpoint_path = checkpoint_path
        self.device = device

        # create a trainer
        self.trainer = Trainer(
            config=self.config["TRAIN_CONFIG"],
            dir_prefix=dir_prefix,
            checkpt_dir="",
            device=self.device,
            wandb_log=False,
            wandb_init_params=None,
        )
        self.model = self.trainer.model

        # create adjacent module getter
        input_size = (1, *self.trainer.input_size)
        self.adjmodule_getter = AdjModuleGetter(self.model,
                                                input_size=input_size,
                                                device=self.device)

        # Note: model must have nn.Flatten to get last conv shape info
        self.last_conv_shape = self.adjmodule_getter.last_conv_shape

        # dummy pruning
        self.params_all = model_utils.get_params(
            self.model,
            (
                (nn.Conv2d, "weight"),
                (nn.Conv2d, "bias"),
                (nn.BatchNorm2d, "weight"),
                (nn.BatchNorm2d, "bias"),
                (nn.Linear, "weight"),
                (nn.Linear, "bias"),
            ),
        )
        model_utils.dummy_pruning(self.params_all)
Ejemplo n.º 7
0
    def _generate_reshaped_fc(self, mask: torch.Tensor,
                              fc: nn.Linear) -> nn.Linear:
        """Generate new fc given old fc, mask, last_conv_shape."""
        # expand considering last_dim
        # ex) bn_dim * last_dim * last_dim feed into NN
        # repeat last_dim * last_dim
        in_mask = torch.flatten(
            mask.view(-1, 1, 1).repeat(1, self.last_conv_shape,
                                       self.last_conv_shape))

        # Do shrink on fc
        in_features_size = int((in_mask == 1).sum())

        out_features, _ = fc.weight.size()
        weight_mask = in_mask.repeat(out_features)
        weight_mask_idx = (weight_mask == 1).nonzero().squeeze(1)
        weight = fc.weight.detach().clone()
        weight = torch.gather(torch.flatten(weight), 0,
                              weight_mask_idx).reshape(out_features,
                                                       in_features_size)

        reshaped_fc = torch.nn.Linear(
            in_features=in_features_size,
            out_features=out_features,
            bias=fc.bias is not None,
        ).to(self.device)
        param_to_prune = ((reshaped_fc, "weight"), )
        model_utils.dummy_pruning(param_to_prune)

        reshaped_fc.weight_orig.data = weight  # type: ignore
        reshaped_fc.weight_mask.data = torch.ones_like(weight)  # type: ignore
        # Note: this doesn't work if bias is pruned and dimension changed
        # Only available for the networks that have only one last fc
        if hasattr(fc, "bias"):
            reshaped_fc.bias.data = fc.bias
        prune.remove(reshaped_fc, "weight")

        return reshaped_fc