Ejemplo n.º 1
0
    def create_mtcnn_net(self, p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True):
        dirname, _ = os.path.split(p_model_path)
        checkpoint = CheckPoint(dirname)

        pnet, rnet, onet = None, None, None
        self.device = torch.device(
            "cuda:0" if use_cuda and torch.cuda.is_available() else "cpu")

        if p_model_path is not None:
            pnet = PNet()
            pnet_model_state = checkpoint.load_model(p_model_path)
            pnet = checkpoint.load_state(pnet, pnet_model_state)
            if (use_cuda):
                pnet.to(self.device)
            pnet.eval()

        if r_model_path is not None:
            rnet = RNet()
            rnet_model_state = checkpoint.load_model(r_model_path)
            rnet = checkpoint.load_state(rnet, rnet_model_state)
            if (use_cuda):
                rnet.to(self.device)
            rnet.eval()

        if o_model_path is not None:
            onet = ONet()
            onet_model_state = checkpoint.load_model(o_model_path)
            onet = checkpoint.load_state(onet, onet_model_state)
            if (use_cuda):
                onet.to(self.device)
            onet.eval()

        return pnet, rnet, onet
Ejemplo n.º 2
0
class ExperimentDesign:
    def __init__(self, options=None):
        self.settings = options or Option()
        self.checkpoint = None
        self.data_loader = None
        self.model = None

        self.optimizer_state = None
        self.trainer = None
        self.start_epoch = 0

        self.model_analyse = None

        self.visualize = vs.Visualization(self.settings.save_path)
        self.logger = vs.Logger(self.settings.save_path)
        self.test_input = None
        self.lr_master = None
        self.prepare()

    def prepare(self):
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_parallel()
        self._set_lr_policy()
        self._set_trainer()
        self._draw_net()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        print("|===>Set GPU done!")

    def _set_dataloader(self):
        # create data loader
        self.data_loader = DataLoader(dataset=self.settings.dataset,
                                      batch_size=self.settings.batchSize,
                                      data_path=self.settings.dataPath,
                                      n_threads=self.settings.nThreads,
                                      ten_crop=self.settings.tenCrop)
        print("|===>Set data loader done!")

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint(
                self.settings.resume)
            self.model = self.checkpoint.load_state(self.model, model_state)
            # self.start_epoch = epoch
            # self.optimizer_state = optimizer_state
        print("|===>Set checkpoint done!")

    def _set_model(self):
        if self.settings.dataset == "sphere":
            if self.settings.netType == "SphereNet":
                self.model = md.SphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereNIN":
                self.model = md.SphereNIN(
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "wcSphereNet":
                self.model = md.wcSphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim,
                    rate=self.settings.rate)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupport data set: " + self.settings.dataset
        print("|===>Set model done!")

    def _set_parallel(self):
        self.model = utils.data_parallel(self.model, self.settings.nGPU,
                                         self.settings.GPU)

    def _set_lr_policy(self):
        self.lr_master = utils.LRPolicy(self.settings.lr, self.settings.nIters,
                                        self.settings.lrPolicy)
        params_dict = {
            'gamma': self.settings.gamma,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        self.lr_master.set_params(params_dict=params_dict)

    def _set_trainer(self):

        train_loader, test_loader = self.data_loader.getloader()
        self.trainer = Trainer(model=self.model,
                               lr_master=self.lr_master,
                               n_epochs=self.settings.nEpochs,
                               n_iters=self.settings.nIters,
                               train_loader=train_loader,
                               test_loader=test_loader,
                               feature_dim=self.settings.featureDim,
                               momentum=self.settings.momentum,
                               weight_decay=self.settings.weightDecay,
                               optimizer_state=self.optimizer_state,
                               logger=self.logger)

    def _draw_net(self):
        # visualize model

        if self.settings.dataset == "sphere":
            rand_input = torch.randn(1, 3, 112, 96)
        else:
            assert False, "invalid data set"
        rand_input = Variable(rand_input.cuda())
        self.test_input = rand_input

        if self.settings.drawNetwork:
            rand_output, _ = self.trainer.forward(rand_input)
            self.visualize.save_network(rand_output)
            print("|===>Draw network done!")

        self.visualize.write_settings(self.settings)

    def pruning(self, run_count=0):
        net_type = None
        if self.settings.dataset == "sphere":
            if self.settings.netType == "wcSphereNet":
                net_type = "SphereNet"

        assert net_type is not None, "net_type for prune is NoneType"

        self.trainer.test()

        if isinstance(self.model, nn.DataParallel):
            model = self.model.module
        else:
            model = self.model

        if net_type == "SphereNet":
            model_prune = prune.SpherePrune(model)
        model_prune.run()
        self.trainer.reset_model(model_prune.model)
        self.model = self.trainer.model

        self.trainer.test()
        self.checkpoint.save_model(self.trainer.model,
                                   index=run_count,
                                   tag="pruning")

        # analyse model
        self.model_analyse = utils.ModelAnalyse(self.trainer.model,
                                                self.visualize)
        params_num = self.model_analyse.params_count()
        self.model_analyse.flops_compute(self.test_input)

    def fine_tuning(self, run_count=0):
        # set lr
        self.settings.lr = 0.01  #  0.1
        self.settings.nIters = 12000  # 28000
        self.settings.lrPolicy = "multi_step"
        self.settings.decayRate = 0.1
        self.settings.step = [6000]  #  [16000, 24000]

        self._set_lr_policy()
        self.trainer.reset_lr(self.lr_master, self.settings.nIters)

        # run fine-tuning
        self.training(run_count, tag="fine-tuning")

    def retrain(self, run_count=0):
        self.settings.lr = 0.1
        self.settings.nIters = 28000
        self.settings.lrPolicy = "multi_step"
        self.settings.decayRate = 0.1
        self.settings.step = [16000, 24000]
        self._set_lr_policy()
        self.trainer.reset_lr(self.lr_master, self.settings.nIters)

        # run retrain
        self.training(run_count, tag="training")

    def run(self, run_count=0):
        """
        if run_count == 0:
            print "|===> training"
            self.retrain(run_count)
        else:     
            print "|===> fine-tuning"
            self.fine_tuning(run_count)
        """
        if run_count >= 1:
            print("|===> training")
            self.retrain(run_count)

            self.trainer.reset_model(self.model)
            print("|===> fine-tuning")
            self.fine_tuning(run_count)

        print("|===> pruning")
        self.pruning(run_count)

        # keep margin_linear static
        layer_count = 0
        for layer in self.model.modules():
            if isinstance(layer, md.MarginLinear):
                layer.iteration.fill_(0)
                layer.margin_type.data.fill_(1)
                layer.weight.requires_grad = False

            elif isinstance(layer, nn.Linear):
                layer.weight.requires_grad = False
                layer.bias.requires_grad = False

            elif isinstance(layer, nn.Conv2d):
                if layer.bias is not None:
                    bias_flag = True
                else:
                    bias_flag = False
                new_layer = prune.wcConv2d(layer.weight.size(1),
                                           layer.weight.size(0),
                                           kernel_size=layer.kernel_size,
                                           stride=layer.stride,
                                           padding=layer.padding,
                                           bias=bias_flag,
                                           rate=self.settings.rate)
                new_layer.weight.data.copy_(layer.weight.data)
                if layer.bias is not None:
                    new_layer.bias.data.copy_(layer.bias.data)
                if layer_count == 1:
                    self.model.conv2 = new_layer
                elif layer_count == 2:
                    self.model.conv3 = new_layer
                elif layer_count == 3:
                    self.model.conv4 = new_layer
                layer_count += 1

        print(self.model)
        self.trainer.reset_model(self.model)
        # assert False

    def training(self, run_count=0, tag="training"):
        best_top1 = 100
        # start_time = time.time()
        self.trainer.test()
        # assert False
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            if self.trainer.iteration >= self.trainer.n_iters:
                break
            start_epoch = 0
            # training and testing
            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            acc_mean, acc_std, acc_all = self.trainer.test()

            test_error_mean = 100 - acc_mean * 100
            # write and print result
            log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                epoch, train_error, train_loss, train5_error, acc_mean,
                acc_std)
            for acc in acc_all:
                log_str += "%.4f\t" % acc

            self.visualize.write_log(log_str)
            best_flag = False
            if best_top1 >= test_error_mean:
                best_top1 = test_error_mean
                best_flag = True
                print(
                    colored(
                        "# %d ==>Best Result is: Top1 Error: %f\n" %
                        (run_count, best_top1), "red"))
            else:
                print(
                    colored(
                        "# %d ==>Best Result is: Top1 Error: %f\n" %
                        (run_count, best_top1), "blue"))

            self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer,
                                            epoch)

            if best_flag:
                self.checkpoint.save_model(self.model,
                                           best_flag=best_flag,
                                           tag="%s_%d" % (tag, run_count))

            if (epoch + 1) % self.settings.drawInterval == 0:
                self.visualize.draw_curves()

            for name, value in self.model.named_parameters():
                if 'weight' in name:
                    name = name.replace('.', '/')
                    self.logger.histo_summary(
                        name,
                        value.data.cpu().numpy(),
                        run_count * self.settings.nEpochs + epoch + 1)
                    if value.grad is not None:
                        self.logger.histo_summary(
                            name + "/grad",
                            value.grad.data.cpu().numpy(),
                            run_count * self.settings.nEpochs + epoch + 1)

        # end_time = time.time()

        if isinstance(self.model, nn.DataParallel):
            self.model = self.model.module
        # draw experimental curves
        self.visualize.draw_curves()

        # compute cost time
        # time_interval = end_time - start_time
        # t_string = "Running Time is: " + \
        #    str(datetime.timedelta(seconds=time_interval)) + "\n"
        # print(t_string)
        # write cost time to file
        # self.visualize.write_readme(t_string)

        # save experimental results
        self.model_analyse = utils.ModelAnalyse(self.trainer.model,
                                                self.visualize)
        self.visualize.write_readme("Best Result of all is: Top1 Error: %f\n" %
                                    best_top1)

        # analyse model
        params_num = self.model_analyse.params_count()

        # save analyse result to file
        self.visualize.write_readme("Number of parameters is: %d" %
                                    (params_num))
        self.model_analyse.prune_rate()

        self.model_analyse.flops_compute(self.test_input)

        return best_top1
Ejemplo n.º 3
0
class ExperimentDesign(object):
    """
    run experiments with pre-defined pipeline
    """
    def __init__(self):
        self.settings = Option()
        self.checkpoint = None
        self.data_loader = None
        self.model = None

        self.trainer = None
        self.seg_opt_state = None
        self.fc_opt_state = None
        self.aux_fc_state = None
        self.start_epoch = 0

        self.model_analyse = None

        self.visualize = vs.Visualization(self.settings.save_path)
        self.logger = vs.Logger(self.settings.save_path)

        self.prepare()

    def prepare(self):
        """
        preparing experiments
        """
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainer()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        cudnn.benchmark = True

    def _set_dataloader(self):
        # create data loader
        self.data_loader = DataLoader(dataset=self.settings.dataset,
                                      batch_size=self.settings.batchSize,
                                      data_path=self.settings.dataPath,
                                      n_threads=self.settings.nThreads,
                                      ten_crop=self.settings.tenCrop)

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            model_state = check_point_params["model"]
            self.seg_opt_state = check_point_params["seg_opt"]
            self.fc_opt_state = check_point_params["fc_opt"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = 90

    def _set_model(self):
        print("netType:", self.settings.netType)
        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.netType == "DARTSNet":
                genotype = md.genotypes.DARTS
                self.model = md.DARTSNet(self.settings.init_channels,
                                         self.settings.nClasses,
                                         self.settings.layers,
                                         self.settings.auxiliary, genotype)
            elif self.settings.netType == "PreResNet":
                self.model = md.PreResNet(depth=self.settings.depth,
                                          num_classes=self.settings.nClasses,
                                          wide_factor=self.settings.wideFactor)
            elif self.settings.netType == "CifarResNeXt":
                self.model = md.CifarResNeXt(self.settings.cardinality,
                                             self.settings.depth,
                                             self.settings.nClasses,
                                             self.settings.base_width,
                                             self.settings.widen_factor)

            elif self.settings.netType == "ResNet":
                self.model = md.ResNet(self.settings.depth,
                                       self.settings.nClasses)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupported data set: " + self.settings.dataset

    def _set_trainer(self):
        # set lr master
        lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs,
                                   self.settings.lrPolicy)
        params_dict = {
            'power': self.settings.power,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        lr_master.set_params(params_dict=params_dict)
        # set trainer
        train_loader, test_loader = self.data_loader.getloader()
        self.trainer = Trainer(model=self.model,
                               train_loader=train_loader,
                               test_loader=test_loader,
                               lr_master=lr_master,
                               settings=self.settings,
                               logger=self.logger)
        if self.seg_opt_state is not None:
            self.trainer.resume(aux_fc_state=self.aux_fc_state,
                                seg_opt_state=self.seg_opt_state,
                                fc_opt_state=self.fc_opt_state)

    def _save_best_model(self, model, aux_fc):

        check_point_params = {}
        if isinstance(model, nn.DataParallel):
            check_point_params["model"] = model.module.state_dict()
        else:
            check_point_params["model"] = model.state_dict()
        aux_fc_state = []
        for i in range(len(aux_fc)):
            if isinstance(aux_fc[i], nn.DataParallel):
                aux_fc_state.append(aux_fc[i].module.state_dict())
            else:
                aux_fc_state.append(aux_fc[i].state_dict())
        check_point_params["aux_fc"] = aux_fc_state
        torch.save(
            check_point_params,
            os.path.join(self.checkpoint.save_path,
                         "best_model_with_AuxFC.pth"))

    def _save_checkpoint(self,
                         model,
                         seg_optimizers,
                         fc_optimizers,
                         aux_fc,
                         index=0):
        check_point_params = {}
        if isinstance(model, nn.DataParallel):
            check_point_params["model"] = model.module.state_dict()
        else:
            check_point_params["model"] = model.state_dict()
        seg_opt_state = []
        fc_opt_state = []
        aux_fc_state = []
        for i in range(len(seg_optimizers)):
            seg_opt_state.append(seg_optimizers[i].state_dict())
        for i in range(len(fc_optimizers)):
            fc_opt_state.append(fc_optimizers[i].state_dict())
            if isinstance(aux_fc[i], nn.DataParallel):
                aux_fc_state.append(aux_fc[i].module.state_dict())
            else:
                aux_fc_state.append(aux_fc[i].state_dict())

        check_point_params["seg_opt"] = seg_opt_state
        check_point_params["fc_opt"] = fc_opt_state
        check_point_params["aux_fc"] = aux_fc_state

        torch.save(
            check_point_params,
            os.path.join(self.checkpoint.save_path,
                         "checkpoint_%03d.pth" % index))

    def run(self, run_count=0):
        """
        training and testing
        """

        best_top1 = 100
        best_top5 = 100
        start_time = time.time()
        if self.settings.resume is not None or self.settings.retrain is not None:
            self.trainer.test(0)
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            self.start_epoch = 0

            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            test_error, test_loss, test5_error = self.trainer.test(epoch=epoch)

            # write and print result
            if isinstance(train_error, np.ndarray):
                log_str = "%d\t" % (epoch)
                for i in range(len(train_error)):
                    log_str += "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                        train_error[i],
                        train_loss[i],
                        test_error[i],
                        test_loss[i],
                        train5_error[i],
                        test5_error[i],
                    )

                best_flag = False
                if best_top1 >= test_error[-1]:
                    best_top1 = test_error[-1]
                    best_top5 = test5_error[-1]
                    best_flag = True

            else:
                log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                    epoch, train_error, train_loss, test_error, test_loss,
                    train5_error, test5_error)
                best_flag = False
                if best_top1 >= test_error:
                    best_top1 = test_error
                    best_top5 = test5_error
                    best_flag = True

            self.visualize.write_log(log_str)

            if best_flag:
                self.checkpoint.save_model(self.trainer.model,
                                           best_flag=best_flag)
                self._save_best_model(self.trainer.model, self.trainer.auxfc)

                print colored(
                    "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n"
                    % (run_count, best_top1, best_top5), "red")
            else:
                print colored(
                    "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n"
                    % (run_count, best_top1, best_top5), "blue")

            self._save_checkpoint(self.trainer.model,
                                  self.trainer.seg_optimizer,
                                  self.trainer.fc_optimizer,
                                  self.trainer.auxfc)

        end_time = time.time()

        # compute cost time
        time_interval = end_time - start_time
        t_string = "Running Time is: " + \
            str(datetime.timedelta(seconds=time_interval)) + "\n"
        print t_string
        # write cost time to file
        self.visualize.write_readme(t_string)

        # save experimental results
        self.visualize.write_readme(
            "Best Result of all is: Top1 Error: %f, Top5 Error: %f\n" %
            (best_top1, best_top5))
class ExperimentDesign:
    def __init__(self, options=None):
        self.settings = options or Option()
        self.checkpoint = None
        self.train_loader = None
        self.test_loader = None
        self.model = None

        self.optimizer_state = None
        self.trainer = None
        self.start_epoch = 0
        self.test_input = None
        self.model_analyse = None

        os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.visible_devices

        self.settings.set_save_path()
        self.logger = self.set_logger()
        self.settings.paramscheck(self.logger)
        self.visualize = vs.Visualization(self.settings.save_path, self.logger)
        self.tensorboard_logger = vs.Logger(self.settings.save_path)

        self.prepare()

    def set_logger(self):
        logger = logging.getLogger('sphereface')
        file_formatter = logging.Formatter(
            '%(asctime)s %(levelname)s: %(message)s')
        console_formatter = logging.Formatter('%(message)s')
        # file log
        file_handler = logging.FileHandler(
            os.path.join(self.settings.save_path, "train_test.log"))
        file_handler.setFormatter(file_formatter)

        # console log
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        logger.setLevel(logging.INFO)
        return logger

    def prepare(self):
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainer()
        self._draw_net()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        cudnn.benchmark = True

    def _set_dataloader(self):
        # create data loader
        data_loader = DataLoader(dataset=self.settings.dataset,
                                 batch_size=self.settings.batchSize,
                                 data_path=self.settings.dataPath,
                                 n_threads=self.settings.nThreads,
                                 ten_crop=self.settings.tenCrop,
                                 logger=self.logger)
        self.train_loader, self.test_loader = data_loader.getloader()

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint(
                self.settings.resume)
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = epoch
            self.optimizer_state = optimizer_state

    def _set_model(self):
        if self.settings.dataset in ["sphere", "sphere_large"]:
            self.test_input = Variable(torch.randn(1, 3, 112, 96).cuda())
            if self.settings.netType == "SphereNet":
                self.model = md.SphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereNIN":
                self.model = md.SphereNIN(
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereMobileNet_v2":
                self.model = md.SphereMobleNet_v2(
                    num_features=self.settings.featureDim)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupport data set: " + self.settings.dataset

    def _set_trainer(self):
        lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs,
                                   self.settings.lrPolicy)
        params_dict = {
            'power': self.settings.power,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        lr_master.set_params(params_dict=params_dict)

        self.trainer = Trainer(model=self.model,
                               train_loader=self.train_loader,
                               test_loader=self.test_loader,
                               lr_master=lr_master,
                               settings=self.settings,
                               logger=self.logger,
                               tensorboard_logger=self.tensorboard_logger,
                               optimizer_state=self.optimizer_state)

    def _draw_net(self):
        # visualize model
        if self.settings.drawNetwork:
            rand_output, _ = self.trainer.forward(self.test_input)
            self.visualize.save_network(rand_output)
            self.logger.info("|===>Draw network done!")

        self.visualize.write_settings(self.settings)

    def _model_analyse(self, model):
        # analyse model
        model_analyse = utils.ModelAnalyse(model, self.visualize)
        params_num = model_analyse.params_count()
        zero_num = model_analyse.zero_count()
        zero_rate = zero_num * 1.0 / params_num
        self.logger.info("zero rate is: {}".format(zero_rate))

        # save analyse result to file
        self.visualize.write_readme(
            "Number of parameters is: %d, number of zeros is: %d, zero rate is: %f"
            % (params_num, zero_num, zero_rate))

        # model_analyse.flops_compute(self.test_input)
        model_analyse.madds_compute(self.test_input)

    def run(self, run_count=0):
        self.logger.info("|===>Start training")
        best_top1 = 100
        start_time = time.time()
        # self._model_analyse(self.model)
        # assert False
        self.trainer.test()
        # assert False
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            if self.trainer.iteration >= self.settings.nIters:
                break
            self.start_epoch = 0
            # training and testing
            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            acc_mean, acc_std, acc_all = self.trainer.test()

            test_error_mean = 100 - acc_mean * 100
            # write and print result
            log_str = "{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format(
                epoch, train_error, train_loss, train5_error, acc_mean,
                acc_std)
            for acc in acc_all:
                log_str += "%.4f\t" % acc

            self.visualize.write_log(log_str)
            best_flag = False
            if best_top1 >= test_error_mean:
                best_top1 = test_error_mean
                best_flag = True
                self.logger.info(
                    "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format(
                        run_count, best_top1))
            else:
                self.logger.info(
                    "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format(
                        run_count, best_top1))

            self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer,
                                            epoch)

            if best_flag:
                self.checkpoint.save_model(self.model, best_flag=best_flag)

            if (epoch + 1) % self.settings.drawInterval == 0:
                self.visualize.draw_curves()

        end_time = time.time()

        if isinstance(self.model, nn.DataParallel):
            self.model = self.model.module
        # draw experimental curves
        self.visualize.draw_curves()

        # compute cost time
        time_interval = end_time - start_time
        t_string = "Running Time is: " + \
                   str(datetime.timedelta(seconds=time_interval)) + "\n"
        self.logger.info(t_string)
        # write cost time to file
        self.visualize.write_readme(t_string)
        # analyse model
        self._model_analyse(self.model)

        return best_top1