Ejemplo n.º 1
0
def test_model(test_loader, model, it, cfg):
    print('\nModel testing start ...\n')
    test_tic = time.time()
    top1 = AverageMeter()
    top5 = AverageMeter()

    if cfg.mode_custom:
        if isinstance(model, nn.DataParallel):
            model.module.eval_mode(it, cfg)
        else:
            model.eval_mode(it, cfg)
    else:
        model.eval()

    with torch.no_grad():
        cost = AverageMeter()  # 处理图片耗时
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)

            xtic = time.time()
            if images.dim() == 4:
                outputs = model(images)
            elif images.dim() == 5:
                bh, ncrop, c, h, w = images.size()
                outputs = model(images.view(-1, c, h, w))
                if isinstance(outputs, (list, tuple)):
                    outputs = [ot.view(bh, ncrop, -1).mean(1) for ot in outputs]
                else:
                    outputs = outputs.view(bh, ncrop, -1).mean(1)
            else:
                raise NotImplementedError('exptect image.dim in [4, 5], but %s' % images.dim())
            cost.update(time.time() - xtic, images.size(0))

            if isinstance(outputs, (list, tuple)):
                if isinstance(cfg.xfc_which, int):
                    outputs = outputs[cfg.xfc_which]
                elif cfg.xfc_which == 'all-avg':
                    outputs = sum(outputs) / len(outputs)
                elif cfg.xfc_which == 'aux-avg':
                    outputs = sum(outputs[:-1]) / len(outputs[:-1])
                else:
                    raise NotImplementedError

            prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))

    current_time = (time.time() - test_tic) / 60
    cfg.test_total_time += current_time

    print('\n*** Test Model at Iteration {it:d} --> '
          'Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}\t '
          'CurrentTestTime {vtime:.4f} minutes\t TotalTestTime {ttime:.4f} minutes, '
          'SingleImage-AvgTestTime {sitime:.5f}*** \n'
          .format(it=it, top1=top1, top5=top5, vtime=current_time, ttime=cfg.test_total_time, sitime=cost.avg))
    return top1.avg, top5.avg
Ejemplo n.º 2
0
def valid_model(val_loader, model, criterion, it, cfg, writer=None):
    valid_tic = time.time()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    if cfg.mode_custom:
        if isinstance(model, nn.DataParallel):
            model.module.eval_mode(it, cfg)
        else:
            model.eval_mode(it, cfg)
    else:
        model.eval()

    with torch.no_grad():
        btic = time.time()
        for i, (images, labels) in enumerate(val_loader):
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)

            if images.dim() == 4:
                outputs = model(images)
            elif images.dim() == 5:
                bh, ncrop, c, h, w = images.size()
                outputs = model(images.view(-1, c, h, w))
                if isinstance(outputs, (list, tuple)):
                    outputs = [
                        ot.view(bh, ncrop, -1).mean(1) for ot in outputs
                    ]
                else:
                    outputs = outputs.view(bh, ncrop, -1).mean(1)
            else:
                raise NotImplementedError(
                    'exptect image.dim in [4, 5], but %s' % images.dim())

            if isinstance(outputs, (list, tuple)):
                all_loss = [criterion(out, labels) for out in outputs]
                # 只验证哪一个xfc抽头  或   # 验证所有xfc抽头的平均
                if isinstance(cfg.xfc_which, int):
                    loss = all_loss[cfg.xfc_which]
                    outputs = outputs[cfg.xfc_which]
                elif cfg.xfc_which == 'all-avg':
                    loss = sum(all_loss) / len(outputs)
                    outputs = sum(outputs) / len(outputs)
                elif cfg.xfc_which == 'aux-avg':
                    loss = sum(all_loss[:-1]) / len(outputs[:-1])
                    outputs = sum(outputs[:-1]) / len(outputs[:-1])
                else:
                    raise NotImplementedError
            else:
                loss = criterion(outputs, labels)

            prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - btic)
            btic = time.time()

            if it == 0 and i % 20 == 0:
                print('Valid: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} \t Loss {loss.val:.4f} \t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))

    if writer is not None:
        writer.add_scalar(tag='val-loss',
                          scalar_value=losses.avg,
                          global_step=it)
        writer.add_scalar(tag='val-prec1',
                          scalar_value=top1.avg,
                          global_step=it)
        writer.add_scalar(tag='val-prec5',
                          scalar_value=top5.avg,
                          global_step=it)

    current_time = (time.time() - valid_tic) / 60
    cfg.valid_total_time += current_time
    print(
        '\nValidate Iteration <{it:d}> --> '
        'prec@1-{top1.avg:.3f} prec@5-{top5.avg:.3f} Loss {losses.avg:.3f}\t '
        'CurrentValidTime {vtime:.4f} minutes\t TotalValidTime {ttime:.4f} minutes ****** \n'
        .format(it=it,
                top1=top1,
                top5=top5,
                losses=losses,
                vtime=current_time,
                ttime=cfg.valid_total_time))

    return top1.avg, top5.avg, losses.avg
Ejemplo n.º 3
0
def train_model(cfg=Config()):

    # resumer-flow
    if cfg.resume:
        if cfg.resume_config:
            assert os.path.isfile(cfg.resume), FileNotFoundError(
                '\nCan not find the .ckpt file: %s ...' % cfg.resume)
            print("\nloading config from checkpoint file at %s ..." %
                  (cfg.resume, ))
            # checkpoint = torch.load(f=cfg.resume, map_location=cfg.device)
            cfg.dict_to_class(torch.load(f=cfg.resume)['config'],
                              exclude=cfg.exclude_keys)
            cfg.start_iter = cfg.current_iter + 1
            print('loaded done at epoch {0} ......\n'.format(
                cfg.current_epoch))
            print(
                'current training state: train_prec1-%0.4f val_prec1-%0.4f val-prec5-%0.4f\n'
                %
                (cfg.best_train_prec1, cfg.best_val_prec1, cfg.best_val_prec5))
            print('model will be saved at ckpt_dir: {0} ......\n'.format(
                cfg.ckpt_dir))
            print('log will be saved at log_dir: {0} ......\n'.format(
                cfg.log_dir))
            checkpoint = None
    print('\nExp-%s start ... \n' % (cfg.exp_version, ))

    # data-flow
    train_loader, val_loader, test_loader = \
        data_factory(cfg.dataset, cfg.data_root, cfg.bsize_train, cfg.bsize_val,
                     cfg.data_augment, cfg.data_workers, result='loader', **cfg.data_kwargs)
    assert cfg.batch_nums == len(
        train_loader
    ), '<batch_num> must == <data_info>[train_size]>/<batch_size>'
    assert cfg.max_iters % cfg.batch_nums == 0, '迭代次数不能整除数据集批次数,最后epoch中数据不完整'
    cfg.max_epochs = cfg.max_iters // cfg.batch_nums
    print(
        '\nTrain BatchNums-> %s : Val BatchNums-> %s : Test BatchNums-> %s\n' %
        (len(train_loader), len(val_loader), len(test_loader)))

    # model-flow
    model, params, gflops, mdepth = model_factory(cfg.arch_name,
                                                  cfg.arch_kwargs,
                                                  cfg.dataset,
                                                  with_info='return')
    model = model.to(cfg.device)
    if len(cfg.gpu_ids) > 1:
        model = nn.DataParallel(model, device_ids=cfg.gpu_ids)
    if cfg.resume:
        model = xtils.load_ckpt_weights(model,
                                        cfg.resume,
                                        cfg.device,
                                        cfg.mgpus_to_sxpu,
                                        strict=cfg.resume_strict)

    # logger-flow
    if cfg.train_val_test[0]:
        writer = SummaryWriter(cfg.log_dir)
        # writer.add_graph(model, torch.zeros(4, 3, 224, 224).to(cfg.device))

    # judger
    criterion = nn.CrossEntropyLoss()

    # optimizer
    if cfg.optim_custom:
        if isinstance(model, nn.DataParallel):
            optimizer = model.module.init_optimizer(cfg=cfg)
        else:
            optimizer = model.init_optimizer(cfg=cfg)
    else:
        optimizer = get_optimizer(model, cfg=cfg)
    if cfg.resume and cfg.resume_optimizer:
        optimizer.load_state_dict(torch.load(cfg.resume)['optimizer'])

    # batch iterations flow
    epoch = 0
    current_lr, data_iter = None, None
    epoch_tic = time.time()
    epoch_time = AverageMeter()
    if cfg.resume:
        epoch = cfg.start_iter // cfg.batch_nums
        cfg.current_epoch = epoch
        data_iter = iter(train_loader)
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        data_time = AverageMeter()
        batch_time = AverageMeter()

    print(
        '\n=> Warning: Train Val Test iterations will be start, please confirm your config at last!\n'
    )
    cfg.show_config()

    # iteration-flow
    if cfg.train_val_test[0]:
        for it in range(cfg.start_iter, cfg.max_iters):

            cfg.current_iter = it
            batch_tic = time.time()

            if it == 0 or it % cfg.batch_nums == 0:
                epoch = it // cfg.batch_nums
                cfg.current_epoch = epoch
                if cfg.data_shuffle:
                    train_loader, _, _ = data_factory(cfg.dataset,
                                                      cfg.data_root,
                                                      cfg.bsize_train,
                                                      cfg.bsize_val,
                                                      cfg.data_augment,
                                                      cfg.data_workers,
                                                      result='loader',
                                                      **cfg.data_kwargs)
                data_iter = iter(train_loader)
                losses = AverageMeter()
                top1 = AverageMeter()
                top5 = AverageMeter()
                batch_time = AverageMeter()
                data_time = AverageMeter()
                epoch_time.update(time.time() - epoch_tic)
                epoch_tic = time.time()

            if cfg.mode_custom:
                if isinstance(model, nn.DataParallel):
                    model.module.train_mode(it, cfg)
                else:
                    model.train_mode(it, cfg)
            else:
                model.train()

            current_lr = adjust_learning_rate(optimizer, it, cfg)

            # data to device
            tic = time.time()
            images, labels = next(data_iter)
            data_time.update(time.time() - tic)
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)

            # forward
            outputs = model(images)
            if isinstance(outputs, (list, tuple)):
                all_loss = [criterion(out, labels) for out in outputs]
                loss = sum(all_loss)
            else:
                loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if isinstance(outputs, (list, tuple)):
                loss = all_loss[cfg.xfc_which]
                outputs = outputs[cfg.xfc_which]

            # measure accuracy and record loss of
            # current batch(.val) and current epoch(.avg)
            prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(prec1.item(), images.size(0))
            top5.update(prec5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - batch_tic)
            cfg.train_total_time += batch_time.val

            cfg.curr_train_prec1, cfg.curr_train_prec5 = top1.val, top5.val

            # 信息打印
            if (it + 1) % cfg.print_frequency == 0:
                print(
                    'Epoch: [{0}/{1}][{2}/{3} {4:.2f}%] '
                    'BatchTime:{batch_time.val:.4f}s({batch_time.avg:.4f}s) '
                    'DataTime:{data_time.val:.5f}s({data_time.avg:.5f}s) '
                    'Loss-{loss.val:.4f}({loss.avg:.4f}) '
                    'Prec@1-{top1.val:.3f}({top1.avg:.4f}) '
                    'Prec@5-{top5.val:.3f}({top5.avg:.4f}) '
                    'lr-{lr:.5f} {cost:.1f}Hours EpochTime:{epoch_time:.2f}Minutes'
                    .format(epoch,
                            cfg.max_epochs,
                            it,
                            cfg.max_iters,
                            100 * it / cfg.max_iters,
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            top1=top1,
                            top5=top5,
                            lr=current_lr,
                            cost=cfg.train_total_time / 3600,
                            epoch_time=epoch_time.val / 60))

            # 绘制曲线
            if (it + 1) % cfg.plot_frequency == 0:
                writer.add_scalar(tag='train-loss',
                                  scalar_value=losses.val,
                                  global_step=it)
                writer.add_scalar(tag='train-prec1',
                                  scalar_value=top1.val,
                                  global_step=it)
                writer.add_scalar(tag='train-prec5',
                                  scalar_value=top5.val,
                                  global_step=it)
                writer.add_scalar(tag='learning-rate',
                                  scalar_value=current_lr,
                                  global_step=it)

            # 模型验证
            if cfg.train_val_test[1] and (
                (it + 1) >= cfg.val_frequency[0] and
                (it + 1) % cfg.val_frequency[1] == 0):
                cfg.curr_val_prec1, cfg.curr_val_prec5, curr_val_loss \
                    = valid_model(val_loader, model, criterion, it, cfg, writer)

            # 模型测试
            if cfg.train_val_test[2] and (
                (it + 1) >= cfg.test_frequency[0] and
                (it + 1) % cfg.test_frequency[1] == 0):
                curr_test_prec1, curr_test_prec5 = test_model(
                    test_loader, model, it, cfg)

            # 记录历史最佳值, 并保存其模型
            # 大于多少迭代次(或迭代回合)再开始发现最大值,防止开始时保存太多
            cfg.best_val_prec1 = max(cfg.curr_val_prec1, cfg.best_val_prec1)
            cfg.best_train_prec1 = max(cfg.curr_train_prec1,
                                       cfg.best_train_prec1)
            cfg.best_val_prec5 = max(cfg.curr_val_prec5, cfg.best_val_prec5)
            cfg.best_train_prec5 = max(cfg.curr_train_prec5,
                                       cfg.best_train_prec5)

            if (it + 1) >= cfg.batch_nums * cfg.best_prec[
                    'best_start'] and cfg.best_prec[
                        'val_prec1'] < cfg.curr_val_prec1:
                cfg.best_prec['val_prec1'] = cfg.curr_val_prec1
                cfg.best_prec['best_ok'] = True
                cfg.ckpt_suffix = '-best'
            else:
                cfg.best_prec['best_ok'] = False
                cfg.ckpt_suffix = '-norm'

            # 保存模型
            if ((it + 1) >= cfg.save_frequency[0] and (it + 1) % cfg.save_frequency[1] == 0) \
                    or it == cfg.max_iters - 1 or cfg.best_prec['best_ok']:
                checkpoint = {
                    'model': model.state_dict(),
                    'config': cfg.class_to_dict(),
                    'optimizer': optimizer.state_dict()
                }
                filename = '%s-%s%s-ep%s-it%d-acc%.2f-best%.2f-topv%.2f-par%.2fM%s-%s.ckpt' % \
                           (cfg.dataset, cfg.arch_name, mdepth, epoch, it, cfg.curr_val_prec1,
                            cfg.best_val_prec1, cfg.best_val_prec5, params, cfg.ckpt_suffix, cfg.exp_version)
                if not os.path.exists(cfg.ckpt_dir):
                    os.makedirs(cfg.ckpt_dir)
                print('\n *** Model will be saved at: %s ..., %s \n' %
                      (cfg.ckpt_dir, filename))
                torch.save(checkpoint, f=os.path.join(cfg.ckpt_dir, filename))

    if cfg.train_val_test[1]:
        valid_model(val_loader, model, criterion, 0, cfg, writer=None)

    if cfg.train_val_test[2]:
        test_model(test_loader, model, 0, cfg)