예제 #1
0
파일: ADC.py 프로젝트: aam12/distiller
    def compute_reward(self):
        """The ADC paper defines reward = -Error"""
        distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])

        if PERFORM_THINNING:
            _, total_macs, total_nnz = collect_conv_details(
                self.model, self.app_args.dataset)
            compression = distiller.model_numel(
                self.model, param_dims=[4]) / self.dense_model_size
        else:
            _, total_macs, total_nnz = collect_conv_details(
                self.model, self.app_args.dataset)
            compression = 1 - distiller.model_sparsity(self.model) / 100
            # What a hack!
            total_nnz *= compression

        msglogger.info("Total parameters left: %.2f%%" % (compression * 100))
        msglogger.info("Total compute left: %.2f%%" %
                       (total_macs / self.dense_model_macs * 100))
        # Train for zero or more epochs
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.app_args.optimizer_data['lr'],
            momentum=self.app_args.optimizer_data['momentum'],
            weight_decay=self.app_args.optimizer_data['weight_decay'])
        for _ in range(NUM_TRAINING_EPOCHS):
            self.services.train_fn(
                model=self.model,
                compression_scheduler=self.create_scheduler(),
                optimizer=optimizer,
                epoch=self.debug_stats['episode'])
        # Validate
        top1, top5, vloss = self.services.validate_fn(
            model=self.model, epoch=self.debug_stats['episode'])
        reward = self.amc_cfg.reward_fn(top1, top5, vloss, total_macs)

        stats = ('Peformance/Validation/',
                 OrderedDict([('Loss', vloss), ('Top1', top1), ('Top5', top5),
                              ('reward', reward),
                              ('total_macs', int(total_macs)),
                              ('macs_normalized',
                               total_macs / self.dense_model_macs * 100),
                              ('log(total_macs)', math.log(total_macs)),
                              ('total_nnz', int(total_nnz))]))
        distiller.log_training_progress(stats,
                                        None,
                                        self.debug_stats['episode'],
                                        steps_completed=0,
                                        total_steps=1,
                                        log_freq=1,
                                        loggers=[self.tflogger, self.pylogger])
        return reward, top1, total_macs, total_nnz
예제 #2
0
    def compute_reward(self, log_stats=True):
        """Compute the reward"""
        distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])
        total_macs, total_nnz = self.net_wrapper.get_model_resources_requirements(
            self.model)
        if self.amc_cfg.perform_thinning:
            compression = distiller.model_numel(
                self.model, param_dims=[4]) / self.dense_model_size
        else:
            compression = 1 - distiller.model_sparsity(self.model) / 100
            # What a hack!
            total_nnz *= compression

        accuracies = self.net_wrapper.train(self.amc_cfg.num_ft_epochs,
                                            self.episode)
        self.ft_stats_file.add_record([self.episode, accuracies])

        top1, top5, vloss = self.net_wrapper.validate()
        reward = self.amc_cfg.reward_fn(self, top1, top5, vloss, total_macs)

        if log_stats:
            macs_normalized = total_macs / self.dense_model_macs
            msglogger.info("Total parameters left: %.2f%%" %
                           (compression * 100))
            msglogger.info("Total compute left: %.2f%%" %
                           (total_macs / self.dense_model_macs * 100))

            stats = ('Performance/EpisodeEnd/',
                     OrderedDict([('Loss', vloss), ('Top1', top1),
                                  ('Top5', top5), ('reward', reward),
                                  ('total_macs', int(total_macs)),
                                  ('macs_normalized', macs_normalized * 100),
                                  ('log(total_macs)', math.log(total_macs)),
                                  ('total_nnz', int(total_nnz))]))
            distiller.log_training_progress(
                stats,
                None,
                self.episode,
                steps_completed=0,
                total_steps=1,
                log_freq=1,
                loggers=[self.tflogger, self.pylogger])
        return reward, top1, total_macs, total_nnz
예제 #3
0
def train(**kwargs):
    opt._parse(kwargs)
    train_writer = None
    value_writer = None
    if opt.vis:
        train_writer = SummaryWriter(
            log_dir='./runs/train_' +
            datetime.now().strftime('%y%m%d-%H-%M-%S'))
        value_writer = SummaryWriter(
            log_dir='./runs/val_' + datetime.now().strftime('%y%m%d-%H-%M-%S'))
    previous_loss = 1e10  # 上次学习的loss
    best_precision = 0  # 最好的精确度
    start_epoch = 0
    lr = opt.lr
    perf_scores_history = []  # 绩效分数
    # step1: criterion and optimizer
    # 1. 铰链损失(Hinge Loss):主要用于支持向量机(SVM) 中;
    # 2. 互熵损失 (Cross Entropy Loss,Softmax Loss ):用于Logistic 回归与Softmax 分类中;
    # 3. 平方损失(Square Loss):主要是最小二乘法(OLS)中;
    # 4. 指数损失(Exponential Loss) :主要用于Adaboost 集成学习算法中;
    # 5. 其他损失(如0-1损失,绝对值损失)
    criterion = t.nn.CrossEntropyLoss().to(opt.device)  # 损失函数
    # step2: meters
    train_losses = AverageMeter()  # 误差仪表
    train_top1 = AverageMeter()  # top1 仪表
    train_top5 = AverageMeter()  # top5 仪表
    pylogger = PythonLogger(msglogger)
    # step3: configure model
    model = getattr(models, opt.model)()  # 获得网络结构
    compression_scheduler = distiller.CompressionScheduler(model)
    optimizer = model.get_optimizer(lr, opt.weight_decay)  # 优化器
    if opt.load_model_path:
        # # 把所有的张量加载到CPU中
        # t.load(opt.load_model_path, map_location=lambda storage, loc: storage)
        # t.load(opt.load_model_path, map_location='cpu')
        # # 把所有的张量加载到GPU 1中
        # t.load(opt.load_model_path, map_location=lambda storage, loc: storage.cuda(1))
        # # 把张量从GPU 1 移动到 GPU 0
        # t.load(opt.load_model_path, map_location={'cuda:1': 'cuda:0'})
        checkpoint = t.load(opt.load_model_path)
        start_epoch = checkpoint["epoch"]
        # compression_scheduler.load_state_dict(checkpoint['compression_scheduler'], False)
        best_precision = checkpoint["best_precision"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer = checkpoint['optimizer']
    model.to(opt.device)  # 加载模型到 GPU

    if opt.compress:
        compression_scheduler = distiller.file_config(
            model, optimizer, opt.compress, compression_scheduler)  # 加载模型修剪计划表
        model.to(opt.device)
    # 学习速率调整器
    lr_scheduler = get_scheduler(optimizer, opt)
    # step4: data_image
    train_data = DatasetFromFilename(opt.data_root, flag='train')  # 训练集
    val_data = DatasetFromFilename(opt.data_root, flag='test')  # 验证集
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)  # 训练集加载器
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers)  # 验证集加载器
    # train
    for epoch in range(start_epoch, opt.max_epoch):
        model.train()
        if opt.pruning:
            compression_scheduler.on_epoch_begin(epoch)  # epoch 开始修剪
        train_losses.reset()  # 重置仪表
        train_top1.reset()  # 重置仪表
        # print('训练数据集大小', len(train_dataloader))
        total_samples = len(train_dataloader.sampler)
        steps_per_epoch = math.ceil(total_samples / opt.batch_size)
        train_progressor = ProgressBar(mode="Train  ",
                                       epoch=epoch,
                                       total_epoch=opt.max_epoch,
                                       model_name=opt.model,
                                       lr=lr,
                                       total=len(train_dataloader))
        lr = lr_scheduler.get_lr()[0]
        for ii, (data, labels, img_path, tag) in enumerate(train_dataloader):
            if not check_date(img_path, tag, msglogger): return
            if opt.pruning:
                compression_scheduler.on_minibatch_begin(
                    epoch, ii, steps_per_epoch, optimizer)  # batch 开始修剪
            train_progressor.current = ii + 1  # 训练集当前进度
            # train model
            input = data.to(opt.device)
            target = labels.to(opt.device)
            if train_writer:
                grid = make_grid(
                    (input.data.cpu() * 0.225 + 0.45).clamp(min=0, max=1))
                train_writer.add_image('train_images', grid,
                                       ii * (epoch + 1))  # 训练图片
            score = model(input)  # 网络结构返回值
            # 计算损失
            loss = criterion(score, target)
            if opt.pruning:
                # Before running the backward phase, we allow the scheduler to modify the loss
                # (e.g. add regularization loss)
                agg_loss = compression_scheduler.before_backward_pass(
                    epoch,
                    ii,
                    steps_per_epoch,
                    loss,
                    optimizer=optimizer,
                    return_loss_components=True)  # 模型修建误差
                loss = agg_loss.overall_loss
            train_losses.update(loss.item(), input.size(0))
            # loss = criterion(score[0], target)  # 计算损失   Inception3网络
            optimizer.zero_grad()  # 参数梯度设成0
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            if opt.pruning:
                compression_scheduler.on_minibatch_end(epoch, ii,
                                                       steps_per_epoch,
                                                       optimizer)  # batch 结束修剪

            precision1_train, precision5_train = accuracy(
                score, target, topk=(1, 5))  # top1 和 top5 的准确率

            # writer.add_graph(model, input)
            # precision1_train, precision2_train = accuracy(score[0], target, topk=(1, 2))  # Inception3网络
            train_losses.update(loss.item(), input.size(0))
            train_top1.update(precision1_train[0].item(), input.size(0))
            train_top5.update(precision5_train[0].item(), input.size(0))
            train_progressor.current_loss = train_losses.avg
            train_progressor.current_top1 = train_top1.avg
            train_progressor.current_top5 = train_top5.avg
            train_progressor()  # 打印进度
            if ii % opt.print_freq == 0:
                if train_writer:
                    train_writer.add_scalar('loss', train_losses.avg,
                                            ii * (epoch + 1))  # 训练误差
                    train_writer.add_text(
                        'top1', 'train accuracy top1 %s' % train_top1.avg,
                        ii * (epoch + 1))  # top1准确率文本
                    train_writer.add_scalars(
                        'accuracy', {
                            'top1': train_top1.avg,
                            'top5': train_top5.avg,
                            'loss': train_losses.avg
                        }, ii * (epoch + 1))
        # train_progressor.done()  # 保存训练结果为txt
        # validate and visualize
        if opt.pruning:
            distiller.log_weights_sparsity(model, epoch,
                                           loggers=[pylogger])  # 打印模型修剪结果
            compression_scheduler.on_epoch_end(epoch, optimizer)  # epoch 结束修剪
        val_loss, val_top1, val_top5 = val(model, criterion, val_dataloader,
                                           epoch, value_writer, lr)  # 校验模型
        sparsity = distiller.model_sparsity(model)
        perf_scores_history.append(
            distiller.MutableNamedTuple(
                {
                    'sparsity': sparsity,
                    'top1': val_top1,
                    'top5': val_top5,
                    'epoch': epoch + 1,
                    'lr': lr,
                    'loss': val_loss
                }, ))
        # 保持绩效分数历史记录从最好到最差的排序
        # 按稀疏度排序为主排序键,然后按top1、top5、epoch排序
        perf_scores_history.sort(key=operator.attrgetter(
            'sparsity', 'top1', 'top5', 'epoch'),
                                 reverse=True)
        for score in perf_scores_history[:1]:
            msglogger.info(
                '==> Best [Top1: %.3f   Top5: %.3f   Sparsity: %.2f on epoch: %d   Lr: %f   Loss: %f]',
                score.top1, score.top5, score.sparsity, score.epoch, lr,
                score.loss)

        best_precision = max(perf_scores_history[0].top1,
                             best_precision)  # 最大top1 准确率
        is_best = epoch + 1 == perf_scores_history[
            0].epoch  # 当前epoch 和最佳epoch 一样
        if is_best:
            model.save({
                "epoch":
                epoch + 1,
                "model_name":
                opt.model,
                "state_dict":
                model.state_dict(),
                "best_precision":
                best_precision,
                "optimizer":
                optimizer,
                "valid_loss": [val_loss, val_top1, val_top5],
                'compression_scheduler':
                compression_scheduler.state_dict(),
            })  # 保存模型
        # update learning rate
        lr_scheduler.step(epoch)  # 更新学习效率
        # 如果训练误差比上次大 降低学习效率
        # if train_losses.val > previous_loss:
        #     lr = lr * opt.lr_decay
        #     # 当loss大于上一次loss,降低学习率
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        #
        # previous_loss = train_losses.val
        t.cuda.empty_cache()  # 这个命令是清除没用的临时变量的
예제 #4
0
    def train(self):
        previous_loss = 1e10  # 上次学习的loss
        lr = self.opt.lr
        perf_scores_history = []
        pylogger = PythonLogger(msglogger)
        self.train_load_model()
        self.load_compress()
        self.create_write()
        lr_scheduler = get_scheduler(self.optimizer, opt)
        for epoch in range(self.start_epoch, self.opt.max_epoch):
            self.model.train()
            self.load_data()
            if self.opt.pruning:
                self.compression_scheduler.on_epoch_begin(epoch)  # epoch 开始修剪
            self.train_losses.reset()  # 重置仪表
            self.train_top1.reset()  # 重置仪表
            # print('训练数据集大小', len(train_dataloader))
            total_samples = len(self.train_dataloader.sampler)
            steps_per_epoch = math.ceil(total_samples / self.opt.batch_size)
            train_progressor = ProgressBar(mode="Train  ",
                                           epoch=epoch,
                                           total_epoch=self.opt.max_epoch,
                                           model_name=self.opt.model,
                                           total=len(self.train_dataloader))
            lr = lr_scheduler.get_lr()
            for ii, (data, labels,
                     img_path) in enumerate(self.train_dataloader):
                if self.opt.pruning:
                    self.compression_scheduler.on_minibatch_begin(
                        epoch, ii, steps_per_epoch,
                        self.optimizer)  # batch 开始修剪
                train_progressor.current = ii + 1  # 训练集当前进度
                # train model
                input = data.to(self.opt.device)
                target = labels.to(self.opt.device)
                score = self.model(input)  # 网络结构返回值
                loss = self.criterion(score, target)  # 计算损失
                if self.opt.pruning:
                    # Before running the backward phase, we allow the scheduler to modify the loss
                    # (e.g. add regularization loss)
                    agg_loss = self.compression_scheduler.before_backward_pass(
                        epoch,
                        ii,
                        steps_per_epoch,
                        loss,
                        optimizer=self.optimizer,
                        return_loss_components=True)  # 模型修建误差
                    loss = agg_loss.overall_loss
                self.train_losses.update(loss.item(), input.size(0))
                # loss = criterion(score[0], target)  # 计算损失   Inception3网络
                self.optimizer.zero_grad()  # 参数梯度设成0
                loss.backward()  # 反向传播
                self.optimizer.step()  # 更新参数

                if opt.pruning:
                    self.compression_scheduler.on_minibatch_end(
                        epoch, ii, steps_per_epoch,
                        self.optimizer)  # batch 结束修剪

                precision1_train, precision5_train = accuracy(
                    score, target, topk=(1, 5))  # top1 和 top5 的准确率

                # precision1_train, precision2_train = accuracy(score[0], target, topk=(1, 2))  # Inception3网络
                self.train_losses.update(loss.item(), input.size(0))
                self.train_top1.update(precision1_train[0].item(),
                                       input.size(0))
                self.train_top5.update(precision5_train[0].item(),
                                       input.size(0))
                train_progressor.current_loss = self.train_losses.avg
                train_progressor.current_top1 = self.train_top1.avg
                train_progressor.current_top5 = self.train_top5.avg
                train_progressor()  # 打印进度
                if (ii + 1) % self.opt.print_freq == 0:
                    self.visualization_train(input, ii, epoch)
            if self.opt.pruning:
                distiller.log_weights_sparsity(self.model,
                                               epoch,
                                               loggers=[pylogger])  # 打印模型修剪结果
                self.compression_scheduler.on_epoch_end(
                    epoch, self.optimizer)  # epoch 结束修剪
            val_loss, val_top1, val_top5 = val(self.model, self.criterion,
                                               self.val_dataloader, epoch,
                                               self.value_writer)  # 校验模型
            sparsity = distiller.model_sparsity(self.model)
            perf_scores_history.append(
                distiller.MutableNamedTuple(
                    {
                        'sparsity': sparsity,
                        'top1': val_top1,
                        'top5': val_top5,
                        'epoch': epoch + 1,
                        'lr': lr,
                        'loss': val_loss
                    }, ))
            # 保持绩效分数历史记录从最好到最差的排序
            # 按稀疏度排序为主排序键,然后按top1、top5、epoch排序
            perf_scores_history.sort(key=operator.attrgetter(
                'sparsity', 'top1', 'top5', 'epoch'),
                                     reverse=True)
            for score in perf_scores_history[:1]:
                msglogger.info(
                    '==> Best [Top1: %.3f   Top5: %.3f   Sparsity: %.2f on epoch: %d   Lr: %f   Loss: %f]',
                    score.top1, score.top5, score.sparsity, score.epoch, lr,
                    score.loss)

            is_best = epoch == perf_scores_history[
                0].epoch  # 当前epoch 和最佳epoch 一样
            self.best_precision = max(perf_scores_history[0].top1,
                                      self.best_precision)  # 最大top1 准确率
            if is_best:
                self.train_save_model(epoch, val_loss, val_top1, val_top5)
            # update learning rate
            lr = lr_scheduler.get_lr()
            # # 如果训练误差比上次大 降低学习效率
            # if self.train_losses.val > previous_loss:
            #     lr = lr * self.opt.lr_decay
            #     # 当loss大于上一次loss,降低学习率
            #     for param_group in self.optimizer.param_groups:
            #         param_group['lr'] = lr
            #
            # previous_loss = self.train_losses.val
            t.cuda.empty_cache()
예제 #5
0
def main():
    script_dir = os.path.dirname(__file__)
    module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
    global msglogger

    # Parse arguments
    args = parser.get_parser().parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    msglogger = apputils.config_pylogger(
        os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)

    # Log various details about the execution environment.  It is sometimes useful
    # to refer to past experiment executions and this information may be useful.
    # 记录有关执行环境的各种详细信息。有时是有用的
    # 参考过去的实验执行,这些信息可能有用。
    apputils.log_execution_env_state(args.compress,
                                     msglogger.logdir,
                                     gitroot=module_path)
    msglogger.debug("Distiller: %s", distiller.__version__)

    start_epoch = 0
    perf_scores_history = []
    if args.deterministic:
        # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
        # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
        # In Pytorch, support for deterministic execution is still a bit clunky.
        if args.workers > 1:
            msglogger.error(
                'ERROR: Setting --deterministic requires setting --workers/-j to 0 or 1'
            )  # 错误:设置--确定性要求将--workers/-j设置为0或1
            exit(1)  # 正常退出程序
        # Use a well-known seed, for repeatability of experiments 使用一种众所周知的种子,用于实验的重复性。
        distiller.set_deterministic()
    else:
        # This issue: https://github.com/pytorch/pytorch/issues/3659
        # Implies that cudnn.benchmark should respect cudnn.deterministic, but empirically we see that
        # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
        cudnn.benchmark = True

    if args.cpu or not torch.cuda.is_available():
        # Set GPU index to -1 if using CPU
        args.device = 'cpu'
        args.gpus = -1
    else:
        args.device = 'cuda'
        if args.gpus is not None:
            try:
                args.gpus = [int(s) for s in args.gpus.split(',')]
            except ValueError:
                msglogger.error(
                    'ERROR: Argument --gpus must be a comma-separated list of integers only'
                )
                exit(1)
            available_gpus = torch.cuda.device_count()
            for dev_id in args.gpus:
                if dev_id >= available_gpus:
                    msglogger.error(
                        'ERROR: GPU device ID {0} requested, but only {1} devices available'
                        .format(dev_id, available_gpus))
                    exit(1)
            # Set default device in case the first one on the list != 0
            torch.cuda.set_device(args.gpus[0])

    # Infer the dataset from the model name
    args.dataset = 'cousm'

    if args.earlyexit_thresholds:
        args.num_exits = len(args.earlyexit_thresholds) + 1
        args.loss_exits = [0] * args.num_exits
        args.losses_exits = []
        args.exiterrors = []

    # Create the model
    model = ResNet152()
    # model = torch.nn.DataParallel(model, device_ids=args.gpus) # 并行GPU
    model.to(args.device)
    compression_scheduler = None  # 压缩调度
    # Create a couple of logging backends.  TensorBoardLogger writes log files in a format
    # that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
    # 创建两个日志后端 TensorBoardLogger以Google的Tensor板可以读取的格式写入日志文件。python logger将写入python记录器。
    tflogger = TensorBoardLogger(msglogger.logdir)
    pylogger = PythonLogger(msglogger)

    # capture thresholds for early-exit training
    if args.earlyexit_thresholds:
        msglogger.info('=> using early-exit threshold values of %s',
                       args.earlyexit_thresholds)

    # We can optionally resume from a checkpoint
    if args.resume:  # 加载训练模型
        # checkpoint = torch.load(args.resume)
        # model.load_state_dict(checkpoint['state_dict'])
        model, compression_scheduler, start_epoch = apputils.load_checkpoint(
            model, chkpt_file=args.resume)
        model.to(args.device)

    # Define loss function (criterion) and optimizer  # 定义损失函数和优化器SGD
    criterion = nn.CrossEntropyLoss().to(args.device)

    # optimizer = torch.optim.SGD(model.fc.parameters(), lr=args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    optimizer = torch.optim.Adam(model.model.fc.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    msglogger.info('Optimizer Type: %s', type(optimizer))
    msglogger.info('Optimizer Args: %s', optimizer.defaults)

    if args.AMC:  # 自动化的深层压缩
        return automated_deep_compression(model, criterion, optimizer,
                                          pylogger, args)
    if args.greedy:  # 贪婪的
        return greedy(model, criterion, optimizer, pylogger, args)

    # This sample application can be invoked to produce various summary reports. # 可以调用此示例应用程序来生成各种摘要报告。
    if args.summary:
        return summarize_model(model, args.dataset, which_summary=args.summary)
    # 激活统计收集器
    activations_collectors = create_activation_stats_collectors(
        model, *args.activation_stats)

    if args.qe_calibration:
        msglogger.info('Quantization calibration stats collection enabled:')
        msglogger.info(
            '\tStats will be collected for {:.1%} of test dataset'.format(
                args.qe_calibration))
        msglogger.info(
            '\tSetting constant seeds and converting model to serialized execution'
        )
        distiller.set_deterministic()
        model = distiller.make_non_parallel_copy(model)
        activations_collectors.update(
            create_quantization_stats_collector(model))  # 量化统计收集器
        args.evaluate = True
        args.effective_test_size = args.qe_calibration

    # Load the datasets: the dataset to load is inferred from the model name passed
    # in args.arch.  The default dataset is ImageNet, but if args.arch contains the
    # substring "_cifar", then cifar10 is used.
    # 加载数据集:从传递的模型名称推断要加载的数据集

    train_loader, val_loader, test_loader, _ = get_data_loaders(
        datasets_fn, r'/home/tian/Desktop/image_yasuo', args.batch_size,
        args.workers, args.validation_split, args.deterministic,
        args.effective_train_size, args.effective_valid_size,
        args.effective_test_size)
    msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                   len(train_loader.sampler), len(val_loader.sampler),
                   len(test_loader.sampler))
    # 可以调用此示例应用程序来对模型执行敏感性分析。输出保存到csv和png。
    if args.sensitivity is not None:
        sensitivities = np.arange(args.sensitivity_range[0],
                                  args.sensitivity_range[1],
                                  args.sensitivity_range[2])
        return sensitivity_analysis(model, criterion, test_loader, pylogger,
                                    args, sensitivities)

    if args.evaluate:
        return evaluate_model(model, criterion, test_loader, pylogger,
                              activations_collectors, args,
                              compression_scheduler)

    if args.compress:
        # The main use-case for this sample application is CNN compression. Compression
        # requires a compression schedule configuration file in YAML.
        # #这个示例应用程序的主要用例是CNN压缩
        # #需要yaml中的压缩计划配置文件。
        compression_scheduler = distiller.file_config(model, optimizer,
                                                      args.compress,
                                                      compression_scheduler)
        # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
        # 如果添加了参数(如PactQualifier),则模型会重新传输到GPU。
        model.to(args.device)
    elif compression_scheduler is None:
        compression_scheduler = distiller.CompressionScheduler(model)  # 压缩计划程序

    if args.thinnify:
        # zeros_mask_dict = distiller.create_model_masks_dict(model)
        assert args.resume is not None, "You must use --resume to provide a checkpoint file to thinnify"  # 必须使用--resume提供检查点文件以细化
        distiller.remove_filters(model,
                                 compression_scheduler.zeros_mask_dict,
                                 args.arch,
                                 args.dataset,
                                 optimizer=None)
        apputils.save_checkpoint(0,
                                 args.arch,
                                 model,
                                 optimizer=None,
                                 scheduler=compression_scheduler,
                                 name="{}_thinned".format(
                                     args.resume.replace(".pth.tar", "")),
                                 dir=msglogger.logdir)
        print(
            "Note: your model may have collapsed to random inference, so you may want to fine-tune"
        )  # 注意:您的模型可能已折叠为随机推理,因此您可能需要对其进行微调。
        return

    args.kd_policy = None  # 蒸馏
    if args.kd_teacher:
        teacher = create_model(args.kd_pretrained,
                               args.dataset,
                               args.kd_teacher,
                               device_ids=args.gpus)
        if args.kd_resume:
            teacher, _, _ = apputils.load_checkpoint(teacher,
                                                     chkpt_file=args.kd_resume)
        dlw = distiller.DistillationLossWeights(args.kd_distill_wt,
                                                args.kd_student_wt,
                                                args.kd_teacher_wt)
        args.kd_policy = distiller.KnowledgeDistillationPolicy(
            model, teacher, args.kd_temp, dlw)
        compression_scheduler.add_policy(args.kd_policy,
                                         starting_epoch=args.kd_start_epoch,
                                         ending_epoch=args.epochs,
                                         frequency=1)

        msglogger.info('\nStudent-Teacher knowledge distillation enabled:')
        msglogger.info('\tTeacher Model: %s', args.kd_teacher)
        msglogger.info('\tTemperature: %s', args.kd_temp)
        msglogger.info('\tLoss Weights (distillation | student | teacher): %s',
                       ' | '.join(['{:.2f}'.format(val) for val in dlw]))
        msglogger.info('\tStarting from Epoch: %s', args.kd_start_epoch)
    lr = args.lr
    lr_decay = 0.5
    for epoch in range(start_epoch, args.epochs):
        # This is the main training loop.
        msglogger.info('\n')
        if compression_scheduler:
            compression_scheduler.on_epoch_begin(epoch)

        # Train for one epoch
        with collectors_context(activations_collectors["train"]) as collectors:
            train(train_loader,
                  model,
                  criterion,
                  optimizer,
                  epoch,
                  compression_scheduler,
                  loggers=[tflogger, pylogger],
                  args=args)
            distiller.log_weights_sparsity(model,
                                           epoch,
                                           loggers=[tflogger, pylogger])
            distiller.log_activation_statsitics(
                epoch,
                "train",
                loggers=[tflogger],
                collector=collectors["sparsity"])
            if args.masks_sparsity:  # 打印掩盖稀疏表 在end of each epoch
                msglogger.info(
                    distiller.masks_sparsity_tbl_summary(
                        model, compression_scheduler))

        # evaluate on validation set
        with collectors_context(activations_collectors["valid"]) as collectors:
            top1, top5, vloss = validate(val_loader, model, criterion,
                                         [pylogger], args, epoch)
            distiller.log_activation_statsitics(
                epoch,
                "valid",
                loggers=[tflogger],
                collector=collectors["sparsity"])
            save_collectors_data(collectors, msglogger.logdir)

        stats = ('Peformance/Validation/',
                 OrderedDict([('Loss', vloss), ('Top1', top1),
                              ('Top5', top5)]))
        distiller.log_training_progress(stats,
                                        None,
                                        epoch,
                                        steps_completed=0,
                                        total_steps=1,
                                        log_freq=1,
                                        loggers=[tflogger])

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch, optimizer)

        # Update the list of top scores achieved so far, and save the checkpoint # 更新到目前为止获得的最高分数列表,并保存检查点
        sparsity = distiller.model_sparsity(model)
        perf_scores_history.append(
            distiller.MutableNamedTuple({
                'sparsity': sparsity,
                'top1': top1,
                'top5': top5,
                'epoch': epoch
            }))
        # Keep perf_scores_history sorted from best to worst
        # Sort by sparsity as main sort key, then sort by top1, top5 and epoch
        # 保持绩效分数历史记录从最好到最差的排序
        # 按稀疏度排序为主排序键,然后按top1、top5、epoch排序
        perf_scores_history.sort(key=operator.attrgetter(
            'sparsity', 'top1', 'top5', 'epoch'),
                                 reverse=True)
        for score in perf_scores_history[:args.num_best_scores]:
            msglogger.info(
                '==> Best [Top1: %.3f   Top5: %.3f   Sparsity: %.2f on epoch: %d]',
                score.top1, score.top5, score.sparsity, score.epoch)

        is_best = epoch == perf_scores_history[0].epoch
        apputils.save_checkpoint(epoch, args.arch, model, optimizer,
                                 compression_scheduler,
                                 perf_scores_history[0].top1, is_best,
                                 args.name, msglogger.logdir)
        if not is_best:
            lr = lr * lr_decay
            # 当loss大于上一次loss,降低学习率
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Finally run results on the test set # 最后在测试集上运行结果
    test(test_loader,
         model,
         criterion, [pylogger],
         activations_collectors,
         args=args)