Exemple #1
0
    def train(self, epoch):
        """
        Train one epoch for auxnet
        :param epoch: index of epoch
        """

        top1_error = utils.AverageMeter()
        top1_loss = utils.AverageMeter()
        top5_error = utils.AverageMeter()

        iters = len(self.train_loader)
        self.update_lr(epoch)
        # Switch to train mode
        self.model.train()

        start_time = time.time()
        end_time = start_time

        for i, (images, labels) in enumerate(self.train_loader):
            start_time = time.time()
            data_time = start_time - end_time

            if self.settings.n_gpus == 1:
                images = images.cuda()
            labels = labels.cuda()

            # forward
            output, loss = self.forward(images, labels)
            self.backward(loss)

            # compute loss and error rate
            single_error, single_loss, single5_error = utils.compute_singlecrop_error(
                outputs=output, labels=labels,
                loss=loss, top5_flag=True)

            top1_error.update(single_error, images.size(0))
            top1_loss.update(single_loss, images.size(0))
            top5_error.update(single5_error, images.size(0))

            end_time = time.time()
            iter_time = end_time - start_time

            if i % self.settings.print_frequency == 0:
                utils.print_result(epoch, self.settings.n_epochs, i + 1,
                                   iters, self.lr, data_time, iter_time,
                                   single_error,
                                   single_loss,
                                   mode="Train",
                                   logger=self.logger)

        if self.tensorboard_logger is not None:
            self.tensorboard_logger.scalar_summary('train_top1_error', top1_error.avg, self.run_count)
            self.tensorboard_logger.scalar_summary('train_top5_error', top5_error.avg, self.run_count)
            self.tensorboard_logger.scalar_summary('train_loss', top1_loss.avg, self.run_count)
            self.tensorboard_logger.scalar_summary("lr", self.lr, self.run_count)

        self.logger.info("|===>Training Error: {:.4f} Loss: {:.4f}, Top5 Error: {:.4f}"
                         .format(top1_error.avg, top1_loss.avg, top5_error.avg))
        return top1_error.avg, top1_loss.avg, top5_error.avg
Exemple #2
0
    def val(self, epoch):
        """
        Validation
        """

        top1_error = utils.AverageMeter()
        top1_loss = utils.AverageMeter()
        top5_error = utils.AverageMeter()

        self.pruned_model.eval()

        iters = len(self.val_loader)
        start_time = time.time()
        end_time = start_time

        with torch.no_grad():
            for i, (images, labels) in enumerate(self.val_loader):
                start_time = time.time()
                data_time = start_time - end_time

                if self.settings.n_gpus == 1:
                    images = images.cuda()
                labels = labels.cuda()

                output, loss = self.forward(images, labels)

                single_error, single_loss, single5_error = utils.compute_singlecrop_error(
                    outputs=output, loss=loss,
                    labels=labels, top5_flag=True)

                top1_error.update(single_error, images.size(0))
                top1_loss.update(single_loss, images.size(0))
                top5_error.update(single5_error, images.size(0))

                end_time = time.time()
                iter_time = end_time - start_time

                if i % self.settings.print_frequency == 0:
                    utils.print_result(
                        epoch, self.settings.n_epochs, i + 1,
                        iters, self.lr, data_time, iter_time,
                        single_error, single_loss,
                        top5error=single5_error,
                        mode="Validation",
                        logger=self.logger)

        self.scalar_info['network_wise_fine_tune_val_top1_error'] = top1_error.avg
        self.scalar_info['network_wise_fine_tune_val_top5_error'] = top5_error.avg
        self.scalar_info['network_wise_fine_tune_val_loss'] = top1_loss.avg
        if self.tensorboard_logger is not None:
            for tag, value in self.scalar_info.items():
                self.tensorboard_logger.scalar_summary(tag, value, self.run_count)
            self.scalar_info = {}
        self.run_count += 1
        self.logger.info(
            "|===>Validation Error: {:.4f} Loss: {:.4f}, Top5 Error: {:.4f}".format(top1_error.avg, top1_loss.avg,
                                                                                    top5_error.avg))
        return top1_error.avg, top1_loss.avg, top5_error.avg
Exemple #3
0
    def val(self, epoch):
        """
        Validation
        :param epoch: index of epoch
        """

        top1_error = utils.AverageMeter()
        top1_loss = utils.AverageMeter()
        top5_error = utils.AverageMeter()

        self.model.eval()

        iters = len(self.val_loader)
        start_time = time.time()
        end_time = start_time

        with torch.no_grad():
            for i, (images, labels) in enumerate(self.val_loader):
                start_time = time.time()
                data_time = start_time - end_time

                if self.settings.n_gpus == 1:
                    images = images.cuda()
                labels = labels.cuda()

                output, loss = self.forward(images, labels)

                # compute loss and error rate
                single_error, single_loss, single5_error = utils.compute_singlecrop_error(
                    outputs=output, labels=labels,
                    loss=loss, top5_flag=True)

                top1_error.update(single_error, images.size(0))
                top1_loss.update(single_loss, images.size(0))
                top5_error.update(single5_error, images.size(0))

                end_time = time.time()
                iter_time = end_time - start_time

                if i % self.settings.print_frequency == 0:
                    utils.print_result(epoch, self.settings.n_epochs, i + 1,
                                       iters, self.lr, data_time, iter_time,
                                       single_error,
                                       single_loss,
                                       mode="Validation",
                                       logger=self.logger)

        if self.tensorboard_logger is not None:
            self.tensorboard_logger.scalar_summary("val_top1_error", top1_error.avg, self.run_count)
            self.tensorboard_logger.scalar_summary("val_top5_error", top5_error.avg, self.run_count)
            self.tensorboard_logger.scalar_summary("val_loss", top1_loss.avg, self.run_count)

        self.run_count += 1
        self.logger.info("|===>Testing Error: {:.4f} Loss: {:.4f}, Top5 Error: {:.4f}"
                         .format(top1_error.avg, top1_loss.avg, top5_error.avg))
        return top1_error.avg, top1_loss.avg, top5_error.avg
Exemple #4
0
    def __init__(self, trainer, train_loader, val_loader, settings, checkpoint,
                 logger, tensorboard_logger):
        self.segment_wise_trainer = trainer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.settings = settings
        self.checkpoint = checkpoint

        self.logger = logger
        self.tensorboard_logger = tensorboard_logger

        self.feature_cache_original_input = {}
        self.feature_cache_original_output = {}
        self.feature_cache_pruned_input = {}
        self.feature_cache_pruned_output = {}

        self.criterion_mse = nn.MSELoss().cuda()
        self.criterion_softmax = nn.CrossEntropyLoss().cuda()

        self.logger_counter = 0

        self.record_time = utils.AverageMeter()
        self.record_selection_mse_loss = utils.AverageMeter()
        self.record_selection_softmax_loss = utils.AverageMeter()
        self.record_selection_loss = utils.AverageMeter()
        self.record_sub_problem_softmax_loss = utils.AverageMeter()
        self.record_sub_problem_mse_loss = utils.AverageMeter()
        self.record_sub_problem_loss = utils.AverageMeter()
        self.record_sub_problem_top1_error = utils.AverageMeter()
        self.record_sub_problem_top5_error = utils.AverageMeter()
Exemple #5
0
    def __init__(self, trainer, train_loader, val_loader, settings, checkpoint, logger, tensorboard_logger):
        self.segment_wise_trainer = trainer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.settings = settings

        # liuxu adds, to remove pruning rate for shortcut, make value easy to get
        # length becomes 20, cause this code only do prune for conv2, so pruning rate of each layer = pruning_rate[num_block*2]
        self.settings.pruning_rate = self.settings.pruning_rate[:7]+self.settings.pruning_rate[8:14]+self.settings.pruning_rate[15:]\
            if len(self.settings.pruning_rate)>1 else self.settings.pruning_rate[0]

        self.checkpoint = checkpoint

        self.logger = logger
        self.tensorboard_logger = tensorboard_logger

        self.feature_cache_original_input = {}
        self.feature_cache_original_output = {}
        self.feature_cache_pruned_input = {}
        self.feature_cache_pruned_output = {}

        self.criterion_mse = nn.MSELoss().cuda()
        self.criterion_softmax = nn.CrossEntropyLoss().cuda()

        self.logger_counter = 0

        self.record_time = utils.AverageMeter()
        self.record_selection_mse_loss = utils.AverageMeter()
        self.record_selection_softmax_loss = utils.AverageMeter()
        self.record_selection_loss = utils.AverageMeter()
        self.record_sub_problem_softmax_loss = utils.AverageMeter()
        self.record_sub_problem_mse_loss = utils.AverageMeter()
        self.record_sub_problem_loss = utils.AverageMeter()
        self.record_sub_problem_top1_error = utils.AverageMeter()
        self.record_sub_problem_top5_error = utils.AverageMeter()
Exemple #6
0
    def val(self, epoch):
        """
        Validation
        :param epoch: index of epoch
        """

        top1_error = []
        top5_error = []
        top1_loss = []
        num_segments = len(self.segments)
        for i in range(num_segments):
            self.segments[i].eval()
            self.aux_fc[i].eval()
            top1_error.append(utils.AverageMeter())
            top5_error.append(utils.AverageMeter())
            top1_loss.append(utils.AverageMeter())

        iters = len(self.val_loader)

        start_time = time.time()
        end_time = start_time

        with torch.no_grad():
            for i, (images, labels, _) in enumerate(self.val_loader, start=1):
                start_time = time.time()
                data_time = start_time - end_time

                if self.settings.n_gpus == 1:
                    images = images.cuda()
                labels = labels.cuda()

                outputs, losses = self.auxnet_forward(images, labels)

                # compute loss and error rate
                single_error, single_loss, single5_error = utils.compute_singlecrop_error(
                    outputs=outputs, labels=labels,
                    loss=losses, top5_flag=True)

                for j in range(num_segments):
                    top1_error[j].update(single_error[j], images.size(0))
                    top5_error[j].update(single5_error[j], images.size(0))
                    top1_loss[j].update(single_loss[j], images.size(0))

                end_time = time.time()
                iter_time = end_time - start_time

                if i % self.settings.print_frequency == 0:
                    utils.print_result(epoch, self.settings.n_epochs, i + 1,
                                       iters, self.lr, data_time, iter_time,
                                       single_error,
                                       single_loss,
                                       mode="Validation",
                                       logger=self.logger)

        top1_error_list, top1_loss_list, top5_error_list = self._convert_results(
            top1_error=top1_error, top1_loss=top1_loss, top5_error=top5_error)

        if self.logger is not None:
            for i in range(num_segments):
                self.tensorboard_logger.scalar_summary(
                    "auxnet_val_top1_error_{:d}".format(i), top1_error[i].avg, self.run_count)
                self.tensorboard_logger.scalar_summary(
                    "auxnet_val_top5_error_{:d}".format(i), top5_error[i].avg, self.run_count)
                self.tensorboard_logger.scalar_summary(
                    "auxnet_val_loss_{:d}".format(i), top1_loss[i].avg, self.run_count)
        self.run_count += 1

        self.logger.info("|===>Validation Error: {:4f}/{:4f}, Loss: {:4f}".format(
            top1_error[-1].avg, top5_error[-1].avg, top1_loss[-1].avg))
        return top1_error_list, top1_loss_list, top5_error_list
Exemple #7
0
    def train(self, epoch):
        """
        Train one epoch for auxnet
        :param epoch: index of epoch
        """

        iters = len(self.train_loader)
        self.update_lr(epoch)

        top1_error = []
        top5_error = []
        top1_loss = []
        num_segments = len(self.segments)
        for i in range(num_segments):
            self.segments[i].train() # each segment and fc are treated as a model
            self.aux_fc[i].train()
            top1_error.append(utils.AverageMeter())
            top5_error.append(utils.AverageMeter())
            top1_loss.append(utils.AverageMeter())

        start_time = time.time()
        end_time = start_time

        for i, (images, labels, _) in enumerate(self.train_loader, start=1):
            start_time = time.time()
            data_time = start_time - end_time

            if self.settings.n_gpus == 1:
                images = images.cuda()
            labels = labels.cuda()

            # forward
            outputs, losses = self.auxnet_forward(images, labels)
            # backward
            for j in range(len(self.seg_optimizer)):
                self.auxnet_backward_for_loss_i(losses[j], j)

            # compute loss and error rate
            single_error, single_loss, single5_error = utils.compute_singlecrop_error(
                outputs=outputs, labels=labels,
                loss=losses, top5_flag=True)

            for j in range(num_segments):
                top1_error[j].update(single_error[j], images.size(0))
                top5_error[j].update(single5_error[j], images.size(0))
                top1_loss[j].update(single_loss[j], images.size(0))

            end_time = time.time()
            iter_time = end_time - start_time

            if i % self.settings.print_frequency == 0:
                utils.print_result(epoch, self.settings.n_epochs, i + 1,
                                   iters, self.lr, data_time, iter_time,
                                   single_error,
                                   single_loss,
                                   mode="Train",
                                   logger=self.logger)

        top1_error_list, top1_loss_list, top5_error_list = self._convert_results(
            top1_error=top1_error, top1_loss=top1_loss, top5_error=top5_error)
        if self.logger is not None:
            for i in range(num_segments):
                self.tensorboard_logger.scalar_summary(
                    "auxnet_train_top1_error_{:d}".format(i), top1_error[i].avg,
                    self.run_count)
                self.tensorboard_logger.scalar_summary(
                    "auxnet_train_top5_error_{:d}".format(i), top5_error[i].avg,
                    self.run_count)
                self.tensorboard_logger.scalar_summary(
                    "auxnet_train_loss_{:d}".format(i), top1_loss[i].avg, self.run_count)
            self.tensorboard_logger.scalar_summary("lr", self.lr, self.run_count)

        self.logger.info("|===>Training Error: {:4f}/{:4f}, Loss: {:4f}".format(
            top1_error[-1].avg, top5_error[-1].avg, top1_loss[-1].avg))
        return top1_error_list, top1_loss_list, top5_error_list
    def train(self, epoch):
        """
        training
        """

        top1_error = utils.AverageMeter()
        top1_loss = utils.AverageMeter()
        top5_error = utils.AverageMeter()

        iters = len(self.train_loader)
        self.update_lr(epoch)
        # switch to train mode
        self.pruned_model.train()

        start_time = time.time()
        end_time = start_time

        for i, (images, labels) in enumerate(self.train_loader):
            start_time = time.time()
            data_time = start_time - end_time

            if self.settings.n_gpus == 1:
                images = images.cuda()
            labels = labels.cuda()

            output, loss = self.forward(images, labels)
            self.backward(loss)

            single_error, single_loss, single5_error = utils.compute_singlecrop(
                outputs=output, labels=labels,
                loss=loss, top5_flag=True, mean_flag=True)

            top1_error.update(single_error, images.size(0))
            top1_loss.update(single_loss, images.size(0))
            top5_error.update(single5_error, images.size(0))

            end_time = time.time()
            iter_time = end_time - start_time

            utils.print_result(
                epoch, self.settings.network_wise_n_epochs, i + 1,
                iters, self.network_wise_lr, data_time, iter_time,
                single_error,
                single_loss, top5error=single5_error,
                mode="Train",
                logger=self.logger)

        self.scalar_info['network_wise_fine_tune_train_top1_error'] = top1_error.avg
        self.scalar_info['network_wise_fine_tune_train_top5_error'] = top5_error.avg
        self.scalar_info['network_wise_fine_tune_train_loss'] = top1_loss.avg
        self.scalar_info['network_wise_fine_tune_lr'] = self.network_wise_lr

        if self.tensorboard_logger is not None:
            for tag, value in list(self.scalar_info.items()):
                self.tensorboard_logger.scalar_summary(tag, value, self.run_count)
            self.scalar_info = {}

        self.logger.info(
            "|===>Training Error: {:.4f} Loss: {:.4f}, Top5 Error: {:.4f}".format(top1_error.avg, top1_loss.avg,
                                                                                  top5_error.avg))
        return top1_error.avg, top1_loss.avg, top5_error.avg
    def train(self, epoch, index):
        """
        train
        :param epoch: index of epoch
        :param index: index of segment
        """

        iters = len(self.train_loader)
        self.update_lr(epoch)

        top1_error = []
        top5_error = []
        top1_loss = []
        num_segments = len(self.pruned_segments)
        for i in range(num_segments):
            self.pruned_segments[i].train()
            if i != index and i != num_segments - 1:
                self.aux_fc[i].eval()
            else:
                self.aux_fc[i].train()
            top1_error.append(utils.AverageMeter())
            top5_error.append(utils.AverageMeter())
            top1_loss.append(utils.AverageMeter())

        start_time = time.time()
        end_time = start_time
        for i, (images, labels) in enumerate(self.train_loader):
            start_time = time.time()
            data_time = start_time - end_time

            if self.settings.n_gpus == 1:
                images = images.cuda()
            labels = labels.cuda()

            # forward
            outputs, losses = self.forward(images, labels)
            # backward
            self.backward(losses, index)

            # compute loss and error rate
            single_error, single_loss, single5_error = utils.compute_singlecrop(
                outputs=outputs, labels=labels,
                loss=losses, top5_flag=True, mean_flag=True)

            for j in range(num_segments):
                top1_error[j].update(single_error[j], images.size(0))
                top5_error[j].update(single5_error[j], images.size(0))
                top1_loss[j].update(single_loss[j], images.size(0))

            end_time = time.time()
            iter_time = end_time - start_time

            utils.print_result(epoch, self.settings.segment_wise_n_epochs, i + 1,
                               iters, self.segment_wise_lr, data_time, iter_time,
                               single_error,
                               single_loss,
                               mode="Train",
                               logger=self.logger)

        top1_error_list, top1_loss_list, top5_error_list = self._convert_results(
            top1_error=top1_error, top1_loss=top1_loss, top5_error=top5_error)
        if self.logger is not None:
            for i in range(num_segments):
                self.tensorboard_logger.scalar_summary(
                    "segment_wise_fine_tune_train_top1_error_{:d}".format(i), top1_error[i].avg,
                    self.run_count)
                self.tensorboard_logger.scalar_summary(
                    "segment_wise_fine_tune_train_top5_error_{:d}".format(i), top5_error[i].avg,
                    self.run_count)
                self.tensorboard_logger.scalar_summary(
                    "segment_wise_fine_tune_train_loss_{:d}".format(i), top1_loss[i].avg, self.run_count)
            self.tensorboard_logger.scalar_summary("segment_wise_fine_tune_lr", self.segment_wise_lr, self.run_count)

        self.logger.info("|===>Training Error: {:4f}/{:4f}, Loss: {:4f}".format(
            top1_error[-1].avg, top5_error[-1].avg, top1_loss[-1].avg))
        return top1_error_list, top1_loss_list, top5_error_list
    def _layer_channel_selection(self,
                                 net_origin,
                                 net_pruned,
                                 aux_fc,
                                 module,
                                 block_count,
                                 layer_name="conv2"):
        """
        conduct channel selection for module
        :param net_origin: original network segments
        :param net_pruned: pruned network segments
        :param aux_fc: auxiliary fully-connected layer
        :param module: the module need to be pruned
        :param block_count: current block no.
        :param layer_name: the name of layer need to be pruned
        """

        self.logger.info(
            "|===>layer-wise channel selection: block-{}-{}".format(
                block_count, layer_name))
        # layer-wise channel selection
        if layer_name == "conv2":
            layer = module.conv2
        elif layer_name == "conv3":
            layer = module.conv3
        else:
            assert False, "unsupport layer: {}".format(layer_name)

        if not isinstance(layer, MaskConv2d):
            temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                   out_channels=layer.out_channels,
                                   kernel_size=layer.kernel_size,
                                   stride=layer.stride,
                                   padding=layer.padding,
                                   bias=(layer.bias is not None))
            temp_conv.weight.data.copy_(layer.weight.data)

            if layer.bias is not None:
                temp_conv.bias.data.copy_(layer.bias.data)
            temp_conv.pruned_weight.data.fill_(0)
            temp_conv.d.fill_(0)

            if layer_name == "conv2":
                module.conv2 = temp_conv
            elif layer_name == "conv3":
                module.conv3 = temp_conv
            layer = temp_conv

        # define criterion
        criterion_mse = nn.MSELoss().cuda()
        criterion_softmax = nn.CrossEntropyLoss().cuda()

        # register hook
        if layer_name == "conv2":
            hook_origin = net_origin[block_count].conv2.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv2.register_forward_hook(
                self._hook_pruned_feature)
        elif layer_name == "conv3":
            hook_origin = net_origin[block_count].conv3.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv3.register_forward_hook(
                self._hook_pruned_feature)

        net_origin_parallel = utils.data_parallel(net_origin,
                                                  self.settings.n_gpus)
        net_pruned_parallel = utils.data_parallel(net_pruned,
                                                  self.settings.n_gpus)

        # avoid computing the gradient
        for params in net_origin_parallel.parameters():
            params.requires_grad = False
        for params in net_pruned_parallel.parameters():
            params.requires_grad = False

        net_origin_parallel.eval()
        net_pruned_parallel.eval()

        layer.pruned_weight.requires_grad = True
        aux_fc.cuda()
        logger_counter = 0
        record_time = utils.AverageMeter()

        for channel in range(layer.in_channels):
            if layer.d.eq(0).sum() <= math.floor(
                    layer.in_channels * self.settings.pruning_rate):
                break

            time_start = time.time()
            cum_grad = None
            record_selection_mse_loss = utils.AverageMeter()
            record_selection_softmax_loss = utils.AverageMeter()
            record_selection_loss = utils.AverageMeter()
            img_count = 0
            for i, (images, labels) in enumerate(self.train_loader):
                images = images.cuda()
                labels = labels.cuda()
                net_origin_parallel(images)
                output = net_pruned_parallel(images)
                softmax_loss = criterion_softmax(aux_fc(output), labels)

                origin_feature = self._concat_gpu_data(
                    self.feature_cache_origin)
                self.feature_cache_origin = {}
                pruned_feature = self._concat_gpu_data(
                    self.feature_cache_pruned)
                self.feature_cache_pruned = {}
                mse_loss = criterion_mse(pruned_feature, origin_feature)

                loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                loss.backward()
                record_selection_loss.update(loss.item(), images.size(0))
                record_selection_mse_loss.update(mse_loss.item(),
                                                 images.size(0))
                record_selection_softmax_loss.update(softmax_loss.item(),
                                                     images.size(0))

                if cum_grad is None:
                    cum_grad = layer.pruned_weight.grad.data.clone()
                else:
                    cum_grad.add_(layer.pruned_weight.grad.data)
                    layer.pruned_weight.grad = None

                img_count += images.size(0)
                if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                    break

            # write tensorboard log
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_LossAll".format(block_count, layer_name),
                value=record_selection_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_selection_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_selection_softmax_loss.avg,
                step=logger_counter)
            cum_grad.abs_()
            # calculate gradient F norm
            grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0)

            # find grad_fnorm with maximum absolute gradient
            while True:
                _, max_index = torch.topk(grad_fnorm, 1)
                if layer.d[max_index[0]] == 0:
                    layer.d[max_index[0]] = 1
                    layer.pruned_weight.data[:, max_index[
                        0], :, :] = layer.weight[:,
                                                 max_index[0], :, :].data.clone(
                                                 )
                    break
                else:
                    grad_fnorm[max_index[0]] = -1

            # fine-tune average meter
            record_finetune_softmax_loss = utils.AverageMeter()
            record_finetune_mse_loss = utils.AverageMeter()
            record_finetune_loss = utils.AverageMeter()

            record_finetune_top1_error = utils.AverageMeter()
            record_finetune_top5_error = utils.AverageMeter()

            # define optimizer
            params_list = []
            params_list.append({
                "params": layer.pruned_weight,
                "lr": self.settings.layer_wise_lr
            })
            if layer.bias is not None:
                layer.bias.requires_grad = True
                params_list.append({"params": layer.bias, "lr": 0.001})
            optimizer = torch.optim.SGD(
                params=params_list,
                weight_decay=self.settings.weight_decay,
                momentum=self.settings.momentum,
                nesterov=True)
            img_count = 0
            for epoch in range(1):
                for i, (images, labels) in enumerate(self.train_loader):
                    images = images.cuda()
                    labels = labels.cuda()
                    features = net_pruned_parallel(images)
                    net_origin_parallel(images)
                    output = aux_fc(features)
                    softmax_loss = criterion_softmax(output, labels)

                    origin_feature = self._concat_gpu_data(
                        self.feature_cache_origin)
                    self.feature_cache_origin = {}
                    pruned_feature = self._concat_gpu_data(
                        self.feature_cache_pruned)
                    self.feature_cache_pruned = {}
                    mse_loss = criterion_mse(pruned_feature, origin_feature)

                    top1_error, _, top5_error = utils.compute_singlecrop(
                        outputs=output,
                        labels=labels,
                        loss=softmax_loss,
                        top5_flag=True,
                        mean_flag=True)

                    # update parameters
                    optimizer.zero_grad()
                    loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(layer.parameters(),
                                                   max_norm=10.0)
                    layer.pruned_weight.grad.data.mul_(
                        layer.d.unsqueeze(0).unsqueeze(2).unsqueeze(
                            3).expand_as(layer.pruned_weight))
                    optimizer.step()
                    # update record info
                    record_finetune_softmax_loss.update(
                        softmax_loss.item(), images.size(0))
                    record_finetune_mse_loss.update(mse_loss.item(),
                                                    images.size(0))
                    record_finetune_loss.update(loss.item(), images.size(0))
                    record_finetune_top1_error.update(top1_error,
                                                      images.size(0))
                    record_finetune_top5_error.update(top5_error,
                                                      images.size(0))

                    img_count += images.size(0)
                    if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                        break

            layer.pruned_weight.grad = None
            if layer.bias is not None:
                layer.bias.requires_grad = False

            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_finetune_softmax_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Loss".format(block_count, layer_name),
                value=record_finetune_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_finetune_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top1Error".format(block_count, layer_name),
                value=record_finetune_top1_error.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top5Error".format(block_count, layer_name),
                value=record_finetune_top5_error.avg,
                step=logger_counter)

            # write log information to file
            self._write_log(
                dir_name=os.path.join(self.settings.save_path, "log"),
                file_name="log_block-{:0>2d}_{}.txt".format(
                    block_count, layer_name),
                log_str=
                "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n".
                format(int(layer.d.sum()), record_selection_loss.avg,
                       record_selection_mse_loss.avg,
                       record_selection_softmax_loss.avg,
                       record_finetune_loss.avg, record_finetune_mse_loss.avg,
                       record_finetune_softmax_loss.avg,
                       record_finetune_top1_error.avg,
                       record_finetune_top5_error.avg))
            log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format(
                block_count, layer_name, int(layer.d.sum()), layer.d.size(0))
            log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_selection_loss.avg, record_selection_mse_loss.avg,
                record_selection_softmax_loss.avg)
            log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_finetune_loss.avg, record_finetune_mse_loss.avg,
                record_finetune_softmax_loss.avg)
            log_str += "top1error: {:4f}\ttop5error: {:4f}".format(
                record_finetune_top1_error.avg, record_finetune_top5_error.avg)
            self.logger.info(log_str)

            logger_counter += 1
            time_interval = time.time() - time_start
            record_time.update(time_interval)

        for params in net_origin_parallel.parameters():
            params.requires_grad = True
        for params in net_pruned_parallel.parameters():
            params.requires_grad = True

        # remove hook
        hook_origin.remove()
        hook_pruned.remove()
        log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format(
            block_count, layer_name,
            str(datetime.timedelta(seconds=record_time.sum)),
            str(datetime.timedelta(seconds=record_time.avg)))
        self.logger.info(log_str)
        log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format(
            record_finetune_loss.avg, record_finetune_mse_loss.avg,
            record_finetune_softmax_loss.avg, record_finetune_top1_error.avg,
            record_finetune_top5_error.avg)
        self.logger.info(log_str)

        self.logger.info("|===>remove hook")