コード例 #1
0
def eval_model(cfg, net):
    org_aux = net.aux_mode
    net.aux_mode = 'eval'

    is_dist = dist.is_initialized()
    dl = get_data_loader(cfg, mode='val', distributed=is_dist)
    net.eval()

    heads, mious = [], []
    logger = logging.getLogger()

    single_scale = MscEvalV0((1., ), False)
    mIOU = single_scale(net, dl, cfg.n_cats)
    heads.append('single_scale')
    mious.append(mIOU)
    logger.info('single mIOU is: %s\n', mIOU)

    single_crop = MscEvalCrop(
        cropsize=cfg.eval_crop,
        cropstride=2. / 3,
        flip=False,
        scales=(1., ),
        lb_ignore=255,
    )
    mIOU = single_crop(net, dl, cfg.n_cats)
    heads.append('single_scale_crop')
    mious.append(mIOU)
    logger.info('single scale crop mIOU is: %s\n', mIOU)

    ms_flip = MscEvalV0(cfg.eval_scales, True)
    mIOU = ms_flip(net, dl, cfg.n_cats)
    heads.append('ms_flip')
    mious.append(mIOU)
    logger.info('ms flip mIOU is: %s\n', mIOU)

    ms_flip_crop = MscEvalCrop(
        cropsize=cfg.eval_crop,
        cropstride=2. / 3,
        flip=True,
        scales=cfg.eval_scales,
        lb_ignore=255,
    )
    mIOU = ms_flip_crop(net, dl, cfg.n_cats)
    heads.append('ms_flip_crop')
    mious.append(mIOU)
    logger.info('ms crop mIOU is: %s\n', mIOU)

    net.aux_mode = org_aux
    return heads, mious
コード例 #2
0
def eval_model(cfg, net):
    is_dist = dist.is_initialized()
    dl = get_data_loader(cfg, mode='val', distributed=is_dist)
    net.eval()

    heads, mious = [], []
    logger = logging.getLogger()

    single_scale = MscEvalV0((1., ), False)
    mIOU = single_scale(net, dl, 19)
    heads.append('single_scale')
    mious.append(mIOU)
    logger.info('single mIOU is: %s\n', mIOU)

    single_crop = MscEvalCrop(
        cropsize=1024,
        cropstride=2. / 3,
        flip=False,
        scales=(1., ),
        lb_ignore=255,
    )
    mIOU = single_crop(net, dl, 19)
    heads.append('single_scale_crop')
    mious.append(mIOU)
    logger.info('single scale crop mIOU is: %s\n', mIOU)

    ms_flip = MscEvalV0((0.5, 0.75, 1, 1.25, 1.5, 1.75), True)
    mIOU = ms_flip(net, dl, 19)
    heads.append('ms_flip')
    mious.append(mIOU)
    logger.info('ms flip mIOU is: %s\n', mIOU)

    ms_flip_crop = MscEvalCrop(
        cropsize=1024,
        cropstride=2. / 3,
        flip=True,
        scales=(0.5, 0.75, 1.0, 1.25, 1.5, 1.75),
        lb_ignore=255,
    )
    mIOU = ms_flip_crop(net, dl, 19)
    heads.append('ms_flip_crop')
    mious.append(mIOU)
    logger.info('ms crop mIOU is: %s\n', mIOU)
    return heads, mious
コード例 #3
0
def train():
    logger = logging.getLogger()
    is_dist = dist.is_initialized()

    ## dataset
    dl = get_data_loader(cfg, mode='train', distributed=is_dist)

    ## model
    net, criteria_pre, criteria_aux = set_model()

    ## optimizer
    optim = set_optimizer(net)

    ## mixed precision training
    scaler = amp.GradScaler()

    ## ddp training
    net = set_model_dist(net)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ## train loop
    for it, (im, lb) in enumerate(dl):
        im = im.cuda()
        lb = lb.cuda()

        lb = torch.squeeze(lb, 1)

        optim.zero_grad()
        with amp.autocast(enabled=cfg.use_fp16):
            logits, *logits_aux = net(im)
            loss_pre = criteria_pre(logits, lb)
            loss_aux = [
                crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)
            ]
            loss = loss_pre + sum(loss_aux)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        torch.cuda.synchronize()

        time_meter.update()
        loss_meter.update(loss.item())
        loss_pre_meter.update(loss_pre.item())
        _ = [
            mter.update(lss.item())
            for mter, lss in zip(loss_aux_meters, loss_aux)
        ]

        ## print training log message
        if (it + 1) % 100 == 0:
            lr = lr_schdr.get_lr()
            lr = sum(lr) / len(lr)
            print_log_msg(it, cfg.max_iter, lr, time_meter, loss_meter,
                          loss_pre_meter, loss_aux_meters)
        lr_schdr.step()

    ## dump the final model and evaluate the result
    save_pth = osp.join(cfg.respth, 'model_final.pth')
    logger.info('\nsave models to {}'.format(save_pth))
    state = net.module.state_dict()
    if dist.get_rank() == 0: torch.save(state, save_pth)

    logger.info('\nevaluating the final model')
    torch.cuda.empty_cache()
    heads, mious = eval_model(cfg, net.module)
    logger.info(tabulate([
        mious,
    ], headers=heads, tablefmt='orgtbl'))

    return