Exemplo n.º 1
0
def validate_general(val_loader, model, infer_model_fn, cuda=False):
    from examples.classification.main import AverageMeter, accuracy
    top1 = AverageMeter()
    top5 = AverageMeter()

    for i, (input_, target) in enumerate(val_loader):
        # compute output
        output = infer_model_fn(model, input_)

        if cuda:
            target = target.cuda(None, non_blocking=True)
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        top1.update(acc1, input_.size(0))
        top5.update(acc5, input_.size(0))

        if i % 10 == 0:
            print('IE Test : [{0}/{1}]\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      i, len(val_loader), top1=top1, top5=top5))

    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                top5=top5))
    return top1.avg, top5.avg
def train_epoch_bin(train_loader, batch_multiplier, model, criterion, optimizer,
                    optimizer_scheduler: BinarizationOptimizerScheduler, kd_loss_calculator: KDLossCalculator,
                    compression_ctrl, epoch, config, is_inception=False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    kd_losses_meter = AverageMeter()
    criterion_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    compression_scheduler = compression_ctrl.scheduler

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input_, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        input_ = input_.to(config.device)
        target = target.to(config.device)

        # compute output
        if is_inception:
            # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
            output, aux_outputs = model(input_)
            loss1 = criterion(output, target)
            loss2 = criterion(aux_outputs, target)
            criterion_loss = loss1 + 0.4 * loss2
        else:
            output = model(input_)
            criterion_loss = criterion(output, target)

        # compute KD loss
        kd_loss = kd_loss_calculator.loss(input_, output)
        loss = criterion_loss + kd_loss

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input_.size(0))
        comp_loss_val = kd_loss.item()
        kd_losses_meter.update(comp_loss_val, input_.size(0))
        criterion_losses.update(criterion_loss.item(), input_.size(0))
        top1.update(acc1, input_.size(0))
        top1.update(acc1, input_.size(0))
        top5.update(acc5, input_.size(0))

        # compute gradient and do SGD step
        if i % batch_multiplier == 0:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            loss.backward()

        compression_scheduler.step()
        optimizer_scheduler.step(float(i) / len(train_loader))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.print_freq == 0:
            logger.info(
                '{rank}: '
                'Epoch: [{0}][{1}/{2}] '
                'Lr: {3:.3} '
                'Wd: {4:.3} '
                'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
                'CE_loss: {ce_loss.val:.4f} ({ce_loss.avg:.4f}) '
                'KD_loss: {kd_loss.val:.4f} ({kd_loss.avg:.4f}) '
                'Loss: {loss.val:.4f} ({loss.avg:.4f}) '
                'Acc@1: {top1.val:.3f} ({top1.avg:.3f}) '
                'Acc@5: {top5.val:.3f} ({top5.avg:.3f})'.format(
                    epoch, i, len(train_loader), get_lr(optimizer), get_wd(optimizer), batch_time=batch_time,
                    data_time=data_time, ce_loss=criterion_losses, kd_loss=kd_losses_meter,
                    loss=losses, top1=top1, top5=top5,
                    rank='{}:'.format(config.rank) if config.multiprocessing_distributed else ''
                ))

        if is_main_process():
            global_step = len(train_loader) * epoch
            config.tb.add_scalar("train/learning_rate", get_lr(optimizer), i + global_step)
            config.tb.add_scalar("train/criterion_loss", criterion_losses.avg, i + global_step)
            config.tb.add_scalar("train/kd_loss", kd_losses_meter.avg, i + global_step)
            config.tb.add_scalar("train/loss", losses.avg, i + global_step)
            config.tb.add_scalar("train/top1", top1.avg, i + global_step)
            config.tb.add_scalar("train/top5", top5.avg, i + global_step)

            for stat_name, stat_value in compression_ctrl.statistics().items():
                if isinstance(stat_value, (int, float)):
                    config.tb.add_scalar('train/statistics/{}'.format(stat_name), stat_value, i + global_step)