예제 #1
0
class ExperimentDesign:
    def __init__(self, options=None):
        self.settings = options or Option()
        self.checkpoint = None
        self.data_loader = None
        self.model = None

        self.optimizer_state = None
        self.trainer = None
        self.start_epoch = 0

        self.model_analyse = None

        self.visualize = vs.Visualization(self.settings.save_path)
        self.logger = vs.Logger(self.settings.save_path)
        self.test_input = None
        self.lr_master = None
        self.prepare()

    def prepare(self):
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_parallel()
        self._set_lr_policy()
        self._set_trainer()
        self._draw_net()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        print("|===>Set GPU done!")

    def _set_dataloader(self):
        # create data loader
        self.data_loader = DataLoader(dataset=self.settings.dataset,
                                      batch_size=self.settings.batchSize,
                                      data_path=self.settings.dataPath,
                                      n_threads=self.settings.nThreads,
                                      ten_crop=self.settings.tenCrop)
        print("|===>Set data loader done!")

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint(
                self.settings.resume)
            self.model = self.checkpoint.load_state(self.model, model_state)
            # self.start_epoch = epoch
            # self.optimizer_state = optimizer_state
        print("|===>Set checkpoint done!")

    def _set_model(self):
        if self.settings.dataset == "sphere":
            if self.settings.netType == "SphereNet":
                self.model = md.SphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereNIN":
                self.model = md.SphereNIN(
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "wcSphereNet":
                self.model = md.wcSphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim,
                    rate=self.settings.rate)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupport data set: " + self.settings.dataset
        print("|===>Set model done!")

    def _set_parallel(self):
        self.model = utils.data_parallel(self.model, self.settings.nGPU,
                                         self.settings.GPU)

    def _set_lr_policy(self):
        self.lr_master = utils.LRPolicy(self.settings.lr, self.settings.nIters,
                                        self.settings.lrPolicy)
        params_dict = {
            'gamma': self.settings.gamma,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        self.lr_master.set_params(params_dict=params_dict)

    def _set_trainer(self):

        train_loader, test_loader = self.data_loader.getloader()
        self.trainer = Trainer(model=self.model,
                               lr_master=self.lr_master,
                               n_epochs=self.settings.nEpochs,
                               n_iters=self.settings.nIters,
                               train_loader=train_loader,
                               test_loader=test_loader,
                               feature_dim=self.settings.featureDim,
                               momentum=self.settings.momentum,
                               weight_decay=self.settings.weightDecay,
                               optimizer_state=self.optimizer_state,
                               logger=self.logger)

    def _draw_net(self):
        # visualize model

        if self.settings.dataset == "sphere":
            rand_input = torch.randn(1, 3, 112, 96)
        else:
            assert False, "invalid data set"
        rand_input = Variable(rand_input.cuda())
        self.test_input = rand_input

        if self.settings.drawNetwork:
            rand_output, _ = self.trainer.forward(rand_input)
            self.visualize.save_network(rand_output)
            print("|===>Draw network done!")

        self.visualize.write_settings(self.settings)

    def pruning(self, run_count=0):
        net_type = None
        if self.settings.dataset == "sphere":
            if self.settings.netType == "wcSphereNet":
                net_type = "SphereNet"

        assert net_type is not None, "net_type for prune is NoneType"

        self.trainer.test()

        if isinstance(self.model, nn.DataParallel):
            model = self.model.module
        else:
            model = self.model

        if net_type == "SphereNet":
            model_prune = prune.SpherePrune(model)
        model_prune.run()
        self.trainer.reset_model(model_prune.model)
        self.model = self.trainer.model

        self.trainer.test()
        self.checkpoint.save_model(self.trainer.model,
                                   index=run_count,
                                   tag="pruning")

        # analyse model
        self.model_analyse = utils.ModelAnalyse(self.trainer.model,
                                                self.visualize)
        params_num = self.model_analyse.params_count()
        self.model_analyse.flops_compute(self.test_input)

    def fine_tuning(self, run_count=0):
        # set lr
        self.settings.lr = 0.01  #  0.1
        self.settings.nIters = 12000  # 28000
        self.settings.lrPolicy = "multi_step"
        self.settings.decayRate = 0.1
        self.settings.step = [6000]  #  [16000, 24000]

        self._set_lr_policy()
        self.trainer.reset_lr(self.lr_master, self.settings.nIters)

        # run fine-tuning
        self.training(run_count, tag="fine-tuning")

    def retrain(self, run_count=0):
        self.settings.lr = 0.1
        self.settings.nIters = 28000
        self.settings.lrPolicy = "multi_step"
        self.settings.decayRate = 0.1
        self.settings.step = [16000, 24000]
        self._set_lr_policy()
        self.trainer.reset_lr(self.lr_master, self.settings.nIters)

        # run retrain
        self.training(run_count, tag="training")

    def run(self, run_count=0):
        """
        if run_count == 0:
            print "|===> training"
            self.retrain(run_count)
        else:     
            print "|===> fine-tuning"
            self.fine_tuning(run_count)
        """
        if run_count >= 1:
            print("|===> training")
            self.retrain(run_count)

            self.trainer.reset_model(self.model)
            print("|===> fine-tuning")
            self.fine_tuning(run_count)

        print("|===> pruning")
        self.pruning(run_count)

        # keep margin_linear static
        layer_count = 0
        for layer in self.model.modules():
            if isinstance(layer, md.MarginLinear):
                layer.iteration.fill_(0)
                layer.margin_type.data.fill_(1)
                layer.weight.requires_grad = False

            elif isinstance(layer, nn.Linear):
                layer.weight.requires_grad = False
                layer.bias.requires_grad = False

            elif isinstance(layer, nn.Conv2d):
                if layer.bias is not None:
                    bias_flag = True
                else:
                    bias_flag = False
                new_layer = prune.wcConv2d(layer.weight.size(1),
                                           layer.weight.size(0),
                                           kernel_size=layer.kernel_size,
                                           stride=layer.stride,
                                           padding=layer.padding,
                                           bias=bias_flag,
                                           rate=self.settings.rate)
                new_layer.weight.data.copy_(layer.weight.data)
                if layer.bias is not None:
                    new_layer.bias.data.copy_(layer.bias.data)
                if layer_count == 1:
                    self.model.conv2 = new_layer
                elif layer_count == 2:
                    self.model.conv3 = new_layer
                elif layer_count == 3:
                    self.model.conv4 = new_layer
                layer_count += 1

        print(self.model)
        self.trainer.reset_model(self.model)
        # assert False

    def training(self, run_count=0, tag="training"):
        best_top1 = 100
        # start_time = time.time()
        self.trainer.test()
        # assert False
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            if self.trainer.iteration >= self.trainer.n_iters:
                break
            start_epoch = 0
            # training and testing
            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            acc_mean, acc_std, acc_all = self.trainer.test()

            test_error_mean = 100 - acc_mean * 100
            # write and print result
            log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                epoch, train_error, train_loss, train5_error, acc_mean,
                acc_std)
            for acc in acc_all:
                log_str += "%.4f\t" % acc

            self.visualize.write_log(log_str)
            best_flag = False
            if best_top1 >= test_error_mean:
                best_top1 = test_error_mean
                best_flag = True
                print(
                    colored(
                        "# %d ==>Best Result is: Top1 Error: %f\n" %
                        (run_count, best_top1), "red"))
            else:
                print(
                    colored(
                        "# %d ==>Best Result is: Top1 Error: %f\n" %
                        (run_count, best_top1), "blue"))

            self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer,
                                            epoch)

            if best_flag:
                self.checkpoint.save_model(self.model,
                                           best_flag=best_flag,
                                           tag="%s_%d" % (tag, run_count))

            if (epoch + 1) % self.settings.drawInterval == 0:
                self.visualize.draw_curves()

            for name, value in self.model.named_parameters():
                if 'weight' in name:
                    name = name.replace('.', '/')
                    self.logger.histo_summary(
                        name,
                        value.data.cpu().numpy(),
                        run_count * self.settings.nEpochs + epoch + 1)
                    if value.grad is not None:
                        self.logger.histo_summary(
                            name + "/grad",
                            value.grad.data.cpu().numpy(),
                            run_count * self.settings.nEpochs + epoch + 1)

        # end_time = time.time()

        if isinstance(self.model, nn.DataParallel):
            self.model = self.model.module
        # draw experimental curves
        self.visualize.draw_curves()

        # compute cost time
        # time_interval = end_time - start_time
        # t_string = "Running Time is: " + \
        #    str(datetime.timedelta(seconds=time_interval)) + "\n"
        # print(t_string)
        # write cost time to file
        # self.visualize.write_readme(t_string)

        # save experimental results
        self.model_analyse = utils.ModelAnalyse(self.trainer.model,
                                                self.visualize)
        self.visualize.write_readme("Best Result of all is: Top1 Error: %f\n" %
                                    best_top1)

        # analyse model
        params_num = self.model_analyse.params_count()

        # save analyse result to file
        self.visualize.write_readme("Number of parameters is: %d" %
                                    (params_num))
        self.model_analyse.prune_rate()

        self.model_analyse.flops_compute(self.test_input)

        return best_top1
예제 #2
0
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=config.step,
                                                     gamma=0.1)

    # Set trainer
    logger = Logger(config.save_path)
    trainer = AlexNetTrainer(config.lr, train_loader, valid_loader, model,
                             optimizer, scheduler, logger, device)

    print(model)

    time_start = time.time()

    for epoch in range(1, config.nEpochs + 1):
        cls_loss_, accuracy, accuracy_valid = trainer.train(epoch)
        checkpoint.save_model(model, index=epoch)

    time_end = time.time()
    print(time_end - time_start)

    # #计算样本量
    # imagedb.count_num_sample()
    #
    # #记录字符集
    # imagedb.write_char_set()
    #
    # #随机抽取部分字符为模型使用
    # imagedb.random_select()
    #
    # #加载数据
    # train_dataset, valid_dataset = imagedb.load_data()
예제 #3
0
device = torch.device("cuda" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# Set dataloader
kwargs = {'num_workers': config.nThreads, 'pin_memory': True} if use_cuda else {}
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_loader = torch.utils.data.DataLoader(
    FaceDataset(config.annoPath, transform=transform, is_train=True), batch_size=config.batchSize, shuffle=True, **kwargs)

# Set model
model = ONet()
model = model.to(device)

# Set checkpoint
checkpoint = CheckPoint(config.save_path)

# Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.step, gamma=0.1)

# Set trainer
logger = Logger(config.save_path)
trainer = ONetTrainer(config.lr, train_loader, model, optimizer, scheduler, logger, device)

for epoch in range(1, config.nEpochs + 1):
    trainer.train(epoch)
    checkpoint.save_model(model, index=epoch)
예제 #4
0
class ExperimentDesign(object):
    """
    run experiments with pre-defined pipeline
    """
    def __init__(self):
        self.settings = Option()
        self.checkpoint = None
        self.data_loader = None
        self.model = None

        self.trainer = None
        self.seg_opt_state = None
        self.fc_opt_state = None
        self.aux_fc_state = None
        self.start_epoch = 0

        self.model_analyse = None

        self.visualize = vs.Visualization(self.settings.save_path)
        self.logger = vs.Logger(self.settings.save_path)

        self.prepare()

    def prepare(self):
        """
        preparing experiments
        """
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainer()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        cudnn.benchmark = True

    def _set_dataloader(self):
        # create data loader
        self.data_loader = DataLoader(dataset=self.settings.dataset,
                                      batch_size=self.settings.batchSize,
                                      data_path=self.settings.dataPath,
                                      n_threads=self.settings.nThreads,
                                      ten_crop=self.settings.tenCrop)

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            model_state = check_point_params["model"]
            self.seg_opt_state = check_point_params["seg_opt"]
            self.fc_opt_state = check_point_params["fc_opt"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = 90

    def _set_model(self):
        print("netType:", self.settings.netType)
        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.netType == "DARTSNet":
                genotype = md.genotypes.DARTS
                self.model = md.DARTSNet(self.settings.init_channels,
                                         self.settings.nClasses,
                                         self.settings.layers,
                                         self.settings.auxiliary, genotype)
            elif self.settings.netType == "PreResNet":
                self.model = md.PreResNet(depth=self.settings.depth,
                                          num_classes=self.settings.nClasses,
                                          wide_factor=self.settings.wideFactor)
            elif self.settings.netType == "CifarResNeXt":
                self.model = md.CifarResNeXt(self.settings.cardinality,
                                             self.settings.depth,
                                             self.settings.nClasses,
                                             self.settings.base_width,
                                             self.settings.widen_factor)

            elif self.settings.netType == "ResNet":
                self.model = md.ResNet(self.settings.depth,
                                       self.settings.nClasses)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupported data set: " + self.settings.dataset

    def _set_trainer(self):
        # set lr master
        lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs,
                                   self.settings.lrPolicy)
        params_dict = {
            'power': self.settings.power,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        lr_master.set_params(params_dict=params_dict)
        # set trainer
        train_loader, test_loader = self.data_loader.getloader()
        self.trainer = Trainer(model=self.model,
                               train_loader=train_loader,
                               test_loader=test_loader,
                               lr_master=lr_master,
                               settings=self.settings,
                               logger=self.logger)
        if self.seg_opt_state is not None:
            self.trainer.resume(aux_fc_state=self.aux_fc_state,
                                seg_opt_state=self.seg_opt_state,
                                fc_opt_state=self.fc_opt_state)

    def _save_best_model(self, model, aux_fc):

        check_point_params = {}
        if isinstance(model, nn.DataParallel):
            check_point_params["model"] = model.module.state_dict()
        else:
            check_point_params["model"] = model.state_dict()
        aux_fc_state = []
        for i in range(len(aux_fc)):
            if isinstance(aux_fc[i], nn.DataParallel):
                aux_fc_state.append(aux_fc[i].module.state_dict())
            else:
                aux_fc_state.append(aux_fc[i].state_dict())
        check_point_params["aux_fc"] = aux_fc_state
        torch.save(
            check_point_params,
            os.path.join(self.checkpoint.save_path,
                         "best_model_with_AuxFC.pth"))

    def _save_checkpoint(self,
                         model,
                         seg_optimizers,
                         fc_optimizers,
                         aux_fc,
                         index=0):
        check_point_params = {}
        if isinstance(model, nn.DataParallel):
            check_point_params["model"] = model.module.state_dict()
        else:
            check_point_params["model"] = model.state_dict()
        seg_opt_state = []
        fc_opt_state = []
        aux_fc_state = []
        for i in range(len(seg_optimizers)):
            seg_opt_state.append(seg_optimizers[i].state_dict())
        for i in range(len(fc_optimizers)):
            fc_opt_state.append(fc_optimizers[i].state_dict())
            if isinstance(aux_fc[i], nn.DataParallel):
                aux_fc_state.append(aux_fc[i].module.state_dict())
            else:
                aux_fc_state.append(aux_fc[i].state_dict())

        check_point_params["seg_opt"] = seg_opt_state
        check_point_params["fc_opt"] = fc_opt_state
        check_point_params["aux_fc"] = aux_fc_state

        torch.save(
            check_point_params,
            os.path.join(self.checkpoint.save_path,
                         "checkpoint_%03d.pth" % index))

    def run(self, run_count=0):
        """
        training and testing
        """

        best_top1 = 100
        best_top5 = 100
        start_time = time.time()
        if self.settings.resume is not None or self.settings.retrain is not None:
            self.trainer.test(0)
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            self.start_epoch = 0

            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            test_error, test_loss, test5_error = self.trainer.test(epoch=epoch)

            # write and print result
            if isinstance(train_error, np.ndarray):
                log_str = "%d\t" % (epoch)
                for i in range(len(train_error)):
                    log_str += "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                        train_error[i],
                        train_loss[i],
                        test_error[i],
                        test_loss[i],
                        train5_error[i],
                        test5_error[i],
                    )

                best_flag = False
                if best_top1 >= test_error[-1]:
                    best_top1 = test_error[-1]
                    best_top5 = test5_error[-1]
                    best_flag = True

            else:
                log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                    epoch, train_error, train_loss, test_error, test_loss,
                    train5_error, test5_error)
                best_flag = False
                if best_top1 >= test_error:
                    best_top1 = test_error
                    best_top5 = test5_error
                    best_flag = True

            self.visualize.write_log(log_str)

            if best_flag:
                self.checkpoint.save_model(self.trainer.model,
                                           best_flag=best_flag)
                self._save_best_model(self.trainer.model, self.trainer.auxfc)

                print colored(
                    "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n"
                    % (run_count, best_top1, best_top5), "red")
            else:
                print colored(
                    "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n"
                    % (run_count, best_top1, best_top5), "blue")

            self._save_checkpoint(self.trainer.model,
                                  self.trainer.seg_optimizer,
                                  self.trainer.fc_optimizer,
                                  self.trainer.auxfc)

        end_time = time.time()

        # compute cost time
        time_interval = end_time - start_time
        t_string = "Running Time is: " + \
            str(datetime.timedelta(seconds=time_interval)) + "\n"
        print t_string
        # write cost time to file
        self.visualize.write_readme(t_string)

        # save experimental results
        self.visualize.write_readme(
            "Best Result of all is: Top1 Error: %f, Top5 Error: %f\n" %
            (best_top1, best_top5))
예제 #5
0
train_loader = torch.utils.data.DataLoader(FaceDataset(train_config.annoPath,
                                                       transform=transform,
                                                       is_train=True),
                                           batch_size=train_config.batchSize,
                                           shuffle=True,
                                           **kwargs)

# Set model
model = ONet(config.NUM_LANDMARKS)
model = model.to(device)

# Set checkpoint
checkpoint = CheckPoint(train_config.save_path)

# Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=train_config.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=train_config.step,
                                                 gamma=0.1)

# Set trainer
logger = Logger(train_config.save_path)
trainer = ONetTrainer(train_config.lr, train_loader, model, optimizer,
                      scheduler, logger, device)

for epoch in range(1, train_config.nEpochs + 1):
    trainer.train(epoch)
    checkpoint.save_model(model,
                          index=epoch,
                          tag=str(config.NUM_LANDMARKS) + '_landmarks')
class ExperimentDesign:
    def __init__(self, options=None):
        self.settings = options or Option()
        self.checkpoint = None
        self.train_loader = None
        self.test_loader = None
        self.model = None

        self.optimizer_state = None
        self.trainer = None
        self.start_epoch = 0
        self.test_input = None
        self.model_analyse = None

        os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.visible_devices

        self.settings.set_save_path()
        self.logger = self.set_logger()
        self.settings.paramscheck(self.logger)
        self.visualize = vs.Visualization(self.settings.save_path, self.logger)
        self.tensorboard_logger = vs.Logger(self.settings.save_path)

        self.prepare()

    def set_logger(self):
        logger = logging.getLogger('sphereface')
        file_formatter = logging.Formatter(
            '%(asctime)s %(levelname)s: %(message)s')
        console_formatter = logging.Formatter('%(message)s')
        # file log
        file_handler = logging.FileHandler(
            os.path.join(self.settings.save_path, "train_test.log"))
        file_handler.setFormatter(file_formatter)

        # console log
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        logger.setLevel(logging.INFO)
        return logger

    def prepare(self):
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainer()
        self._draw_net()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        cudnn.benchmark = True

    def _set_dataloader(self):
        # create data loader
        data_loader = DataLoader(dataset=self.settings.dataset,
                                 batch_size=self.settings.batchSize,
                                 data_path=self.settings.dataPath,
                                 n_threads=self.settings.nThreads,
                                 ten_crop=self.settings.tenCrop,
                                 logger=self.logger)
        self.train_loader, self.test_loader = data_loader.getloader()

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint(
                self.settings.resume)
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = epoch
            self.optimizer_state = optimizer_state

    def _set_model(self):
        if self.settings.dataset in ["sphere", "sphere_large"]:
            self.test_input = Variable(torch.randn(1, 3, 112, 96).cuda())
            if self.settings.netType == "SphereNet":
                self.model = md.SphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereNIN":
                self.model = md.SphereNIN(
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereMobileNet_v2":
                self.model = md.SphereMobleNet_v2(
                    num_features=self.settings.featureDim)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupport data set: " + self.settings.dataset

    def _set_trainer(self):
        lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs,
                                   self.settings.lrPolicy)
        params_dict = {
            'power': self.settings.power,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        lr_master.set_params(params_dict=params_dict)

        self.trainer = Trainer(model=self.model,
                               train_loader=self.train_loader,
                               test_loader=self.test_loader,
                               lr_master=lr_master,
                               settings=self.settings,
                               logger=self.logger,
                               tensorboard_logger=self.tensorboard_logger,
                               optimizer_state=self.optimizer_state)

    def _draw_net(self):
        # visualize model
        if self.settings.drawNetwork:
            rand_output, _ = self.trainer.forward(self.test_input)
            self.visualize.save_network(rand_output)
            self.logger.info("|===>Draw network done!")

        self.visualize.write_settings(self.settings)

    def _model_analyse(self, model):
        # analyse model
        model_analyse = utils.ModelAnalyse(model, self.visualize)
        params_num = model_analyse.params_count()
        zero_num = model_analyse.zero_count()
        zero_rate = zero_num * 1.0 / params_num
        self.logger.info("zero rate is: {}".format(zero_rate))

        # save analyse result to file
        self.visualize.write_readme(
            "Number of parameters is: %d, number of zeros is: %d, zero rate is: %f"
            % (params_num, zero_num, zero_rate))

        # model_analyse.flops_compute(self.test_input)
        model_analyse.madds_compute(self.test_input)

    def run(self, run_count=0):
        self.logger.info("|===>Start training")
        best_top1 = 100
        start_time = time.time()
        # self._model_analyse(self.model)
        # assert False
        self.trainer.test()
        # assert False
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            if self.trainer.iteration >= self.settings.nIters:
                break
            self.start_epoch = 0
            # training and testing
            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            acc_mean, acc_std, acc_all = self.trainer.test()

            test_error_mean = 100 - acc_mean * 100
            # write and print result
            log_str = "{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format(
                epoch, train_error, train_loss, train5_error, acc_mean,
                acc_std)
            for acc in acc_all:
                log_str += "%.4f\t" % acc

            self.visualize.write_log(log_str)
            best_flag = False
            if best_top1 >= test_error_mean:
                best_top1 = test_error_mean
                best_flag = True
                self.logger.info(
                    "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format(
                        run_count, best_top1))
            else:
                self.logger.info(
                    "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format(
                        run_count, best_top1))

            self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer,
                                            epoch)

            if best_flag:
                self.checkpoint.save_model(self.model, best_flag=best_flag)

            if (epoch + 1) % self.settings.drawInterval == 0:
                self.visualize.draw_curves()

        end_time = time.time()

        if isinstance(self.model, nn.DataParallel):
            self.model = self.model.module
        # draw experimental curves
        self.visualize.draw_curves()

        # compute cost time
        time_interval = end_time - start_time
        t_string = "Running Time is: " + \
                   str(datetime.timedelta(seconds=time_interval)) + "\n"
        self.logger.info(t_string)
        # write cost time to file
        self.visualize.write_readme(t_string)
        # analyse model
        self._model_analyse(self.model)

        return best_top1
class Experiment(object):
    """
    run experiments with pre-defined pipeline
    """
    def __init__(self, options=None, conf_path=None):
        self.settings = options or Option(conf_path)
        self.checkpoint = None
        self.train_loader = None
        self.val_loader = None
        self.ori_model = None
        self.pruned_model = None
        self.segment_wise_trainer = None

        self.aux_fc_state = None
        self.aux_fc_opt_state = None
        self.seg_opt_state = None
        self.current_pivot_index = None
        self.is_segment_wise_finetune = False
        self.is_channel_selection = False

        self.epoch = 0

        self.feature_cache_origin = {}
        self.feature_cache_pruned = {}

        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu

        self.settings.set_save_path()
        self.write_settings()
        self.logger = self.set_logger()
        self.tensorboard_logger = TensorboardLogger(self.settings.save_path)

        self.prepare()

    def write_settings(self):
        """
        save experimental settings to a file
        """

        with open(os.path.join(self.settings.save_path, "settings.log"),
                  "w") as f:
            for k, v in self.settings.__dict__.items():
                f.write(str(k) + ": " + str(v) + "\n")

    def set_logger(self):
        """
        initialize logger
        """
        logger = logging.getLogger('channel_selection')
        file_formatter = logging.Formatter(
            '%(asctime)s %(levelname)s: %(message)s')
        console_formatter = logging.Formatter('%(message)s')
        # file log
        file_handler = logging.FileHandler(
            os.path.join(self.settings.save_path, "train_test.log"))
        file_handler.setFormatter(file_formatter)

        # console log
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        logger.setLevel(logging.INFO)
        return logger

    def prepare(self):
        """
        preparing experiments
        """

        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._cal_pivot()
        self._set_checkpoint()
        self._set_trainier()

    def _set_gpu(self):
        """
        initialize the seed of random number generator
        """

        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.seed)
        torch.cuda.manual_seed(self.settings.seed)
        torch.cuda.set_device(0)
        cudnn.benchmark = True

    def _set_dataloader(self):
        """
        create train loader and validation loader for channel pruning
        """

        if self.settings.dataset == 'cifar10':
            data_root = os.path.join(self.settings.data_path, "cifar")

            norm_mean = [0.49139968, 0.48215827, 0.44653124]
            norm_std = [0.24703233, 0.24348505, 0.26158768]
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])
            val_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])

            train_dataset = datasets.CIFAR10(root=data_root,
                                             train=True,
                                             transform=train_transform,
                                             download=True)
            val_dataset = datasets.CIFAR10(root=data_root,
                                           train=False,
                                           transform=val_transform)

            self.train_loader = torch.utils.data.DataLoader(
                dataset=train_dataset,
                batch_size=self.settings.batch_size,
                shuffle=True,
                pin_memory=True,
                num_workers=self.settings.n_threads)
            self.val_loader = torch.utils.data.DataLoader(
                dataset=val_dataset,
                batch_size=self.settings.batch_size,
                shuffle=False,
                pin_memory=True,
                num_workers=self.settings.n_threads)
        elif self.settings.dataset == 'imagenet':
            dataset_path = os.path.join(self.settings.data_path, "imagenet")
            traindir = os.path.join(dataset_path, "train")
            valdir = os.path.join(dataset_path, 'val')
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

            self.train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    traindir,
                    transforms.Compose([
                        transforms.RandomResizedCrop(224),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=self.settings.batch_size,
                shuffle=True,
                num_workers=self.settings.n_threads,
                pin_memory=True)

            self.val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=self.settings.batch_size,
                shuffle=False,
                num_workers=self.settings.n_threads,
                pin_memory=True)

    def _set_trainier(self):
        """
        initialize segment-wise trainer trainer
        """

        # initialize segment-wise trainer
        self.segment_wise_trainer = SegmentWiseTrainer(
            ori_model=self.ori_model,
            pruned_model=self.pruned_model,
            train_loader=self.train_loader,
            val_loader=self.val_loader,
            settings=self.settings,
            logger=self.logger,
            tensorboard_logger=self.tensorboard_logger)
        if self.aux_fc_state is not None:
            self.segment_wise_trainer.update_aux_fc(self.aux_fc_state,
                                                    self.aux_fc_opt_state,
                                                    self.seg_opt_state)

    def _set_model(self):
        """
        get model
        """

        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.net_type == "preresnet":
                self.ori_model = md.PreResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
                self.pruned_model = md.PreResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        elif self.settings.dataset in ["imagenet"]:
            if self.settings.net_type == "resnet":
                self.ori_model = md.ResNet(depth=self.settings.depth,
                                           num_classes=self.settings.n_classes)
                self.pruned_model = md.ResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        else:
            assert False, "unsupported data set: {}".format(
                self.settings.dataset)

    def _set_checkpoint(self):
        """
        load pre-trained model or resume checkpoint
        """

        assert self.ori_model is not None and self.pruned_model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        self._load_retrain()
        self._load_resume()

    def _load_retrain(self):
        """
        load pre-trained model
        """

        if self.settings.retrain is not None:
            check_point_params = torch.load(self.settings.retrain)
            if "ori_model" not in check_point_params:
                model_state = check_point_params
                self.ori_model = self.checkpoint.load_state(
                    self.ori_model, model_state)
                self.pruned_model = self.checkpoint.load_state(
                    self.pruned_model, model_state)
                self.logger.info("|===>load restrain file: {}".format(
                    self.settings.retrain))
            else:
                ori_model_state = check_point_params["ori_model"]
                pruned_model_state = check_point_params["pruned_model"]
                # self.current_block_count = check_point_params["current_pivot"]
                self.aux_fc_state = check_point_params["aux_fc"]
                # self.replace_layer()
                self.ori_model = self.checkpoint.load_state(
                    self.ori_model, ori_model_state)
                self.pruned_model = self.checkpoint.load_state(
                    self.pruned_model, pruned_model_state)
                self.logger.info("|===>load pre-trained model: {}".format(
                    self.settings.retrain))

    def _load_resume(self):
        """
        load resume checkpoint
        """

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            ori_model_state = check_point_params["ori_model"]
            pruned_model_state = check_point_params["pruned_model"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.aux_fc_opt_state = check_point_params["aux_fc_opt"]
            self.seg_opt_state = check_point_params["seg_opt"]
            self.current_pivot_index = check_point_params["current_pivot"]
            self.is_segment_wise_finetune = check_point_params[
                "segment_wise_finetune"]
            self.is_channel_selection = check_point_params["channel_selection"]
            self.epoch = check_point_params["epoch"]
            self.epoch = self.settings.segment_wise_n_epochs
            self.current_block_count = check_point_params[
                "current_block_count"]

            if self.is_channel_selection or \
                    (self.is_segment_wise_finetune and self.current_pivot_index > self.settings.pivot_set[0]):
                self.replace_layer()
            self.ori_model = self.checkpoint.load_state(
                self.ori_model, ori_model_state)
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, pruned_model_state)
            self.logger.info("|===>load resume file: {}".format(
                self.settings.resume))

    def _cal_pivot(self):
        """
        calculate the inserted layer for additional loss
        """

        self.num_segments = self.settings.n_losses + 1
        num_block_per_segment = (
            block_num[self.settings.net_type + str(self.settings.depth)] //
            self.num_segments) + 1
        pivot_set = []
        for i in range(self.num_segments - 1):
            pivot_set.append(num_block_per_segment * (i + 1))
        self.settings.pivot_set = pivot_set
        self.logger.info("pivot set: {}".format(pivot_set))

    def segment_wise_fine_tune(self, index):
        """
        conduct segment-wise fine-tuning
        :param index: segment index
        """

        best_top1 = 100
        best_top5 = 100

        start_epoch = 0
        if self.is_segment_wise_finetune and self.epoch != 0:
            start_epoch = self.epoch + 1
            self.epoch = 0
        for epoch in range(start_epoch, self.settings.segment_wise_n_epochs):
            train_error, train_loss, train5_error = self.segment_wise_trainer.train(
                epoch, index)
            val_error, val_loss, val5_error = self.segment_wise_trainer.val(
                epoch)

            # write and print result
            if isinstance(train_error, list):
                best_flag = False
                if best_top1 >= val_error[-1]:
                    best_top1 = val_error[-1]
                    best_top5 = val5_error[-1]
                    best_flag = True

            else:
                best_flag = False
                if best_top1 >= val_error:
                    best_top1 = val_error
                    best_top5 = val5_error
                    best_flag = True

            if best_flag:
                self.checkpoint.save_model(
                    ori_model=self.ori_model,
                    pruned_model=self.pruned_model,
                    aux_fc=self.segment_wise_trainer.aux_fc,
                    current_pivot=self.current_pivot_index,
                    segment_wise_finetune=True,
                    index=index)

            self.logger.info(
                "|===>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}\n".
                format(best_top1, best_top5))
            self.logger.info(
                "|===>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n"
                .format(100 - best_top1, 100 - best_top5))

            if self.settings.dataset in ["imagenet"]:
                self.checkpoint.save_checkpoint(
                    ori_model=self.ori_model,
                    pruned_model=self.pruned_model,
                    aux_fc=self.segment_wise_trainer.aux_fc,
                    aux_fc_opt=self.segment_wise_trainer.fc_optimizer,
                    seg_opt=self.segment_wise_trainer.seg_optimizer,
                    current_pivot=self.current_pivot_index,
                    segment_wise_finetune=True,
                    index=index,
                    epoch=epoch)
            else:
                self.checkpoint.save_checkpoint(
                    ori_model=self.ori_model,
                    pruned_model=self.pruned_model,
                    aux_fc=self.segment_wise_trainer.aux_fc,
                    aux_fc_opt=self.segment_wise_trainer.fc_optimizer,
                    seg_opt=self.segment_wise_trainer.seg_optimizer,
                    current_pivot=self.current_pivot_index,
                    segment_wise_finetune=True,
                    index=index)

    def replace_layer(self):
        """
        Replace the convolutional layer to mask convolutional layer
        """

        block_count = 0
        if self.settings.net_type in ["preresnet", "resnet"]:
            for module in self.pruned_model.modules():
                if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)):
                    block_count += 1
                    layer = module.conv2
                    if block_count <= self.current_block_count and not isinstance(
                            layer, MaskConv2d):
                        temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                               out_channels=layer.out_channels,
                                               kernel_size=layer.kernel_size,
                                               stride=layer.stride,
                                               padding=layer.padding,
                                               bias=(layer.bias is not None))
                        temp_conv.weight.data.copy_(layer.weight.data)

                        if layer.bias is not None:
                            temp_conv.bias.data.copy_(layer.bias.data)
                        module.conv2 = temp_conv

                    if isinstance(module, Bottleneck):
                        layer = module.conv3
                        if block_count <= self.current_block_count and not isinstance(
                                layer, MaskConv2d):
                            temp_conv = MaskConv2d(
                                in_channels=layer.in_channels,
                                out_channels=layer.out_channels,
                                kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding,
                                bias=(layer.bias is not None))
                            temp_conv.weight.data.copy_(layer.weight.data)

                            if layer.bias is not None:
                                temp_conv.bias.data.copy_(layer.bias.data)
                            module.conv3 = temp_conv

    def channel_selection(self):
        """
        conduct channel selection
        """

        # get testing error
        self.segment_wise_trainer.val(0)
        time_start = time.time()

        restart_index = None
        # find restart segment index
        if self.current_pivot_index:
            if self.current_pivot_index in self.settings.pivot_set:
                restart_index = self.settings.pivot_set.index(
                    self.current_pivot_index)
            else:
                restart_index = len(self.settings.pivot_set)

        for index in range(self.num_segments):
            if restart_index is not None:
                if index < restart_index:
                    continue
                elif index == restart_index:
                    if self.is_channel_selection and self.current_block_count == self.current_pivot_index:
                        self.is_channel_selection = False
                        continue

            if index == self.num_segments - 1:
                self.current_pivot_index = self.segment_wise_trainer.final_block_count
            else:
                self.current_pivot_index = self.settings.pivot_set[index]

            # fine tune the network with additional loss and final loss
            if (not self.is_segment_wise_finetune and not self.is_channel_selection) or \
                    (self.is_segment_wise_finetune and self.epoch != self.settings.segment_wise_n_epochs - 1):
                self.segment_wise_fine_tune(index)
            else:
                self.is_segment_wise_finetune = False

            # load best model
            best_model_path = os.path.join(
                self.checkpoint.save_path,
                'model_{:0>3d}_swft.pth'.format(index))
            check_point_params = torch.load(best_model_path)
            ori_model_state = check_point_params["ori_model"]
            pruned_model_state = check_point_params["pruned_model"]
            aux_fc_state = check_point_params["aux_fc"]
            self.ori_model = self.checkpoint.load_state(
                self.ori_model, ori_model_state)
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, pruned_model_state)
            self.segment_wise_trainer.update_model(self.ori_model,
                                                   self.pruned_model,
                                                   aux_fc_state)

            # replace the baseline model
            if index == 0:
                if self.settings.net_type in ['preresnet']:
                    self.ori_model.conv = copy.deepcopy(self.pruned_model.conv)
                    for ori_module, pruned_module in zip(
                            self.ori_model.modules(),
                            self.pruned_model.modules()):
                        if isinstance(ori_module, PreBasicBlock):
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                    self.ori_model.bn = copy.deepcopy(self.pruned_model.bn)
                    self.ori_model.fc = copy.deepcopy(self.pruned_model.fc)
                elif self.settings.net_type in ['resnet']:
                    self.ori_model.conv1 = copy.deepcopy(
                        self.pruned_model.conv)
                    self.ori_model.bn1 = copy.deepcopy(self.pruned_model.bn1)
                    for ori_module, pruned_module in zip(
                            self.ori_model.modules(),
                            self.pruned_model.modules()):
                        if isinstance(ori_module, BasicBlock):
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                        if isinstance(ori_module, Bottleneck):
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            ori_module.conv3 = copy.deepcopy(
                                pruned_module.conv3)
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            ori_module.bn3 = copy.deepcopy(pruned_module.bn3)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                    self.ori_model.fc = copy.deepcopy(self.pruned_model.fc)

                aux_fc_state = []
                for i in range(len(self.segment_wise_trainer.aux_fc)):
                    if isinstance(self.segment_wise_trainer.aux_fc[i],
                                  nn.DataParallel):
                        temp_state = self.segment_wise_trainer.aux_fc[
                            i].module.state_dict()
                    else:
                        temp_state = self.segment_wise_trainer.aux_fc[
                            i].state_dict()
                    aux_fc_state.append(temp_state)
                self.segment_wise_trainer.update_model(self.ori_model,
                                                       self.pruned_model,
                                                       aux_fc_state)
            self.segment_wise_trainer.val(0)

            # conduct channel selection
            # contains [0:index] segments
            net_origin_list = []
            net_pruned_list = []
            for j in range(index + 1):
                net_origin_list += utils.model2list(
                    self.segment_wise_trainer.ori_segments[j])
                net_pruned_list += utils.model2list(
                    self.segment_wise_trainer.pruned_segments[j])

            net_origin = nn.Sequential(*net_origin_list)
            net_pruned = nn.Sequential(*net_pruned_list)

            self._seg_channel_selection(
                net_origin=net_origin,
                net_pruned=net_pruned,
                aux_fc=self.segment_wise_trainer.aux_fc[index],
                pivot_index=self.current_pivot_index,
                index=index)

            # update optimizer
            aux_fc_state = []
            for i in range(len(self.segment_wise_trainer.aux_fc)):
                if isinstance(self.segment_wise_trainer.aux_fc[i],
                              nn.DataParallel):
                    temp_state = self.segment_wise_trainer.aux_fc[
                        i].module.state_dict()
                else:
                    temp_state = self.segment_wise_trainer.aux_fc[
                        i].state_dict()
                aux_fc_state.append(temp_state)

            self.segment_wise_trainer.update_model(self.ori_model,
                                                   self.pruned_model,
                                                   aux_fc_state)

            self.checkpoint.save_checkpoint(
                self.ori_model,
                self.pruned_model,
                self.segment_wise_trainer.aux_fc,
                self.segment_wise_trainer.fc_optimizer,
                self.segment_wise_trainer.seg_optimizer,
                self.current_pivot_index,
                channel_selection=True,
                index=index,
                block_count=self.current_pivot_index)

            self.logger.info(self.ori_model)
            self.logger.info(self.pruned_model)
            self.segment_wise_trainer.val(0)
            self.current_pivot_index = None

        self.checkpoint.save_model(self.ori_model,
                                   self.pruned_model,
                                   self.segment_wise_trainer.aux_fc,
                                   self.segment_wise_trainer.final_block_count,
                                   index=self.num_segments)
        time_interval = time.time() - time_start
        log_str = "cost time: {}".format(
            str(datetime.timedelta(seconds=time_interval)))
        self.logger.info(log_str)

    def _hook_origin_feature(self, module, input, output):
        gpu_id = str(output.get_device())
        self.feature_cache_origin[gpu_id] = output

    def _hook_pruned_feature(self, module, input, output):
        gpu_id = str(output.get_device())
        self.feature_cache_pruned[gpu_id] = output

    @staticmethod
    def _concat_gpu_data(data):
        data_cat = data["0"]
        for i in range(1, len(data)):
            data_cat = torch.cat((data_cat, data[str(i)].cuda(0)))
        return data_cat

    def _layer_channel_selection(self,
                                 net_origin,
                                 net_pruned,
                                 aux_fc,
                                 module,
                                 block_count,
                                 layer_name="conv2"):
        """
        conduct channel selection for module
        :param net_origin: original network segments
        :param net_pruned: pruned network segments
        :param aux_fc: auxiliary fully-connected layer
        :param module: the module need to be pruned
        :param block_count: current block no.
        :param layer_name: the name of layer need to be pruned
        """

        self.logger.info(
            "|===>layer-wise channel selection: block-{}-{}".format(
                block_count, layer_name))
        # layer-wise channel selection
        if layer_name == "conv2":
            layer = module.conv2
        elif layer_name == "conv3":
            layer = module.conv3
        else:
            assert False, "unsupport layer: {}".format(layer_name)

        if not isinstance(layer, MaskConv2d):
            temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                   out_channels=layer.out_channels,
                                   kernel_size=layer.kernel_size,
                                   stride=layer.stride,
                                   padding=layer.padding,
                                   bias=(layer.bias is not None))
            temp_conv.weight.data.copy_(layer.weight.data)

            if layer.bias is not None:
                temp_conv.bias.data.copy_(layer.bias.data)
            temp_conv.pruned_weight.data.fill_(0)
            temp_conv.d.fill_(0)

            if layer_name == "conv2":
                module.conv2 = temp_conv
            elif layer_name == "conv3":
                module.conv3 = temp_conv
            layer = temp_conv

        # define criterion
        criterion_mse = nn.MSELoss().cuda()
        criterion_softmax = nn.CrossEntropyLoss().cuda()

        # register hook
        if layer_name == "conv2":
            hook_origin = net_origin[block_count].conv2.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv2.register_forward_hook(
                self._hook_pruned_feature)
        elif layer_name == "conv3":
            hook_origin = net_origin[block_count].conv3.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv3.register_forward_hook(
                self._hook_pruned_feature)

        net_origin_parallel = utils.data_parallel(net_origin,
                                                  self.settings.n_gpus)
        net_pruned_parallel = utils.data_parallel(net_pruned,
                                                  self.settings.n_gpus)

        # avoid computing the gradient
        for params in net_origin_parallel.parameters():
            params.requires_grad = False
        for params in net_pruned_parallel.parameters():
            params.requires_grad = False

        net_origin_parallel.eval()
        net_pruned_parallel.eval()

        layer.pruned_weight.requires_grad = True
        aux_fc.cuda()
        logger_counter = 0
        record_time = utils.AverageMeter()

        for channel in range(layer.in_channels):
            if layer.d.eq(0).sum() <= math.floor(
                    layer.in_channels * self.settings.pruning_rate):
                break

            time_start = time.time()
            cum_grad = None
            record_selection_mse_loss = utils.AverageMeter()
            record_selection_softmax_loss = utils.AverageMeter()
            record_selection_loss = utils.AverageMeter()
            img_count = 0
            for i, (images, labels) in enumerate(self.train_loader):
                images = images.cuda()
                labels = labels.cuda()
                net_origin_parallel(images)
                output = net_pruned_parallel(images)
                softmax_loss = criterion_softmax(aux_fc(output), labels)

                origin_feature = self._concat_gpu_data(
                    self.feature_cache_origin)
                self.feature_cache_origin = {}
                pruned_feature = self._concat_gpu_data(
                    self.feature_cache_pruned)
                self.feature_cache_pruned = {}
                mse_loss = criterion_mse(pruned_feature, origin_feature)

                loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                loss.backward()
                record_selection_loss.update(loss.item(), images.size(0))
                record_selection_mse_loss.update(mse_loss.item(),
                                                 images.size(0))
                record_selection_softmax_loss.update(softmax_loss.item(),
                                                     images.size(0))

                if cum_grad is None:
                    cum_grad = layer.pruned_weight.grad.data.clone()
                else:
                    cum_grad.add_(layer.pruned_weight.grad.data)
                    layer.pruned_weight.grad = None

                img_count += images.size(0)
                if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                    break

            # write tensorboard log
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_LossAll".format(block_count, layer_name),
                value=record_selection_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_selection_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_selection_softmax_loss.avg,
                step=logger_counter)
            cum_grad.abs_()
            # calculate gradient F norm
            grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0)

            # find grad_fnorm with maximum absolute gradient
            while True:
                _, max_index = torch.topk(grad_fnorm, 1)
                if layer.d[max_index[0]] == 0:
                    layer.d[max_index[0]] = 1
                    layer.pruned_weight.data[:, max_index[
                        0], :, :] = layer.weight[:,
                                                 max_index[0], :, :].data.clone(
                                                 )
                    break
                else:
                    grad_fnorm[max_index[0]] = -1

            # fine-tune average meter
            record_finetune_softmax_loss = utils.AverageMeter()
            record_finetune_mse_loss = utils.AverageMeter()
            record_finetune_loss = utils.AverageMeter()

            record_finetune_top1_error = utils.AverageMeter()
            record_finetune_top5_error = utils.AverageMeter()

            # define optimizer
            params_list = []
            params_list.append({
                "params": layer.pruned_weight,
                "lr": self.settings.layer_wise_lr
            })
            if layer.bias is not None:
                layer.bias.requires_grad = True
                params_list.append({"params": layer.bias, "lr": 0.001})
            optimizer = torch.optim.SGD(
                params=params_list,
                weight_decay=self.settings.weight_decay,
                momentum=self.settings.momentum,
                nesterov=True)
            img_count = 0
            for epoch in range(1):
                for i, (images, labels) in enumerate(self.train_loader):
                    images = images.cuda()
                    labels = labels.cuda()
                    features = net_pruned_parallel(images)
                    net_origin_parallel(images)
                    output = aux_fc(features)
                    softmax_loss = criterion_softmax(output, labels)

                    origin_feature = self._concat_gpu_data(
                        self.feature_cache_origin)
                    self.feature_cache_origin = {}
                    pruned_feature = self._concat_gpu_data(
                        self.feature_cache_pruned)
                    self.feature_cache_pruned = {}
                    mse_loss = criterion_mse(pruned_feature, origin_feature)

                    top1_error, _, top5_error = utils.compute_singlecrop(
                        outputs=output,
                        labels=labels,
                        loss=softmax_loss,
                        top5_flag=True,
                        mean_flag=True)

                    # update parameters
                    optimizer.zero_grad()
                    loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(layer.parameters(),
                                                   max_norm=10.0)
                    layer.pruned_weight.grad.data.mul_(
                        layer.d.unsqueeze(0).unsqueeze(2).unsqueeze(
                            3).expand_as(layer.pruned_weight))
                    optimizer.step()
                    # update record info
                    record_finetune_softmax_loss.update(
                        softmax_loss.item(), images.size(0))
                    record_finetune_mse_loss.update(mse_loss.item(),
                                                    images.size(0))
                    record_finetune_loss.update(loss.item(), images.size(0))
                    record_finetune_top1_error.update(top1_error,
                                                      images.size(0))
                    record_finetune_top5_error.update(top5_error,
                                                      images.size(0))

                    img_count += images.size(0)
                    if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                        break

            layer.pruned_weight.grad = None
            if layer.bias is not None:
                layer.bias.requires_grad = False

            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_finetune_softmax_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Loss".format(block_count, layer_name),
                value=record_finetune_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_finetune_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top1Error".format(block_count, layer_name),
                value=record_finetune_top1_error.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top5Error".format(block_count, layer_name),
                value=record_finetune_top5_error.avg,
                step=logger_counter)

            # write log information to file
            self._write_log(
                dir_name=os.path.join(self.settings.save_path, "log"),
                file_name="log_block-{:0>2d}_{}.txt".format(
                    block_count, layer_name),
                log_str=
                "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n".
                format(int(layer.d.sum()), record_selection_loss.avg,
                       record_selection_mse_loss.avg,
                       record_selection_softmax_loss.avg,
                       record_finetune_loss.avg, record_finetune_mse_loss.avg,
                       record_finetune_softmax_loss.avg,
                       record_finetune_top1_error.avg,
                       record_finetune_top5_error.avg))
            log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format(
                block_count, layer_name, int(layer.d.sum()), layer.d.size(0))
            log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_selection_loss.avg, record_selection_mse_loss.avg,
                record_selection_softmax_loss.avg)
            log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_finetune_loss.avg, record_finetune_mse_loss.avg,
                record_finetune_softmax_loss.avg)
            log_str += "top1error: {:4f}\ttop5error: {:4f}".format(
                record_finetune_top1_error.avg, record_finetune_top5_error.avg)
            self.logger.info(log_str)

            logger_counter += 1
            time_interval = time.time() - time_start
            record_time.update(time_interval)

        for params in net_origin_parallel.parameters():
            params.requires_grad = True
        for params in net_pruned_parallel.parameters():
            params.requires_grad = True

        # remove hook
        hook_origin.remove()
        hook_pruned.remove()
        log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format(
            block_count, layer_name,
            str(datetime.timedelta(seconds=record_time.sum)),
            str(datetime.timedelta(seconds=record_time.avg)))
        self.logger.info(log_str)
        log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format(
            record_finetune_loss.avg, record_finetune_mse_loss.avg,
            record_finetune_softmax_loss.avg, record_finetune_top1_error.avg,
            record_finetune_top5_error.avg)
        self.logger.info(log_str)

        self.logger.info("|===>remove hook")

    @staticmethod
    def _write_log(dir_name, file_name, log_str):
        """
        Write log to file
        :param dir_name:  the path of directory
        :param file_name: the name of the saved file
        :param log_str: the string that need to be saved
        """

        if not os.path.isdir(dir_name):
            os.mkdir(dir_name)
        with open(os.path.join(dir_name, file_name), "a+") as f:
            f.write(log_str)

    def _seg_channel_selection(self, net_origin, net_pruned, aux_fc,
                               pivot_index, index):
        """
        conduct segment channel selection
        :param net_origin: original network segments
        :param net_pruned: pruned network segments
        :param aux_fc: auxiliary fully-connected layer
        :param pivot_index: the layer index of the additional loss
        :param index: the index of segment
        :return:
        """
        block_count = 0
        if self.settings.net_type in ["preresnet", "resnet"]:
            for module in net_pruned.modules():
                if isinstance(module, (PreBasicBlock, BasicBlock)):
                    block_count += 1
                    # We will not prune the pruned blocks again
                    if not isinstance(module.conv2, MaskConv2d):
                        self._layer_channel_selection(net_origin=net_origin,
                                                      net_pruned=net_pruned,
                                                      aux_fc=aux_fc,
                                                      module=module,
                                                      block_count=block_count,
                                                      layer_name="conv2")
                        self.logger.info("|===>checking layer type: {}".format(
                            type(module.conv2)))

                        self.checkpoint.save_model(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)
                        self.checkpoint.save_checkpoint(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            self.segment_wise_trainer.fc_optimizer,
                            self.segment_wise_trainer.seg_optimizer,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)

                elif isinstance(module, Bottleneck):
                    block_count += 1
                    if not isinstance(module.conv2, MaskConv2d):
                        self._layer_channel_selection(net_origin=net_origin,
                                                      net_pruned=net_pruned,
                                                      aux_fc=aux_fc,
                                                      module=module,
                                                      block_count=block_count,
                                                      layer_name="conv2")

                    if not isinstance(module.conv3, MaskConv2d):
                        self._layer_channel_selection(net_origin=net_origin,
                                                      net_pruned=net_pruned,
                                                      aux_fc=aux_fc,
                                                      module=module,
                                                      block_count=block_count,
                                                      layer_name="conv3")

                        self.checkpoint.save_model(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)
                        self.checkpoint.save_checkpoint(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            self.segment_wise_trainer.fc_optimizer,
                            self.segment_wise_trainer.seg_optimizer,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)