Exemple #1
0
def loading_data():
    mean_std = cfg.DATA.MEAN_STD
    train_simul_transform = own_transforms.Compose([
        own_transforms.Scale(int(cfg.TRAIN.IMG_SIZE[0] / 0.875)),
        own_transforms.RandomCrop(cfg.TRAIN.IMG_SIZE),
        own_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = own_transforms.Compose([
        own_transforms.Scale(int(cfg.TRAIN.IMG_SIZE[0] / 0.875)),
        own_transforms.CenterCrop(cfg.TRAIN.IMG_SIZE)
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        own_transforms.MaskToTensor(),
        own_transforms.ChangeLabel(cfg.DATA.IGNORE_LABEL, cfg.DATA.NUM_CLASSES - 1)
    ])
    restore_transform = standard_transforms.Compose([
        own_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train', simul_transform=train_simul_transform, transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=16, shuffle=True)
    val_set = CityScapes('val', simul_transform=val_simul_transform, transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=cfg.VAL.BATCH_SIZE, num_workers=16, shuffle=False)

    return train_loader, val_loader, restore_transform
Exemple #2
0
def evaluate(respth='./res', dspth='./data'):
    ## logger
    logger = logging.getLogger()

    ## model
    logger.info('\n')
    logger.info('====' * 20)
    logger.info('evaluating the model ...\n')
    logger.info('setup and restore model')
    n_classes = 19
    net = BiSeNet(n_classes=n_classes)
    save_pth = osp.join(respth, 'model_final.pth')
    net.load_state_dict(torch.load(save_pth))
    net.cuda()
    net.eval()

    ## dataset
    batchsize = 5
    n_workers = 2
    dsval = CityScapes(dspth, mode='val')
    dl = DataLoader(dsval,
                    batch_size=batchsize,
                    shuffle=False,
                    num_workers=n_workers,
                    drop_last=False)

    ## evaluator
    logger.info('compute the mIOU')
    evaluator = MscEval(net, dl)

    ## eval
    mIOU = evaluator.evaluate()
    logger.info('mIOU is: {:.6f}'.format(mIOU))
Exemple #3
0
def evaluate(respth='./snapshots', dspth='../data/cityscapes'):
    logger = logging.getLogger()

    logger.info('\n')
    logger.info('====' * 20)
    logger.info('evaluating the model ...\n')
    logger.info('setup and restore model')
    n_classes = 19
    net = AttaNet(n_classes=n_classes)
    save_pth = osp.join(respth, 'model_final.pth')
    state_dict = torch.load(save_pth, map_location=torch.device('cpu'))
    net.load_state_dict(state_dict)
    net.cuda()
    net.eval()

    # dataset
    batchsize = 5
    n_workers = 2
    dsval = CityScapes(dspth, mode='val')
    dl = DataLoader(dsval,
                    batch_size=batchsize,
                    shuffle=False,
                    num_workers=n_workers,
                    drop_last=False)

    # evaluator
    logger.info('compute the mIOU')
    evaluator = MscEval(net, dl)

    # eval
    IOUs, mIOU = evaluator.evaluate()
    print(IOUs)
    print(mIOU)
    logger.info('mIOU is: {:.6f}'.format(mIOU))
Exemple #4
0
 def __init__(self, cfg, *args, **kwargs):
     self.cfg = cfg
     self.distributed = dist.is_initialized()
     ## dataloader
     dsval = CityScapes(cfg, mode='val')
     sampler = None
     if self.distributed:
         sampler = torch.utils.data.distributed.DistributedSampler(dsval)
     self.dl = DataLoader(dsval,
                          batch_size=cfg.eval_batchsize,
                          sampler=sampler,
                          shuffle=False,
                          num_workers=cfg.eval_n_workers,
                          drop_last=False)
Exemple #5
0
def evaluatev0(respth='./pretrained',
               dspth='./data',
               backbone='CatNetSmall',
               scale=0.75,
               use_boundary_2=False,
               use_boundary_4=False,
               use_boundary_8=False,
               use_boundary_16=False,
               use_conv_last=False):
    print('scale', scale)
    print('use_boundary_2', use_boundary_2)
    print('use_boundary_4', use_boundary_4)
    print('use_boundary_8', use_boundary_8)
    print('use_boundary_16', use_boundary_16)
    ## dataset
    batchsize = 5
    n_workers = 2
    dsval = CityScapes(dspth, mode='val')
    dl = DataLoader(dsval,
                    batch_size=batchsize,
                    shuffle=False,
                    num_workers=n_workers,
                    drop_last=False)

    n_classes = 19
    print("backbone:", backbone)
    net = BiSeNet(backbone=backbone,
                  n_classes=n_classes,
                  use_boundary_2=use_boundary_2,
                  use_boundary_4=use_boundary_4,
                  use_boundary_8=use_boundary_8,
                  use_boundary_16=use_boundary_16,
                  use_conv_last=use_conv_last)
    net.load_state_dict(torch.load(respth))
    net.cuda()
    net.eval()

    with torch.no_grad():
        single_scale = MscEvalV0(scale=scale)
        mIOU = single_scale(net, dl, 19)
    logger = logging.getLogger()
    logger.info('mIOU is: %s\n', mIOU)
def evaluate(respth='./res', dspth='./data/cityscapes', checkpoint=None):
    ## logger
    logger = logging.getLogger()

    ## model
    logger.info('\n')
    logger.info('====' * 20)
    logger.info('evaluating the model ...\n')
    logger.info('setup and restore model')
    n_classes = 19  #19
    net = ShelfNet(n_classes=n_classes)

    if checkpoint is None:
        save_pth = osp.join(respth, 'model_final.pth')
    else:
        save_pth = checkpoint

    net.load_state_dict(torch.load(save_pth))
    net.cuda()
    net.eval()

    ## dataset
    batchsize = 4
    n_workers = 10
    dsval = CityScapes(dspth, mode='val')
    # print("sjdusgdsds",dsval)
    dl = DataLoader(dsval,
                    batch_size=batchsize,
                    shuffle=False,
                    num_workers=n_workers,
                    drop_last=False)

    ## evaluator
    logger.info('compute the mIOU')
    evaluator = MscEval(net, dl, scales=[1.0], flip=False)
    ## eval
    mIOU = evaluator.evaluate()
    logger.info('mIOU is: {:.6f}'.format(mIOU))
Exemple #7
0
def evaluate(respth='./res', dspth='./data'):
    ## logger
    logger = logging.getLogger()

    ## model
    logger.info('\n')
    logger.info('====' * 20)
    logger.info('evaluating the model ...\n')
    logger.info('setup and restore model')
    n_classes = 19
    net = BiSeNet(n_classes=n_classes)
    save_pth = osp.join(respth, 'model_final.pth')
    net.load_state_dict(torch.load(save_pth))
    net.cuda()
    net.eval()

    ## dataset
    batchsize = 2
    n_workers = 1
    dsval = CityScapes(dspth, mode='val')
    dl = DataLoader(dsval,
                    batch_size=batchsize,
                    shuffle=False,
                    num_workers=n_workers,
                    drop_last=False)

    ## evaluator
    logger.info('compute the mIOU')
    evaluator = MscEval(net, dl)

    # ## export to ONNX
    # dummy_input = Variable(torch.randn(batchsize, 3, 1024, 1024)).cuda()
    # onnx.export(net, dummy_input, "bisenet.proto", verbose=True)

    ## eval
    mIOU = evaluator.evaluate()
    logger.info('mIOU is: {:.6f}'.format(mIOU))
def train(verbose=True, **kwargs):
    args = kwargs['args']
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(
                backend = 'nccl',
                init_method = 'tcp://127.0.0.1:{}'.format(cfg.port),
                world_size = torch.cuda.device_count(),
                rank = args.local_rank
                )
    setup_logger(cfg.respth)
    logger = logging.getLogger()

    ## dataset
    ds = CityScapes(cfg, mode='train')
    sampler = torch.utils.data.distributed.DistributedSampler(ds)
    dl = DataLoader(ds,
                    batch_size = cfg.ims_per_gpu,
                    shuffle = False,
                    sampler = sampler,
                    num_workers = cfg.n_workers,
                    pin_memory = True,
                    drop_last = True)

    ## model
    net = Deeplab_v3plus(cfg)
    net.train()
    net.cuda()
    net = nn.parallel.DistributedDataParallel(net,
            device_ids = [args.local_rank, ],
            output_device = args.local_rank
            )
    n_min = cfg.ims_per_gpu*cfg.crop_size[0]*cfg.crop_size[1]//16
    criteria = OhemCELoss(thresh=cfg.ohem_thresh, n_min=n_min).cuda()

    ## optimizer
    optim = Optimizer(
            net,
            cfg.lr_start,
            cfg.momentum,
            cfg.weight_decay,
            cfg.warmup_steps,
            cfg.warmup_start_lr,
            cfg.max_iter,
            cfg.lr_power
            )

    ## train loop
    loss_avg = []
    st = glob_st = time.time()
    diter = iter(dl)
    n_epoch = 0
    for it in range(cfg.max_iter):
        try:
            im, lb = next(diter)
            if not im.size()[0]==cfg.ims_per_gpu: continue
        except StopIteration:
            n_epoch += 1
            sampler.set_epoch(n_epoch)
            diter = iter(dl)
            im, lb = next(diter)
        im = im.cuda()
        lb = lb.cuda()

        H, W = im.size()[2:]
        lb = torch.squeeze(lb, 1)

        optim.zero_grad()
        logits = net(im)
        loss = criteria(logits, lb)
        loss.backward()
        optim.step()

        loss_avg.append(loss.item())
        ## print training log message
        if it%cfg.msg_iter==0 and not it==0:
            loss_avg = sum(loss_avg) / len(loss_avg)
            lr = optim.lr
            ed = time.time()
            t_intv, glob_t_intv = ed - st, ed - glob_st
            eta = int((cfg.max_iter - it) * (glob_t_intv / it))
            eta = str(datetime.timedelta(seconds = eta))
            msg = ', '.join([
                    'iter: {it}/{max_it}',
                    'lr: {lr:4f}',
                    'loss: {loss:.4f}',
                    'eta: {eta}',
                    'time: {time:.4f}',
                ]).format(
                    it = it,
                    max_it = cfg.max_iter,
                    lr = lr,
                    loss = loss_avg,
                    time = t_intv,
                    eta = eta
                )
            logger.info(msg)
            loss_avg = []
            st = ed

    ## dump the final model and evaluate the result
    if verbose:
        net.cpu()
        save_pth = osp.join(cfg.respth, 'model_final.pth')
        state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
        if dist.get_rank()==0: torch.save(state, save_pth)
        logger.info('training done, model saved to: {}'.format(save_pth))
        logger.info('evaluating the final model')
        net.cuda()
        net.eval()
        evaluator = MscEval(cfg)
        mIOU = evaluator(net)
        logger.info('mIOU is: {}'.format(mIOU))
      as well incldue reseize->crop etc '''
transform_data = transforms.Compose([

    # not needed as transform cty will do the scaling   transforms.Resize(( opt.img_height, opt.img_width), Image.NEAREST),
    #    transforms.Resize(int(opt.img_height * 1.12), Image.NEAREST),
    #    transforms.RandomCrop((256, 256)),
    #    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
])

transform_cty = transforms.Compose(
    [transforms.Resize((opt.img_height, opt.img_width), Image.NEAREST)])

city_data = CityScapes(split='train',
                       transform_cty=transform_cty,
                       transform_data=transform_data)

city_data_val = CityScapes(split='val',
                           transform_cty=transform_cty,
                           transform_data=transform_data)

# Training data loader
dataloader = DataLoader(
    city_data,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)

val_dataloader = DataLoader(
Exemple #10
0
def train():
    args = parse_args()
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl',
                            init_method='tcp://127.0.0.1:33271',
                            world_size=torch.cuda.device_count(),
                            rank=args.local_rank)
    setup_logger(respth)

    ## dataset
    n_classes = 19
    n_img_per_gpu = 8
    n_workers = 4
    cropsize = [1024, 1024]
    ds = CityScapes('./data', cropsize=cropsize, mode='train')
    sampler = torch.utils.data.distributed.DistributedSampler(ds)
    dl = DataLoader(ds,
                    batch_size=n_img_per_gpu,
                    shuffle=False,
                    sampler=sampler,
                    num_workers=n_workers,
                    pin_memory=True,
                    drop_last=True)

    ## model
    ignore_idx = 255
    net = BiSeNet(n_classes=n_classes)
    if not args.ckpt is None:
        net.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    net.cuda()
    net.train()
    net = nn.parallel.DistributedDataParallel(net,
                                              device_ids=[
                                                  args.local_rank,
                                              ],
                                              output_device=args.local_rank)
    score_thres = 0.7
    n_min = n_img_per_gpu * cropsize[0] * cropsize[1] // 16
    criteria_p = OhemCELoss(thresh=score_thres,
                            n_min=n_min,
                            ignore_lb=ignore_idx)
    criteria_16 = OhemCELoss(thresh=score_thres,
                             n_min=n_min,
                             ignore_lb=ignore_idx)
    criteria_32 = OhemCELoss(thresh=score_thres,
                             n_min=n_min,
                             ignore_lb=ignore_idx)

    ## optimizer
    momentum = 0.9
    weight_decay = 5e-4
    lr_start = 1e-2
    max_iter = 80000
    power = 0.9
    warmup_steps = 1000
    warmup_start_lr = 1e-5
    optim = Optimizer(model=net.module,
                      lr0=lr_start,
                      momentum=momentum,
                      wd=weight_decay,
                      warmup_steps=warmup_steps,
                      warmup_start_lr=warmup_start_lr,
                      max_iter=max_iter,
                      power=power)

    ## train loop
    msg_iter = 50
    loss_avg = []
    st = glob_st = time.time()
    diter = iter(dl)
    epoch = 0
    for it in range(max_iter):
        try:
            im, lb = next(diter)
            if not im.size()[0] == n_img_per_gpu: raise StopIteration
        except StopIteration:
            epoch += 1
            sampler.set_epoch(epoch)
            diter = iter(dl)
            im, lb = next(diter)
        im = im.cuda()
        lb = lb.cuda()
        H, W = im.size()[2:]
        lb = torch.squeeze(lb, 1)

        optim.zero_grad()
        out, out16, out32 = net(im)
        lossp = criteria_p(out, lb)
        loss2 = criteria_16(out16, lb)
        loss3 = criteria_32(out32, lb)
        loss = lossp + loss2 + loss3
        loss.backward()
        optim.step()

        loss_avg.append(loss.item())
        ## print training log message
        if (it + 1) % msg_iter == 0:
            loss_avg = sum(loss_avg) / len(loss_avg)
            lr = optim.lr
            ed = time.time()
            t_intv, glob_t_intv = ed - st, ed - glob_st
            eta = int((max_iter - it) * (glob_t_intv / it))
            eta = str(datetime.timedelta(seconds=eta))
            msg = ', '.join([
                'it: {it}/{max_it}',
                'lr: {lr:4f}',
                'loss: {loss:.4f}',
                'eta: {eta}',
                'time: {time:.4f}',
            ]).format(it=it + 1,
                      max_it=max_iter,
                      lr=lr,
                      loss=loss_avg,
                      time=t_intv,
                      eta=eta)
            logger.info(msg)
            loss_avg = []
            st = ed

    ## dump the final model
    save_pth = osp.join(respth, 'model_final.pth')
    net.cpu()
    state = net.module.state_dict() if hasattr(net,
                                               'module') else net.state_dict()
    if dist.get_rank() == 0: torch.save(state, save_pth)
    logger.info('training done, model saved to: {}'.format(save_pth))
def train(verbose=True, **kwargs):
    args = kwargs['args']
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl',
                            init_method='tcp://127.0.0.1:{}'.format(cfg.port),
                            world_size=torch.cuda.device_count(),
                            rank=args.local_rank)
    setup_logger(cfg.respth)
    logger = logging.getLogger()

    ## dataset
    ds = CityScapes(cfg, mode='train', num_copys=2)
    sampler = torch.utils.data.distributed.DistributedSampler(ds)
    dl = DataLoader(ds,
                    batch_size=cfg.ims_per_gpu,
                    shuffle=False,
                    sampler=sampler,
                    num_workers=cfg.n_workers,
                    collate_fn=collate_fn2,
                    pin_memory=True,
                    drop_last=True)

    ## model
    net = Deeplab_v3plus(cfg)
    net.train()
    net.cuda()
    net = nn.parallel.DistributedDataParallel(net,
                                              device_ids=[
                                                  args.local_rank,
                                              ],
                                              output_device=args.local_rank)
    n_min = cfg.ims_per_gpu * cfg.crop_size[0] * cfg.crop_size[1] // 16
    criteria = OhemCELoss(thresh=cfg.ohem_thresh, n_min=n_min).cuda()
    Criterion = pgc_loss(use_pgc=[0, 1, 2], criteria=criteria)
    ## optimizer
    optim = Optimizer(net, cfg.lr_start, cfg.momentum, cfg.weight_decay,
                      cfg.warmup_steps, cfg.warmup_start_lr, cfg.max_iter,
                      cfg.lr_power)
    alpha, beta = cfg.alpha, cfg.beta
    ## train loop
    loss_avg = []
    pgc_avg = []
    ce_avg = []
    ssp_avg = []
    ohem_avg = []

    st = glob_st = time.time()
    diter = iter(dl)
    n_epoch = 0
    for it in range(cfg.max_iter):
        try:
            im, lb, overlap, flip = next(diter)
            if not im.size()[0] != cfg.ims_per_gpu // 2:
                continue
        except StopIteration:
            n_epoch += 1
            sampler.set_epoch(n_epoch)
            diter = iter(dl)
            im, lb, overlap, flip = next(diter)
        im = im.cuda()
        lb = lb.cuda()

        H, W = im.size()[2:]
        lb = torch.squeeze(lb, 1)
        optim.zero_grad()
        im1, im2 = im[::2], im[1::2]
        lb1, lb2 = lb[::2], lb[1::2]
        logits1 = net(im1)
        logits2 = net(im2)
        # logits = torch.cat([logits1[-1], logits2[-1]], dim=0)

        outputs = []
        for f1, f2 in zip(logits1, logits2):
            outputs.append([f1, f2])
        logits = torch.cat([logits1[-1], logits2[-1]], dim=0)

        mse, sym_ce, mid_mse, mid_ce, mid_l1, ce = Criterion(
            outputs, overlap, flip, lb)
        # loss = criteria(logits, lb)
        loss = beta * sym_ce + ce
        gc_loss = sum(mid_mse)
        loss += alpha * gc_loss
        loss.backward()

        optim.step()

        loss_avg.append(loss.item())
        ohem_avg.append(ce.item())
        pgc_avg.append(gc_loss.item())
        ssp_avg.append(sym_ce.item())
        ## print training log message
        if it % cfg.msg_iter == 0 and not it == 0:
            loss_avg = sum(loss_avg) / len(loss_avg)
            ohem = sum(ohem_avg) / len(ohem_avg)
            pgc = sum(pgc_avg) / len(pgc_avg)
            ssp = sum(ssp_avg) / len(ssp_avg)
            lr = optim.lr
            ed = time.time()
            t_intv, glob_t_intv = ed - st, ed - glob_st
            eta = int((cfg.max_iter - it) * (glob_t_intv / it))
            eta = str(datetime.timedelta(seconds=eta))
            msg = ', '.join([
                'iter: {it}/{max_it}',
                'lr: {lr:4f}',
                'loss: {loss:.4f}',
                'ohem: {ohem:.4f}',
                'pgc: {pgc:.4f}',
                'ssp: {ssp:.4f}',
                'eta: {eta}',
                'time: {time:.4f}',
            ]).format(
                it=it,
                max_it=cfg.max_iter,
                lr=lr,
                loss=loss_avg,
                time=t_intv,
                eta=eta,
                ohem=ohem,
                pgc=pgc,
                ssp=ssp,
            )
            logger.info(msg)
            loss_avg = []
            pgc_avg = []
            ssp_avg = []
            ohem_avg = []
            st = ed

    ## dump the final model and evaluate the result
    if verbose:
        net.cpu()
        save_pth = osp.join(cfg.respth, 'model_final.pth')
        state = net.module.state_dict() if hasattr(
            net, 'module') else net.state_dict()
        if dist.get_rank() == 0: torch.save(state, save_pth)
        logger.info('training done, model saved to: {}'.format(save_pth))
        logger.info('evaluating the final model')
        net.cuda()
        net.eval()
        evaluator = MscEval(cfg)
        mIOU = evaluator(net)
        logger.info('mIOU is: {}'.format(mIOU))
Exemple #12
0
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


if __name__ == "__main__":
    #----init dataloader----
    crop_size = [CROP_SIZE, CROP_SIZE]
    train_dataset = CityScapes(DATA_PATH, cropsize=crop_size, mode='train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=BATCH_SIZE,
                                                   num_workers=NUM_WORKERS,
                                                   shuffle=True,
                                                   pin_memory=True)

    #-----init model-----
    model = bisenet(19, training=True)
    model.train()  #---set status to training----
    resnet_state_dict = torch.load(
        RESNET_MODEL_PATH)  #---load pretrained resnet model---
    #model.cp.res18.load_state_dict({k:v for k, v in resnet_state_dict.items() if k in model.cp.res18.state_dict().keys()})
    res18_dict = {
        k: v
        for k, v in resnet_state_dict.items()
Exemple #13
0
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

from cityscapes import CityScapes
from paz import processors as pr
from processors import Round, MasksToColors
from model import UNET_VGG16

label_path = '/home/octavio/Downloads/dummy/gtFine/'
image_path = '/home/octavio/Downloads/dummy/RGB_images/leftImg8bit/'
data_manager = CityScapes(image_path, label_path, 'test')
data = data_manager.load_data()


class PostprocessSegmentation(pr.SequentialProcessor):
    def __init__(self, model, colors=None):
        super(PostprocessSegmentation, self).__init__()
        self.add(pr.UnpackDictionary(['image_path']))
        self.add(pr.LoadImage())
        self.add(pr.ResizeImage(model.input_shape[1:3]))
        self.add(pr.ConvertColorSpace(pr.RGB2BGR))
        self.add(pr.SubtractMeanImage(pr.BGR_IMAGENET_MEAN))
        self.add(pr.ExpandDims(0))
        self.add(pr.Predict(model))
        self.add(pr.Squeeze(0))
        self.add(Round())
        self.add(MasksToColors(model.output_shape[-1], colors))
        self.add(pr.DenormalizeImage())
        self.add(pr.CastImage('uint8'))
        self.add(pr.ShowImage())
Exemple #14
0
def train(verbose=True, **kwargs):
    args = kwargs['args']
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl',
                            init_method='tcp://127.0.0.1:{}'.format(cfg.port),
                            world_size=torch.cuda.device_count(),
                            rank=args.local_rank)
    setup_logger(cfg.respth)
    logger = logging.getLogger()

    ## dataset
    ds = CityScapes(cfg, mode='train_val')
    sampler = torch.utils.data.distributed.DistributedSampler(ds)
    dl = DataLoader(ds,
                    batch_size=cfg.ims_per_gpu,
                    shuffle=False,
                    sampler=sampler,
                    num_workers=cfg.n_workers,
                    pin_memory=True,
                    drop_last=True)

    ## model
    net = EaNet(cfg)
    net.cuda()
    it_start = 0
    n_epoch = 0

    ## optimizer
    optim = Optimizer(
        net,
        cfg.lr_start,
        cfg.momentum,
        cfg.weight_decay,
        cfg.warmup_steps,
        cfg.warmup_start_lr,
        cfg.max_iter,
        cfg.lr_power,
        # start_iter = it_start
    )

    ## resume
    if cfg.resume:
        print("=> loading checkpoint '{}'".format(cfg.resume))
        checkpoint = torch.load(cfg.resume)
        if '.tar' in cfg.resume:
            net.load_state_dict(checkpoint['model'])
            optim.optim.load_state_dict(checkpoint['optimizer'])
            # it_start = checkpoint['it']
            n_epoch = checkpoint['epoch']
            bestMIOU = checkpoint['mIOU']
            # optim.it = it_start

            print('Pth.Tar Load model from {}'.format(cfg.resume))
        else:
            net.load_state_dict(checkpoint)
            print('Pth Load model from {}'.format(cfg.resume))
        print('pretrained model loaded')
        net.eval()
        evaluator = MscEval(cfg)
        mIOU = evaluator(net)
        print('mIOU start from %f' % mIOU)
        del checkpoint

    net.train()

    net = nn.parallel.DistributedDataParallel(net,
                                              device_ids=[
                                                  args.local_rank,
                                              ],
                                              output_device=args.local_rank)
    n_min = cfg.ims_per_gpu * cfg.crop_size[0] * cfg.crop_size[1] // 16
    #criteria = OhemCELoss(thresh=cfg.ohem_thresh, n_min=n_min).cuda()
    criteria = ECELoss(thresh=cfg.ohem_thresh,
                       n_min=n_min,
                       n_classes=cfg.n_classes,
                       alpha=cfg.alpha,
                       radius=cfg.radius,
                       beta=cfg.beta,
                       ignore_lb=cfg.ignore_label,
                       mode=cfg.mode).cuda()

    ## train loop
    loss_avg = []
    st = glob_st = time.time()
    diter = iter(dl)
    # n_epoch = 0
    counter = 0
    #count for the epoch finished
    #已经跑结束的epoch
    epochF = 0
    bestMIOU = 0

    for it in range(it_start, cfg.max_iter):
        try:
            im, lb = next(diter)
            if not im.size()[0] == cfg.ims_per_gpu: continue
        except StopIteration:
            n_epoch += 1
            sampler.set_epoch(n_epoch)
            diter = iter(dl)
            im, lb = next(diter)
        im = im.cuda()
        lb = lb.cuda()

        H, W = im.size()[2:]
        lb = torch.squeeze(lb, 1)

        try:
            optim.zero_grad()
            logits = net(im)
            loss = criteria(logits, lb)

            loss.backward()
            optim.step()
        except RuntimeError as e:
            if 'out of memory' in e:
                print('| WARNING: run out of memory')
                if hasattr(troch.cuda, 'empty_cach'):
                    torch.cuda.empty_cache()
            else:
                raise e
        '''
        logits = net(im)
        loss = criteria(logits, lb)
        loss = loss / (cfg.ims_per_gpu)
        counter += 1
        loss.backward()
        
        if counter == cfg.ims_per_gpu:
            optim.step()
            optim.zero_grad()
            counter = 0
        '''
        loss_avg.append(loss.item())
        ## print training log message
        if it % cfg.msg_iter == 0 and not it == 0:
            loss_avg = sum(loss_avg) / len(loss_avg)
            lr = optim.lr
            ed = time.time()
            t_intv, glob_t_intv = ed - st, ed - glob_st
            eta = int((cfg.max_iter - it) * (glob_t_intv / it))
            eta = str(datetime.timedelta(seconds=eta))
            msg = ', '.join([
                'iter: {it}/{max_it}',
                'lr: {lr:4f}',
                'loss: {loss:.4f}',
                'eta: {eta}',
                'time: {time:.4f}',
            ]).format(it=it,
                      max_it=cfg.max_iter,
                      lr=lr,
                      loss=loss_avg,
                      time=t_intv,
                      eta=eta)

            logger.info(msg)
            loss_avg = []
            st = ed
        #每隔一段时间评估一次
        if n_epoch > epochF and n_epoch > 20:
            #置为相等的了
            epochF = n_epoch
            #if (n_epoch > 35) and it%(5*cfg.msg_iter) == 0 and not it==0:
            # net.cpu()
            # save_pth = osp.join(cfg.respth, 'model_final_best.pth')
            # state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
            # if dist.get_rank()==0: torch.save(state, save_pth)
            # logger.info('training done, model saved to: {}'.format(save_pth))
            # logger.info('evaluating the final model')
            # net.cuda()
            net.eval()
            evaluator = MscEval(cfg)
            mIOU = evaluator(net)
            logger.info('mIOU is: {}'.format(mIOU))

            # 保存check point
            save_pth = osp.join(cfg.respth, 'checkpoint.pth.tar')
            state = net.module.state_dict() if hasattr(
                net, 'module') else net.state_dict()
            if dist.get_rank() == 0:
                stateF = {
                    'model': state,
                    'lr': optim.lr,
                    'mIOU': mIOU,
                    'it': it,
                    'epoch': n_epoch,
                    'optimizer': optim.optim.state_dict(),
                }
                torch.save(stateF, save_pth)

            if mIOU > bestMIOU:
                logger.info('Get a new best mIMOU:{} at epoch:{}'.format(
                    bestMIOU, n_epoch))
                #print('Get a new best mIMOU:{}'.format(bestMIOU))
                bestMIOU = mIOU
                #net.cpu()
                save_pth = osp.join(cfg.respth,
                                    'model_final_{}.pth'.format(n_epoch))
                state = net.module.state_dict() if hasattr(
                    net, 'module') else net.state_dict()
                if dist.get_rank() == 0: torch.save(state, save_pth)
                #重新加载到cuda
                #net.cuda()

            net.train()
    if verbose:
        net.cpu()
        save_pth = osp.join(cfg.respth, 'model_final.pth.rar')
        state = net.module.state_dict() if hasattr(
            net, 'module') else net.state_dict()
        stateF = {
            'model': state,
            'lr': optim.lr,
            'mIOU': mIOU,
            'it': it,
            'epoch': n_epoch,
            'optimizer': optim.optim.state_dict(),
        }
        torch.save(stateF, save_pth)
        #if dist.get_rank()==0: torch.save(state, save_pth)
        logger.info('training done, model saved to: {}'.format(save_pth))
        logger.info('evaluating the final model')
        net.cuda()
        net.eval()
        evaluator = MscEval(cfg)
        mIOU = evaluator(net)
        logger.info('mIOU is: {}'.format(mIOU))
Exemple #15
0
def train():
    args = parse_args()
    
    save_pth_path = os.path.join(args.respath, 'pths')
    dspth = './data'
    
    # print(save_pth_path)
    # print(osp.exists(save_pth_path))
    # if not osp.exists(save_pth_path) and dist.get_rank()==0: 
    if not osp.exists(save_pth_path):
        os.makedirs(save_pth_path)
    
    
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(
                backend = 'nccl',
                init_method = 'tcp://127.0.0.1:33274',
                world_size = torch.cuda.device_count(),
                rank=args.local_rank
                )
    
    setup_logger(args.respath)
    ## dataset
    n_classes = 19
    n_img_per_gpu = args.n_img_per_gpu
    n_workers_train = args.n_workers_train
    n_workers_val = args.n_workers_val
    use_boundary_16 = args.use_boundary_16
    use_boundary_8 = args.use_boundary_8
    use_boundary_4 = args.use_boundary_4
    use_boundary_2 = args.use_boundary_2
    
    mode = args.mode
    cropsize = [1024, 512]
    randomscale = (0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)

    if dist.get_rank()==0: 
        logger.info('n_workers_train: {}'.format(n_workers_train))
        logger.info('n_workers_val: {}'.format(n_workers_val))
        logger.info('use_boundary_2: {}'.format(use_boundary_2))
        logger.info('use_boundary_4: {}'.format(use_boundary_4))
        logger.info('use_boundary_8: {}'.format(use_boundary_8))
        logger.info('use_boundary_16: {}'.format(use_boundary_16))
        logger.info('mode: {}'.format(args.mode))
    
    
    ds = CityScapes(dspth, cropsize=cropsize, mode=mode, randomscale=randomscale)
    sampler = torch.utils.data.distributed.DistributedSampler(ds)
    dl = DataLoader(ds,
                    batch_size = n_img_per_gpu,
                    shuffle = False,
                    sampler = sampler,
                    num_workers = n_workers_train,
                    pin_memory = False,
                    drop_last = True)
    # exit(0)
    dsval = CityScapes(dspth, mode='val', randomscale=randomscale)
    sampler_val = torch.utils.data.distributed.DistributedSampler(dsval)
    dlval = DataLoader(dsval,
                    batch_size = 2,
                    shuffle = False,
                    sampler = sampler_val,
                    num_workers = n_workers_val,
                    drop_last = False)

    ## model
    ignore_idx = 255
    net = BiSeNet(backbone=args.backbone, n_classes=n_classes, pretrain_model=args.pretrain_path, 
    use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4, use_boundary_8=use_boundary_8, 
    use_boundary_16=use_boundary_16, use_conv_last=args.use_conv_last)

    if not args.ckpt is None:
        net.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    net.cuda()
    net.train()
    net = nn.parallel.DistributedDataParallel(net,
            device_ids = [args.local_rank, ],
            output_device = args.local_rank,
            find_unused_parameters=True
            )

    score_thres = 0.7
    n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16
    criteria_p = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    criteria_16 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    criteria_32 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    boundary_loss_func = DetailAggregateLoss()
    ## optimizer
    maxmIOU50 = 0.
    maxmIOU75 = 0.
    momentum = 0.9
    weight_decay = 5e-4
    lr_start = 1e-2
    max_iter = args.max_iter
    save_iter_sep = args.save_iter_sep
    power = 0.9
    warmup_steps = args.warmup_steps
    warmup_start_lr = 1e-5

    if dist.get_rank()==0: 
        print('max_iter: ', max_iter)
        print('save_iter_sep: ', save_iter_sep)
        print('warmup_steps: ', warmup_steps)
    optim = Optimizer(
            model = net.module,
            loss = boundary_loss_func,
            lr0 = lr_start,
            momentum = momentum,
            wd = weight_decay,
            warmup_steps = warmup_steps,
            warmup_start_lr = warmup_start_lr,
            max_iter = max_iter,
            power = power)
    
    ## train loop
    msg_iter = 50
    loss_avg = []
    loss_boundery_bce = []
    loss_boundery_dice = []
    st = glob_st = time.time()
    diter = iter(dl)
    epoch = 0
    for it in range(max_iter):
        try:
            im, lb = next(diter)
            if not im.size()[0]==n_img_per_gpu: raise StopIteration
        except StopIteration:
            epoch += 1
            sampler.set_epoch(epoch)
            diter = iter(dl)
            im, lb = next(diter)
        im = im.cuda()
        lb = lb.cuda()
        H, W = im.size()[2:]
        lb = torch.squeeze(lb, 1)

        optim.zero_grad()


        if use_boundary_2 and use_boundary_4 and use_boundary_8:
            out, out16, out32, detail2, detail4, detail8 = net(im)
        
        if (not use_boundary_2) and use_boundary_4 and use_boundary_8:
            out, out16, out32, detail4, detail8 = net(im)

        if (not use_boundary_2) and (not use_boundary_4) and use_boundary_8:
            out, out16, out32, detail8 = net(im)

        if (not use_boundary_2) and (not use_boundary_4) and (not use_boundary_8):
            out, out16, out32 = net(im)

        lossp = criteria_p(out, lb)
        loss2 = criteria_16(out16, lb)
        loss3 = criteria_32(out32, lb)
        
        boundery_bce_loss = 0.
        boundery_dice_loss = 0.
        
        
        if use_boundary_2: 
            # if dist.get_rank()==0:
            #     print('use_boundary_2')
            boundery_bce_loss2,  boundery_dice_loss2 = boundary_loss_func(detail2, lb)
            boundery_bce_loss += boundery_bce_loss2
            boundery_dice_loss += boundery_dice_loss2
        
        if use_boundary_4:
            # if dist.get_rank()==0:
            #     print('use_boundary_4')
            boundery_bce_loss4,  boundery_dice_loss4 = boundary_loss_func(detail4, lb)
            boundery_bce_loss += boundery_bce_loss4
            boundery_dice_loss += boundery_dice_loss4

        if use_boundary_8:
            # if dist.get_rank()==0:
            #     print('use_boundary_8')
            boundery_bce_loss8,  boundery_dice_loss8 = boundary_loss_func(detail8, lb)
            boundery_bce_loss += boundery_bce_loss8
            boundery_dice_loss += boundery_dice_loss8

        loss = lossp + loss2 + loss3 + boundery_bce_loss + boundery_dice_loss
        
        loss.backward()
        optim.step()

        loss_avg.append(loss.item())

        loss_boundery_bce.append(boundery_bce_loss.item())
        loss_boundery_dice.append(boundery_dice_loss.item())

        ## print training log message
        if (it+1)%msg_iter==0:
            loss_avg = sum(loss_avg) / len(loss_avg)
            lr = optim.lr
            ed = time.time()
            t_intv, glob_t_intv = ed - st, ed - glob_st
            eta = int((max_iter - it) * (glob_t_intv / it))
            eta = str(datetime.timedelta(seconds=eta))

            loss_boundery_bce_avg = sum(loss_boundery_bce) / len(loss_boundery_bce)
            loss_boundery_dice_avg = sum(loss_boundery_dice) / len(loss_boundery_dice)
            msg = ', '.join([
                'it: {it}/{max_it}',
                'lr: {lr:4f}',
                'loss: {loss:.4f}',
                'boundery_bce_loss: {boundery_bce_loss:.4f}',
                'boundery_dice_loss: {boundery_dice_loss:.4f}',
                'eta: {eta}',
                'time: {time:.4f}',
            ]).format(
                it = it+1,
                max_it = max_iter,
                lr = lr,
                loss = loss_avg,
                boundery_bce_loss = loss_boundery_bce_avg,
                boundery_dice_loss = loss_boundery_dice_avg,
                time = t_intv,
                eta = eta
            )
            
            logger.info(msg)
            loss_avg = []
            loss_boundery_bce = []
            loss_boundery_dice = []
            st = ed
            # print(boundary_loss_func.get_params())
        if (it+1)%save_iter_sep==0:# and it != 0:
            
            ## model
            logger.info('evaluating the model ...')
            logger.info('setup and restore model')
            
            net.eval()

            # ## evaluator
            logger.info('compute the mIOU')
            with torch.no_grad():
                single_scale1 = MscEvalV0()
                mIOU50 = single_scale1(net, dlval, n_classes)

                single_scale2= MscEvalV0(scale=0.75)
                mIOU75 = single_scale2(net, dlval, n_classes)


            save_pth = osp.join(save_pth_path, 'model_iter{}_mIOU50_{}_mIOU75_{}.pth'
            .format(it+1, str(round(mIOU50,4)), str(round(mIOU75,4))))
            
            state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
            if dist.get_rank()==0: 
                torch.save(state, save_pth)

            logger.info('training iteration {}, model saved to: {}'.format(it+1, save_pth))

            if mIOU50 > maxmIOU50:
                maxmIOU50 = mIOU50
                save_pth = osp.join(save_pth_path, 'model_maxmIOU50.pth'.format(it+1))
                state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
                if dist.get_rank()==0: 
                    torch.save(state, save_pth)
                    
                logger.info('max mIOU model saved to: {}'.format(save_pth))
            
            if mIOU75 > maxmIOU75:
                maxmIOU75 = mIOU75
                save_pth = osp.join(save_pth_path, 'model_maxmIOU75.pth'.format(it+1))
                state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
                if dist.get_rank()==0: torch.save(state, save_pth)
                logger.info('max mIOU model saved to: {}'.format(save_pth))
            
            logger.info('mIOU50 is: {}, mIOU75 is: {}'.format(mIOU50, mIOU75))
            logger.info('maxmIOU50 is: {}, maxmIOU75 is: {}.'.format(maxmIOU50, maxmIOU75))

            net.train()
    
    ## dump the final model
    save_pth = osp.join(save_pth_path, 'model_final.pth')
    net.cpu()
    state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
    if dist.get_rank()==0: torch.save(state, save_pth)
    logger.info('training done, model saved to: {}'.format(save_pth))
    print('epoch: ', epoch)