def prepare_dataloader(self):
        cfg = cfguh().cfg
        dataset = get_dataset()()
        loader = get_loader()()
        transforms = get_transform()()
        formatter = get_formatter()()

        trainset = dataset(
            mode = cfg.DATA.DATASET_MODE, 
            loader = loader, 
            estimator = None, 
            transforms = transforms, 
            formatter = formatter,
        )
        sampler = DistributedSampler(
            dataset=trainset)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size = cfg.TRAIN.BATCH_SIZE_PER_GPU, 
            sampler = sampler, 
            num_workers = cfg.DATA.NUM_WORKERS_PER_GPU, 
            drop_last = False, pin_memory = False,
            collate_fn = collate(), 
        )
        return {
            'dataloader' : trainloader,
            'sampler'    : sampler}
    def get_classifier(self, RANK):
        from easydict import EasyDict as edict
        from lib.model_zoo.get_model import get_model
        from lib.optimizer.get_optimizer import get_optimizer
        cfg = cfguh().cfg
        cfgm = edict()
        cfgm.RESNET = edict()
        cfgm.RESNET.MODEL_TAGS = ['resnet50']
        cfgm.RESNET.PRETRAINED_PTH = cfg.TRAIN.CLASSIFIER_PATH
        cfgm.RESNET.INPUT_CHANNEL_NUM = 1
        cfgm.RESNET.CONV_TYPE = 'conv'
        cfgm.RESNET.BN_TYPE = 'bn'
        cfgm.RESNET.RELU_TYPE = 'relu'
        cfgm.RESNET.CLASS_NUM = 37
        cfgm.RESNET.IGNORE_LABEL = cfg.DATA.IGNORE_LABEL
        net = get_model()('resnet', cfgm)
        if cfg.CUDA:
            net.to(RANK)
            net = torch.nn.parallel.DistributedDataParallel(
                net, device_ids=[RANK], 
                find_unused_parameters=True)        
        net.train()
        if not cfg.TRAIN.UPDATE_CLASSIFIER:
            from lib.model_zoo.utils import eval_bn
            # deactivate the running mean and var 
            net = eval_bn(net)

        optimizer = get_optimizer(net, opmgr=None)
        return net, optimizer
 def save(self, net, **kwargs):
     cfg = cfguh().cfg
     output_model_file = osp.join(
         cfg.LOG_DIR,
         '{}_{}_last.pth'.format(
             cfg.EXPERIMENT_ID, cfg.MODEL.MODEL_NAME))
     print_log('Saving model file {0}'.format(output_model_file))
     save_state_dict(net, output_model_file)
Esempio n. 4
0
 def output_f(self, item):
     outdir = osp.join(cfguh().cfg.LOG_DIR, 'result')
     if not osp.exists(outdir):
         os.makedirs(outdir)
     outformat = osp.join(outdir, '{}.png')
     for i, fni in enumerate(item['fn']):
         p = (item['prfn'][i] * 255).astype(np.uint8)
         PIL.Image.fromarray(p).save(outformat.format(fni))
    def __call__(self, 
                 RANK,
                 **kwargs):
        self.RANK = RANK
        cfg = self.cfg
        cfguh().save_cfg(cfg) 
        dist.init_process_group(
            backend = cfg.DIST_BACKEND,
            init_method = cfg.DIST_URL,
            rank = RANK,
            world_size = cfg.GPU_COUNT,
        )

        # need to set random seed again
        if isinstance(cfg.RND_SEED, int):
            np.random.seed(cfg.RND_SEED)
            torch.manual_seed(cfg.RND_SEED)

        time_start = timeit.default_timer()

        para = {
            'RANK':RANK,
            'itern_total':0}
        dl_para = self.prepare_dataloader()
        if not isinstance(dl_para, dict):
            raise ValueError
        para.update(dl_para)
        md_para = self.prepare_model()
        if not isinstance(md_para, dict):
            raise ValueError
        para.update(md_para)

        for stage in self.registered_stages:
            stage_para = stage(**para)
            if stage_para is not None:
                para.update(stage_para)

        # save the model
        if RANK == 0:
            if 'TRAIN' in cfg:
                self.save(**para)
        print_log(
            'Total {:.2f} seconds'.format(timeit.default_timer() - time_start))
        self.RANK = None
        dist.destroy_process_group()
 def __call__(self, **para):
     rv = super().__call__(**para)
     cfg = cfguh().cfg
     if cfg.TRAIN.UPDATE_CLASSIFIER:
         output_model_file = osp.join(
             cfg.LOG_DIR,
             '{}_resnet50_clsnet.pth'.format(cfg.EXPERIMENT_ID))
         print_log('Saving model file {0}'.format(output_model_file))
         save_state_dict(self.clsnet, output_model_file)
     return rv
 def main(self, **para):
     cfg = cfguh().cfg
     try:
         if para['itern'] == cfg.TRAIN.ACTIVATE_REFINEMENT_AT_ITER:
             try:
                 para['net'].module.activate_refinement()
             except:
                 para['net'].activate_refinement()
     except:
         pass
     return super().main(**para)
    def prepare_model(self):
        cfg = cfguh().cfg
        net = get_model()()
        paras = {}
        istrain = 'TRAIN' in cfg
        if istrain:
            if 'TEST' in cfg:
                raise ValueError

        # save the init model
        if istrain:
            if (cfg.TRAIN.SAVE_INIT_MODEL) and (self.RANK==0):
                output_model_file = osp.join(
                    cfg.LOG_DIR, '{}_{}.pth.init'.format(
                        cfg.EXPERIMENT_ID, cfg.MODEL.MODEL_NAME))
                save_state_dict(net, output_model_file)

        if cfg.CUDA:
            net.to(self.RANK)
            net = torch.nn.parallel.DistributedDataParallel(
                net, device_ids=[self.RANK], 
                find_unused_parameters=True)

        if istrain:
            net.train() 
            if cfg.TRAIN.USE_OPTIM_MANAGER:
                try:
                    opmgr = net.module.opmgr
                except:
                    opmgr = net.opmgr
                opmgr.set_lrscale(cfg.TRAIN.OPTIM_MANAGER_LRSCALE)
            else:
                opmgr = None

            optimizer = get_optimizer(net, opmgr = opmgr)
            compute_lr = lr_scheduler(cfg.TRAIN.LR_TYPE)
            paras.update({
                'net'       : net,
                'optimizer' : optimizer,
                'compute_lr': compute_lr,
                'opmgr'     : opmgr,
            })
        else:
            net.eval()
            paras.update({'net': net})
        return paras
    def main(self,
             batch,
             net,
             lr,
             optimizer,
             opmgr,
             RANK,
             isinit,
             itern,
             **kwargs):
        cfg = cfguh().cfg
        im, gtsem, _ = batch

        try:
            if itern == cfg.TRAIN.ACTIVATE_REFINEMENT_AT_ITER:
                try:
                    net.module.activate_refinement()
                except:
                    net.activate_refinement()
        except:
            pass

        if cfg.CUDA:
            im = im.to(RANK)
            gtsem = gtsem.to(RANK)

        adjust_lr(optimizer, lr, opmgr=opmgr)
        optimizer.zero_grad()
        loss_item = net(im, gtsem)

        if self.lossf is None:
            self.lossf = myloss.finalize_loss(
                weight=cfg.TRAIN.LOSS_WEIGHT, 
                normalize_weight=cfg.TRAIN.LOSS_WEIGHT_NORMALIZED)
        loss, loss_item = self.lossf(loss_item)

        loss.backward()
        if isinit:
            optimizer.zero_grad()
        else:
            optimizer.step()

        return {'item': loss_item}
 def save(self, itern, epochn, **paras):
     cfg = cfguh().cfg
     net = paras['net']
     if itern is not None:
         save_state_dict(
             net, 
             osp.join(
                 cfg.LOG_DIR,
                 '{}_iter_{}.pth'.format(cfg.EXPERIMENT_ID, itern)))
     elif epochn is not None:
         save_state_dict(
             net, 
             osp.join(
                 cfg.LOG_DIR,
                 '{}_epoch_{}.pth'.format(cfg.EXPERIMENT_ID, epochn)))
     else:
         save_state_dict(
             net, 
             osp.join(
                 cfg.LOG_DIR,
                 '{}.pth'.format(cfg.EXPERIMENT_ID)))
Esempio n. 11
0
    def __call__(self, RANK, dataloader, net, **paras):
        cfg = cfguh().cfg
        evaluator = eva.distributed_evaluator(name=['rfn'],
                                              sample_n=len(dataloader.dataset))

        time_check = timeit.default_timer()

        for idx, batch in enumerate(dataloader):
            item = self.main(RANK=RANK, batch=batch, net=net, **paras)
            gtsem, prfn = [item[i] for i in ['gtsem', 'prfn']]

            evaluator['rfn'].bw_iandu(prfn,
                                      gtsem,
                                      class_n=cfg.DATA.EFFECTIVE_CLASS_NUM)
            evaluator.merge()

            if cfg.TEST.OUTPUT_RESULT:
                self.output_f(item)

            if cfg.TEST.VISUAL:
                raise NotImplementedError

            if idx % cfg.TEST.DISPLAY == cfg.TEST.DISPLAY - 1:
                print_log('processed.. {}, Time:{:.2f}s'.format(
                    idx + 1,
                    timeit.default_timer() - time_check))
                time_check = timeit.default_timer()

        sem_cname = dataloader.dataset.get_semantic_classname()

        eval_result = evaluator['rfn'].miou(classname=sem_cname,
                                            find_n_worst=cfg.TEST.FIND_N_WORST)
        evaluator['rfn'].fscore(classname=sem_cname)

        if RANK == 0:
            evaluator.summary()
            resultf = osp.join(cfg.LOG_DIR, 'result.json')
            evaluator.save(resultf, cfg)
        return eval_result
    def main(self,
             batch,
             net,
             lr,
             optimizer,
             opmgr,
             RANK,
             itern,
             isinit = False,
             **kwargs):
        cfg = cfguh().cfg
        roi_size = cfg.TRAIN.ROI_ALIGN_SIZE
        update_cls = cfg.TRAIN.UPDATE_CLASSIFIER
        act_after = cfg.TRAIN.ACTIVATE_CLASSIFIER_FOR_SEGMODEL_AFTER

        im, sem, bbx, chins, chcls, _ = batch
        # add batch index at front in bbx
        bbx = [
            torch.cat([torch.ones(ci.shape[0], 1).float()*idx, ci], dim=1) \
                for idx, ci in enumerate(bbx)]
        bbx = torch.cat(bbx, dim=0)

        if cfg.CUDA:
            im = im.to(RANK)
            sem = sem.to(RANK)
            bbx = bbx.to(RANK)
        zero = torch.zeros([], dtype=torch.float32, device=im.device)

        if self.clsnet is None:
            self.clsnet, self.clsoptim = self.get_classifier(RANK)

        if self.lossf is None:
            self.lossf = myloss.finalize_loss(
                weight=cfg.TRAIN.LOSS_WEIGHT, 
                normalize_weight=cfg.TRAIN.LOSS_WEIGHT_NORMALIZED)

        adjust_lr(optimizer, lr, opmgr=opmgr)
        optimizer.zero_grad()

        if update_cls:
            adjust_lr(self.clsoptim, lr, opmgr=None)
        self.clsoptim.zero_grad()

        loss_item = net(im, sem)
        pred = loss_item.pop('pred')

        h, w = pred.shape[-2:]
        osh, osw = im.shape[-2]/h, im.shape[-1]/w
        bbx[:, 1] /= osh
        bbx[:, 3] /= osh
        bbx[:, 2] /= osw
        bbx[:, 4] /= osw

        if cfg.TRAIN.ROI_BBOX_PADDING_TYPE == 'semcrop':
            # the bbox have already been squared. 
            # no further action is needed.
            bbx_reordered = torch.stack(
                [bbx[:, i] for i in [0, 2, 1, 4, 3]], dim=-1)
            # input bbx is <bs, w1, h1, w2, h2>
            # pred[:, 1:2] means we only get the fg part
            chpred = torchutils.roi_align(roi_size)(
                pred[:, 1:2], bbx_reordered)
        elif cfg.TRAIN.ROI_BBOX_PADDING_TYPE == 'inscrop':
            # the bbox haven't been squared yet. 
            # square the box before roi_align and pad out of box value to zero.
            dh, dw = [bbx[:, i]-bbx[:, i-2] for i in (3, 4)]
            bbx_sq = bbx.clone()
            bbx_sq[dw>dh , 1] -=  (dw-dh)[dw>dh]/2 # modify h1
            bbx_sq[dw>dh , 3] +=  (dw-dh)[dw>dh]/2 # modify h2
            bbx_sq[dw<=dh, 2] -= (dh-dw)[dw<=dh]/2 # modify w1
            bbx_sq[dw<=dh, 4] += (dh-dw)[dw<=dh]/2 # modify w2

            dhw = torch.max(dh, dw)
            bbx_offset = bbx[:, 1:5] - bbx_sq[:, 1:5]
            bbx_offset[:, 0] *= roi_size[0]/dhw
            bbx_offset[:, 2] *= roi_size[0]/dhw
            bbx_offset[:, 2] += roi_size[0]
            bbx_offset[:, 1] *= roi_size[1]/dhw
            bbx_offset[:, 3] *= roi_size[1]/dhw
            bbx_offset[:, 3] += roi_size[1]
            bbx_offset[:, 0:2] = torch.floor(bbx_offset[:, 0:2])
            bbx_offset[:, 2:4] = torch.ceil( bbx_offset[:, 2:4])
            bbx_offset = bbx_offset.long()

            bbx_reordered = torch.stack(
                [bbx_sq[:, i] for i in [0, 2, 1, 4, 3]], dim=-1)
            
            chpred = torchutils.roi_align(roi_size)(
                pred[:, 1:2], bbx_reordered)

            chpred_zeropad = torch.zeros(
                chpred.shape, device=chpred.device, 
                dtype=chpred.dtype)
            for idxi in range(chpred.shape[0]):
                h1, w1, h2, w2 = bbx_offset[idxi]
                chpred_zeropad[idxi, :, h1:h2, w1:w2] = chpred[idxi, :, h1:h2, w1:w2]
            chpred = chpred_zeropad
        else:
            raise ValueError

        chpredcls = bbx[:, 5].long()

        # compute the extra loss including the result from clsnet
        # do not update clsnet weight however. 

        loss_item['losscls'] = zero
        if update_cls:
            loss_item['lossupdatecls'] = zero

        if (chpred.shape[0] > 1) & (itern >= act_after):
            lossclsp_item = self.clsnet(chpred, chpredcls)
            loss_item['losscls'] = lossclsp_item['losscls']
        else:
            # we have to put a dummy forward and backward
            # with loss * zero otherwise it will stuck in 
            # multiprocess run.
            chpred_dummy = torch.zeros(
                [2, 1]+list(roi_size), dtype=torch.float32, 
                device=im.device)
            chpredcls_dummy = torch.zeros(
                [2], dtype=torch.int64,
                device=im.device)
            lossclsp_item_dummy = self.clsnet(chpred_dummy, chpredcls_dummy)
            loss_item['losscls'] = lossclsp_item_dummy['losscls'] * 0

        loss, loss_display = self.lossf(loss_item)
        loss.backward()
        if isinit:
            optimizer.zero_grad()
        else:
            optimizer.step()
        self.clsoptim.zero_grad()

        # update clsnet weight using the gt chins
        # do not update net
        if update_cls:
            chins = torch.cat(chins, dim=0)
            chcls = torch.cat(chcls, dim=0)
            loss_item = {ni:zero for ni in loss_item.keys()}

            if (chins.shape[0] > 1):
                if cfg.CUDA:
                    chins = chins.to(RANK)
                    chcls = chcls.to(RANK)
                cls_item = self.clsnet(chins.unsqueeze(1), chcls)
                loss_item['lossupdatecls'] = cls_item['losscls']
                # debug
                # print(cls_item['accnum'])

            loss, loss_display2 = self.lossf(loss_item)
            loss.backward()
            if isinit:
                self.clsoptim.zero_grad()
            else:
                self.clsoptim.step()
            optimizer.zero_grad()
            loss_display['lossupdatecls'] = loss_display2['lossupdatecls']

        return {'item': loss_display}
    def __call__(self,
                 **paras):
        cfg = cfguh().cfg
        logm = log_manager()
        epochn, itern = 0, 0

        dataloader = paras['dataloader']
        compute_lr = paras['compute_lr']
        RANK       = paras['RANK']

        while cfg.MAINLOOP_EXECUTE:
            for idx, batch in enumerate(dataloader):
                if not isinstance(batch[0], list):
                    batch_n = batch[0].shape[0]
                else:
                    batch_n = len(batch[0])

                if cfg.TRAIN.SKIP_PARTIAL \
                        and (batch_n != cfg.TRAIN.BATCH_SIZE_PER_GPU):
                    continue

                if cfg.TRAIN.LR_ITER_BY == 'epoch':
                    lr = compute_lr(epochn)
                elif cfg.TRAIN.LR_ITER_BY == 'iter':
                    lr = compute_lr(itern)
                else:
                    raise ValueError

                if itern==0:
                    self.main(
                        batch=batch,
                        lr=lr, 
                        isinit=True,
                        itern=itern,
                        **paras)

                paras_new = self.main(
                    batch=batch, 
                    lr=lr,
                    isinit=False,
                    itern=itern,
                    **paras)

                paras.update(paras_new)

                logm.accumulate(batch_n, paras['item'])
                itern += 1

                if itern % cfg.TRAIN.DISPLAY == 0:
                    print_log(logm.pop(
                        RANK, itern, epochn, (idx+1)*cfg.TRAIN.BATCH_SIZE, lr))

                if not isinstance(cfg.TRAIN.VISUAL, bool):
                    if itern % cfg.TRAIN.VISUAL == 0:
                        self.visual_f(paras['plot_item'])

                if cfg.TRAIN.MAX_STEP_TYPE == 'iter':
                    if itern >= cfg.TRAIN.MAX_STEP:
                        break
                    if itern % cfg.TRAIN.CKPT_EVERY == 0:
                        if RANK == 0:
                            print_log('Checkpoint... {}'.format(itern))
                            self.save(itern=itern, epochn=None, **paras)

                # loop end

            epochn += 1

            if cfg.TRAIN.MAX_STEP_TYPE == 'iter':
                if itern >= cfg.TRAIN.MAX_STEP:
                    break

            elif cfg.TRAIN.MAX_STEP_TYPE == 'epoch':
                if epochn >= cfg.TRAIN.MAX_STEP:
                    break
                if epochn % cfg.TRAIN.CKPT_EVERY == 0:
                    if RANK == 0:
                        print_log('Checkpoint... {}'.format(epochn))
                        self.save(itern=None, epochn=epochn, **paras)
    common_argparse, common_initiates

from lib.data_factory import \
    get_dataset, collate, \
    get_loader, get_transform, \
    get_formatter, DistributedSampler

from lib.model_zoo import \
    get_model, save_state_dict

from lib.optimizer import \
    get_optimizer, adjust_lr, lr_scheduler

from lib.log_service import print_log, torch_to_numpy, log_manager

cfguh().add_code(osp.basename(__file__))

class exec_container(object):
    def __init__(self,
                 cfg,
                 **kwargs):
        self.cfg = cfg
        self.registered_stages = []
        self.RANK = None

    def register_stage(self, stage):
        self.registered_stages.append(stage)

    def __call__(self, 
                 RANK,
                 **kwargs):
Esempio n. 15
0
    def main(self, RANK, batch, net, **kwargs):
        cfg = cfguh().cfg
        ac = cfg.TEST.INFERENCE_MS_ALIGN_CORNERS

        im, gtsem, fn = batch
        bs, _, oh, ow = im.shape

        if cfg.CUDA:
            im = im.to(RANK)

        # ms-flip inference
        psemc_ms, prfnc_ms, pcount_ms = {}, {}, {}
        pattkey, patt = {}, {}
        for mstag, mssize in cfg.TEST.INFERENCE_MS:
            # by area
            ratio = np.sqrt(mssize**2 / (oh * ow))
            th, tw = int(oh * ratio), int(ow * ratio)
            tw = tw // 32 * 32 + 1
            th = th // 32 * 32 + 1

            imi = {
                'nofp':
                torchutils.interpolate_2d(size=(th, tw),
                                          mode='bilinear',
                                          align_corners=ac)(im)
            }
            if cfg.TEST.INFERENCE_FLIP:
                imi['flip'] = torch.flip(imi['nofp'], dims=[-1])

            for fliptag, imii in imi.items():
                with torch.no_grad():
                    pred = net(imii)
                    psem = torchutils.interpolate_2d(size=(oh, ow),
                                                     mode='bilinear',
                                                     align_corners=ac)(
                                                         pred['predsem'])
                    prfn = torchutils.interpolate_2d(size=(oh, ow),
                                                     mode='bilinear',
                                                     align_corners=ac)(
                                                         pred['predrfn'])

                    if fliptag == 'flip':
                        psem = torch.flip(psem, dims=[-1])
                        prfn = torch.flip(prfn, dims=[-1])
                    elif fliptag == 'nofp':
                        pass
                    else:
                        raise ValueError

                try:
                    psemc_ms[mstag] += psem
                    prfnc_ms[mstag] += prfn
                    pcount_ms[mstag] += 1
                except:
                    psemc_ms[mstag] = psem
                    prfnc_ms[mstag] = prfn
                    pcount_ms[mstag] = 1

            # if flip, this is the attention that flipped.
            try:
                pattkey[mstag] = pred['att_key']
            except:
                pattkey[mstag] = None
            try:
                patt[mstag] = pred['att']
            except:
                patt[mstag] = None

        predc = []
        for predci in [psemc_ms, prfnc_ms]:
            p = sum([pi for pi in predci.values()])
            p /= sum([ni for ni in pcount_ms.values()])
            predc.append(p)
            p = {ki: pi / pcount_ms[ki] for ki, pi in predci.items()}
            predc.append(p)
        psemc, psemc_ms, prfnc, prfnc_ms = predc

        psem = torch.argmax(psemc, dim=1)
        prfn = torch.argmax(prfnc, dim=1)
        im, gtsem, psemc, psemc_ms, psem, prfnc, prfnc_ms, prfn = \
            torch_to_numpy(
                im, gtsem, psemc, psemc_ms, psem, prfnc, prfnc_ms, prfn)
        pattkey, patt = torch_to_numpy(pattkey, patt)

        return {
            'im': im,
            'gtsem': gtsem,
            'psem': psem,
            'psemc': psemc,
            'psemc_ms': psemc_ms,
            'prfn': prfn,
            'prfnc': prfnc,
            'prfnc_ms': prfnc_ms,
            'pattkey': pattkey,
            'patt': patt,
            'fn': fn,
        }