def train():
    logger = logging.getLogger()
    is_dist = dist.is_initialized()

    ## dataset
    dl = get_data_loader(cfg.im_root,
                         cfg.train_im_anns,
                         cfg.ims_per_gpu,
                         cfg.scales,
                         cfg.cropsize,
                         cfg.max_iter,
                         mode='train',
                         distributed=is_dist)

    ## model
    net, criteria_pre, criteria_aux = set_model()

    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## ddp training
    net = set_model_dist(net)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ##load checkpoin if exits for resuming training
    if args.loadCheckpointLocation != None:
        net, optim, lr_schdr, start_iteration = load_ckp(
            args.loadCheckpointLocation, net, optim, lr_schdr)
    else:
        start_iteration = 0

    ## train loop
    for current_it, (im, lb) in enumerate(dl):
        #on resumed training 'it' will be incremented from what was left else the sum is 0 anyways
        it = current_it + start_iteration
        im = im.cuda()
        lb = lb.cuda()

        lb = torch.squeeze(lb, 1)

        optim.zero_grad()
        logits, *logits_aux = net(im)
        loss_pre = criteria_pre(logits, lb)
        loss_aux = [
            crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)
        ]
        loss = loss_pre + sum(loss_aux)
        if has_apex:
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optim.step()
        torch.cuda.synchronize()
        lr_schdr.step()

        time_meter.update()
        loss_meter.update(loss.item())
        loss_pre_meter.update(loss_pre.item())
        _ = [
            mter.update(lss.item())
            for mter, lss in zip(loss_aux_meters, loss_aux)
        ]

        ## print training log message
        if (it + 1) % 100 == 0:
            lr = lr_schdr.get_lr()
            lr = sum(lr) / len(lr)
            print_log_msg(it, cfg.max_iter, lr, time_meter, loss_meter,
                          loss_pre_meter, loss_aux_meters)

        #save the checkpoint on every some iteration
        if (it + 1) % args.saveOnEveryIt == 0:
            if args.saveCheckpointDir != None:
                checkpoint = {
                    'iteration': it + 1,
                    'state_dict': net.state_dict(),
                    'optimizer': optim.state_dict(),
                    'lr_schdr': lr_schdr.state_dict(),
                }
                iteration_no_str = (str(it + 1)).zfill(len(str(cfg.max_iter)))
                ckt_name = 'checkpoint_it_' + iteration_no_str + '.pt'
                save_pth = osp.join(args.saveCheckpointDir, ckt_name)
                logger.info(
                    '\nsaving intermidiate checkpoint to {}'.format(save_pth))
                save_ckp(checkpoint, save_pth)

    ## dump the final model and evaluate the result
    checkpoint = {
        'iteration': cfg.max_iter,
        'state_dict': net.state_dict(),
        'optimizer': optim.state_dict(),
        'lr_schdr': lr_schdr.state_dict(),
    }
    save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt')
    logger.info('\nsave Final models to {}'.format(save_pth))
    save_ckp(checkpoint, save_pth)

    logger.info('\nevaluating the final model')
    torch.cuda.empty_cache()
    heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns)
    logger.info(tabulate([
        mious,
    ], headers=heads, tablefmt='orgtbl'))
    return
Exemplo n.º 2
0
def train():
    logger = logging.getLogger()
    is_dist = dist.is_initialized()

    ## dataset
    dl = get_data_loader(cfg.im_root,
                         cfg.train_im_anns,
                         cfg.ims_per_gpu,
                         cfg.scales,
                         cfg.cropsize,
                         cfg.max_iter,
                         mode='train',
                         distributed=is_dist)

    ## model
    net, criteria_pre, criteria_aux = set_model()

    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## ddp training
    net = set_model_dist(net)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ## train loop
    for it, (im, lb) in enumerate(dl):
        im = im.cuda()
        lb = lb.cuda()

        lb = torch.squeeze(lb, 1)

        optim.zero_grad()
        logits, *logits_aux = net(im)
        loss_pre = criteria_pre(logits, lb)
        loss_aux = [
            crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)
        ]
        loss = loss_pre + sum(loss_aux)
        if has_apex:
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optim.step()
        torch.cuda.synchronize()
        lr_schdr.step()

        time_meter.update()
        loss_meter.update(loss.item())
        loss_pre_meter.update(loss_pre.item())
        _ = [
            mter.update(lss.item())
            for mter, lss in zip(loss_aux_meters, loss_aux)
        ]

        ## print training log message
        if (it + 1) % 100 == 0:
            lr = lr_schdr.get_lr()
            lr = sum(lr) / len(lr)
            print_log_msg(it, cfg.max_iter, lr, time_meter, loss_meter,
                          loss_pre_meter, loss_aux_meters)
        if (it) % 1000 == 0:
            save_checkpoint('bisenet_citys_{}.pth'.format(it),
                            net.module.state_dict())

    ## dump the final model and evaluate the result
    save_pth = osp.join(cfg.respth, 'model_final.pth')
    logger.info('\nsave models to {}'.format(save_pth))
    state = net.module.state_dict()
    if dist.get_rank() == 0: torch.save(state, save_pth)

    logger.info('\nevaluating the final model')
    torch.cuda.empty_cache()
    heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns)
    logger.info(tabulate([
        mious,
    ], headers=heads, tablefmt='orgtbl'))

    return
Exemplo n.º 3
0
def train():
    logger = logging.getLogger()

    is_dist = False

    ## dataset
    dl = get_data_loader(
            cfg.im_root, cfg.train_im_anns,
            cfg.ims_per_gpu, cfg.scales, cfg.cropsize,
            cfg.max_iter, mode='train', distributed=is_dist)

    valid = get_data_loader(
        cfg.im_root, cfg.val_im_anns,
            cfg.ims_per_gpu, cfg.scales, cfg.cropsize,
            cfg.max_iter, mode='val', distributed=is_dist
    )

    ## model
    net, criteria_pre, criteria_aux = set_model()
    print(net)
    print(f'n_parameters: {sum(p.numel() for p in net.parameters())}')
    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(optim, power=0.9,
        max_iter=cfg.max_iter, warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1, warmup='exp', last_epoch=-1,)

    best_validation = np.inf

    for i in range(cfg.n_epochs):
        ## train loop
        for it, (im, lb) in enumerate(Bar(dl)):

            net.train()

            im = im.cuda()
            lb = lb.cuda()

            lb = torch.squeeze(lb, 1)

            optim.zero_grad()
            logits, *logits_aux = net(im)
            loss_pre = criteria_pre(logits, lb)

            loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)]
            loss = loss_pre + sum(loss_aux)
            if has_apex:
                with amp.scale_loss(loss, optim) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optim.step()
            torch.cuda.synchronize()
            lr_schdr.step()

            time_meter.update()
            loss_meter.update(loss.item())
            loss_pre_meter.update(loss_pre.item())
            _ = [mter.update(lss.item()) for mter, lss in zip(loss_aux_meters, loss_aux)]
            
            del im
            del lb
        ## print training log message
        lr = lr_schdr.get_lr()
        lr = sum(lr) / len(lr)
        print_log_msg(
            i, cfg.max_iter, lr, time_meter, loss_meter,
            loss_pre_meter, loss_aux_meters)

        ##validation loop
        validation_loss = []
        for it, (im, lb) in enumerate(Bar(valid)):

            net.eval()

            im = im.cuda()
            lb = lb.cuda()

            lb = torch.squeeze(lb, 1)

            with torch.no_grad():
                logits, *logits_aux = net(im)
                loss_pre = criteria_pre(logits, lb)
                loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)]
                loss = loss_pre + sum(loss_aux)
                validation_loss.append(loss.item())

            del im
            del lb

        ## print training log messag
        validation_loss = sum(validation_loss)/len(validation_loss)
        print(f'Validation loss: {validation_loss}')

        if best_validation > validation_loss:
            print('new best performance, storing model')
            best_validation = validation_loss
            state = net.state_dict()
            torch.save(state,  osp.join(cfg.respth, 'best_validation.pth'))

    ## dump the final model and evaluate the result
    save_pth = osp.join(cfg.respth, 'model_final.pth')
    logger.info('\nsave models to {}'.format(save_pth))
    state = net.state_dict()

    torch.save(state, save_pth)

    logger.info('\nevaluating the final model')
    torch.cuda.empty_cache()
    heads, mious = eval_model(net, 2, cfg.im_root, cfg.test_im_anns)
    logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl'))

    return
Exemplo n.º 4
0
def train(loginfo):
    logger = logging.getLogger()
    # is_dist = dist.is_initialized()

    logger.info("config: \n{}".format([item for item in cfg.__dict__.items()]))

    # ## dataset
    # dl = get_data_loader(
    #         cfg.train_img_root, cfg.train_img_anns,
    #         cfg.imgs_per_gpu, cfg.scales, cfg.cropsize,
    #         cfg.max_iter, mode='train', distributed=is_dist)
    # dl = get_data_loader(
    #         cfg.train_img_root, cfg.train_img_anns,
    #         cfg.imgs_per_gpu, cfg.scales, cfg.cropsize,
    #         cfg.anns_ignore, cfg.max_iter, mode='train', distributed=False)
    dl = prepare_data_loader(cfg.train_img_root,
                             cfg.train_img_anns,
                             cfg.input_size,
                             cfg.imgs_per_gpu,
                             device_count,
                             cfg.scales,
                             cfg.cropsize,
                             cfg.anns_ignore,
                             mode='train',
                             distributed=False)

    max_iter = cfg.max_epoch * len(dl.dataset) // (cfg.imgs_per_gpu * device_count) \
        if device == 'cuda' else cfg.max_epoch * len(dl.dataset) // cfg.imgs_per_gpu
    progress_iter = len(dl.dataset) / (cfg.imgs_per_gpu * device_count) // 5 \
        if device == 'cuda' else len(dl.dataset) / cfg.imgs_per_gpu // 5

    ## model
    net, criteria_pre, criteria_aux = set_model()
    net.to(device)
    if device_count >= 2:
        net = nn.DataParallel(net)
        torch.backends.cudnn.benchmark = True
        torch.multiprocessing.set_sharing_strategy('file_system')

    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## ddp training
    # net = set_model_dist(net)
    # #CHANGED: normal training
    # #FIXME: GETTING STARTED WITH DISTRIBUTED DATA PARALLEL
    # #https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters(
        max_iter)

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ## train loopx
    n_epoch = 0
    n_iter = 0
    best_valid_loss = np.inf
    while n_epoch < cfg.max_epoch:
        net.train()
        # for n_iter, (img, tar) in enumerate(dl):
        # for n_iter, (img, tar) in enumerate(tqdm(dl)):
        for (img, tar) in tqdm(dl,
                               desc='train epoch {:d}/{:d}'.format(
                                   n_epoch + 1, cfg.max_epoch)):
            img = img.to(device)
            tar = tar.to(device)

            tar = torch.squeeze(tar, 1)

            optim.zero_grad()
            logits, *logits_aux = net(img)
            loss_pre = criteria_pre(logits, tar)
            loss_aux = [
                crit(lgt, tar) for crit, lgt in zip(criteria_aux, logits_aux)
            ]
            loss = loss_pre + sum(loss_aux)
            if has_apex:
                with amp.scale_loss(loss, optim) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optim.step()
            torch.cuda.synchronize()
            lr_schdr.step()

            time_meter.update()
            loss_meter.update(loss.item())
            loss_pre_meter.update(loss_pre.item())
            _ = [
                mter.update(lss.item())
                for mter, lss in zip(loss_aux_meters, loss_aux)
            ]

            ## print training log message
            # if (n_iter + 1) % 100 == 0:
            if (n_iter + 1) % progress_iter == 0:
                lr = lr_schdr.get_lr()
                lr = sum(lr) / len(lr)
                print_log_msg(n_epoch, cfg.max_epoch, n_iter, max_iter, lr,
                              time_meter, loss_meter, loss_pre_meter,
                              loss_aux_meters)

            n_iter = n_iter + 1

        #CHANGED: save weight with valid loss
        ## dump the final model and evaluate the result
        # save_pth = os.path.join(cfg.weight_path, 'model_final.pth')
        # logger.info('\nsave models to {}'.format(save_pth))
        # state = net.module.state_dict()
        # if dist.get_rank() == 0: torch.save(state, save_pth)
        logger.info('vaildating the {} epoch model'.format(n_epoch + 1))
        valid_loss = valid(net, criteria_pre, criteria_aux, n_epoch, cfg,
                           logger)
        if valid_loss < best_valid_loss:
            # save_path = os.path.join(cfg.weight_path,
            #     'epoch{:d}_valid_loss_{:.4f}.pth'.format(n_epoch, valid_loss))
            if not os.path.exists(cfg.weight_path):
                os.makedirs(cfg.weight_path)
            save_path = os.path.join(
                cfg.weight_path, 'model_bestValidLoss-{}.pth'.format(loginfo))
            logger.info('save models to {}'.format(save_path))
            torch.save(net.state_dict(), save_path)
            best_valid_loss = valid_loss

        # logger.info('\nevaluating the final model')
        logger.info('evaluating the {} epoch model'.format(n_epoch + 1))
        torch.cuda.empty_cache()  ## For reset cuda memory used by cache
        # heads, mious = eval_model(net, 2, cfg.val_img_root, cfg.val_img_anns, cfg.n_classes)
        # logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl'))
        # heads, mious, eious = eval_model(net, cfg, device_count, cfg.val_img_root, cfg.val_img_anns, cfg.n_classes, cfg.anns_ignore)
        heads, mious, eious = test_model(net, cfg, device_count,
                                         cfg.val_img_root, cfg.val_img_anns,
                                         cfg.n_classes, cfg.anns_ignore)
        logger.info('\n' + tabulate(
            [
                mious,
            ], headers=heads, tablefmt='github', floatfmt=".8f"))
        logger.info('\n' + tabulate(np.array(eious).transpose(),
                                    headers=heads,
                                    tablefmt='github',
                                    floatfmt=".8f",
                                    showindex=True))

        n_epoch = n_epoch + 1

    heads, mious, eious = eval_model(net, cfg, device_count, cfg.val_img_root,
                                     cfg.val_img_anns, cfg.n_classes,
                                     cfg.anns_ignore)
    logger.info(
        '\n' +
        tabulate([
            mious,
        ], headers=heads, tablefmt='github', floatfmt=".8f"))
    logger.info('\n' + tabulate(np.array(eious).transpose(),
                                headers=heads,
                                tablefmt='github',
                                floatfmt=".8f",
                                showindex=True))

    return
Exemplo n.º 5
0
def train():
    logger = logging.getLogger()
    is_dist = dist.is_initialized()

    ## dataset
    dl = get_data_loader(cfg.im_root,
                         cfg.train_im_anns,
                         cfg.ims_per_gpu,
                         cfg.scales,
                         cfg.cropsize,
                         cfg.max_iter,
                         mode='train',
                         distributed=is_dist)

    ## model
    net, criteria_pre, criteria_aux = set_model()

    if dist.get_rank() == 0:
        exp_name = "cityscapes_repl"
        wandb.init(project="bisenet", name="cityscapes_repl")
        wandb.watch(net)

    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## ddp training
    net = set_model_dist(net)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ## train loop
    for it, (im, lb) in enumerate(dl):
        net.train()
        im = im.cuda()
        lb = lb.cuda()

        lb = torch.squeeze(lb, 1)

        optim.zero_grad()
        logits, *logits_aux = net(im)
        loss_pre = criteria_pre(logits, lb)
        loss_aux = [
            crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)
        ]
        loss = loss_pre + sum(loss_aux)
        if has_apex:
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optim.step()
        torch.cuda.synchronize()
        lr_schdr.step()

        time_meter.update()
        loss_meter.update(loss.item())
        loss_pre_meter.update(loss_pre.item())
        _ = [
            mter.update(lss.item())
            for mter, lss in zip(loss_aux_meters, loss_aux)
        ]

        lr = lr_schdr.get_lr()
        lr = sum(lr) / len(lr)
        ## print training log message
        if dist.get_rank() == 0:
            loss_avg = loss_meter.get()[0]
            wandb.log(
                {
                    "lr": lr,
                    "time": time_meter.get()[0],
                    "loss": loss_avg,
                    "loss_pre": loss_pre_meter.get()[0],
                    **{
                        f"loss_aux_{el.name}": el.get()[0]
                        for el in loss_aux_meters
                    }
                },
                commit=False)
            if (it + 1) % 100 == 0: print(it, ' - ', lr, ' - ', loss_avg)

            if (it + 1) % 2000 == 0:
                # dump the model and evaluate the result
                save_pth = osp.join(cfg.respth, f"{exp_name}_{it}.pth")
                state = net.module.state_dict()
                torch.save(state, save_pth)
                wandb.save(save_pth)
        if ((it + 1) % 2000 == 0):
            logger.info('\nevaluating the model')
            heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns, it)
            logger.info(tabulate([
                mious,
            ], headers=heads, tablefmt='orgtbl'))
            if (dist.get_rank() == 0):
                wandb.log({k: v for k, v in zip(heads, mious)}, commit=False)
        if (dist.get_rank() == 0):
            wandb.log({"t": it}, step=it)
    return
Exemplo n.º 6
0
def main():
    if not osp.exists(cfg.respth): os.makedirs(cfg.respth)
    setup_logger('{}-train'.format('banet'), cfg.respth)

    best_prec1 = (-1)
    logger = logging.getLogger()

    ## model
    net, criteria = set_model()
    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)
    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.epoch * 371,
        warmup_iter=cfg.warmup_iters * 371,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    for epoch in range(cfg.start_epoch, args.epoch_to_train):
        lr_schdr, time_meter, loss_meter = train(epoch, optim, net, criteria,
                                                 lr_schdr)
        if True:
            #if ((epoch+1)!=cfg.epoch):
            lr = lr_schdr.get_lr()
            print(lr)
            lr = sum(lr) / len(lr)
            loss_avg = print_log_msg(epoch, cfg.epoch, lr, time_meter,
                                     loss_meter)
            writer.add_scalar('loss', loss_avg, epoch + 1)

        if ((epoch + 1) == cfg.epoch) or ((epoch + 1) == args.epoch_to_train):
            #if ((epoch+1)%1==0) and ((epoch+1)>cfg.warmup_iters):
            torch.cuda.empty_cache()
            heads, mious, miou = eval_model(net,
                                            ims_per_gpu=2,
                                            im_root=cfg.im_root,
                                            im_anns=cfg.val_im_anns,
                                            it=epoch)
            filename = osp.join(cfg.respth, args.store_name)
            state = net.state_dict()
            save_checkpoint(state, False, filename=filename)
            #writer.add_scalar('mIOU',miou,epoch+1)
            with open('lr_record.txt', 'w') as m:
                print('lr to store', lr)
                m.seek(0)
                m.write((str(epoch + 1) + '   '))
                m.write(str(lr))
                m.truncate()
                m.close()
            with open('best_miou.txt', 'r+') as f:
                best_miou = f.read()
                #print(best_miou)
                best_miou = best_miou.replace('\n', ' ')
                x = best_miou.split(' ')
                while ('' in x):
                    x.remove('')
                best_miou = eval(x[-1])
                is_best = miou > best_miou
                if is_best:
                    best_miou = miou
                    print('Is best? : ', is_best)
                    f.seek(0)
                    f.write((str(epoch + 1) + '   '))
                    f.write(str(best_miou))
                    f.truncate()
                    f.close()
                    save_checkpoint(state, is_best, filename)
            print('Have Stored Checkpoint')
            #if((epoch+1)==cfg.epoch) or ((epoch+1)==args.epoch_to_train):
            state = net.state_dict()
            torch.cuda.empty_cache()
            #heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns,it=epoch)
            logger.info(tabulate([
                mious,
            ], headers=heads, tablefmt='orgtbl'))
            save_checkpoint(state, False, filename)
            print('Have Saved Final Model')
            break
Exemplo n.º 7
0
def train():
    logger = logging.getLogger()

    ## dataset
    dl = get_data_loader(cfg.im_root,
                         cfg.train_im_anns,
                         cfg.ims_per_gpu,
                         cfg.scales,
                         cfg.cropsize,
                         cfg.max_iter,
                         mode='train',
                         distributed=False)

    #send few training images to tensorboard
    addImage_Tensorboard(dl)
    #finding max epoch to train
    dataset_length = len(dl.dataset)
    print("Dataset length: ", dataset_length)
    batch_size = cfg.ims_per_gpu
    print("Batch Size: ", batch_size)
    iteration_per_epoch = int(dataset_length / batch_size)
    max_epoch = int(cfg.max_iter / iteration_per_epoch)
    print("Max_epoch: ", max_epoch)

    ## model
    net, criteria_pre, criteria_aux = set_model()

    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ##load checkpoin if exits for resuming training
    if args.loadCheckpointLocation != None:
        net, optim, lr_schdr, start_epoch = load_ckp(
            args.loadCheckpointLocation, net, optim, lr_schdr)
    else:
        start_epoch = 0
    #send the model structure to tensorboard
    addGraph_Tensorboard(net, dl)

    ## train loop
    for current_epoch in range(max_epoch):
        #on resumed training 'epoch' will be incremented from what was left else the sum is 0 anyways
        epoch = start_epoch + current_epoch

        for it, (im, lb) in enumerate(dl):

            im = im.to(device)
            lb = lb.to(device)

            lb = torch.squeeze(lb, 1)

            optim.zero_grad()
            logits, *logits_aux = net(im)
            loss_pre = criteria_pre(logits, lb)
            loss_aux = [
                crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)
            ]
            loss = loss_pre + sum(loss_aux)

            if has_apex:
                with amp.scale_loss(loss, optim) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optim.step()
            lr_schdr.step()

            time_meter.update()
            loss_meter.update(loss.item())
            loss_pre_meter.update(loss_pre.item())
            _ = [
                mter.update(lss.item())
                for mter, lss in zip(loss_aux_meters, loss_aux)
            ]

            ## print training log message
            global_it = it + epoch * iteration_per_epoch
            if (global_it + 1) % 100 == 0:
                lr = lr_schdr.get_lr()
                lr = sum(lr) / len(lr)
                #write important scalars to tensorboard
                addScalars_loss_Tensorboard(global_it, loss_meter)
                addScalars_lr_Tensorboard(global_it, lr)
                print_log_msg(global_it, cfg.max_iter, lr, time_meter,
                              loss_meter, loss_pre_meter, loss_aux_meters)

        #save the checkpoint on every some epoch
        if (epoch + 1) % args.saveOnEveryEpoch == 0:
            if args.saveCheckpointDir != None:
                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'optimizer': optim.state_dict(),
                    'lr_schdr': lr_schdr.state_dict(),
                }
                epoch_no_str = (str(epoch + 1)).zfill(len(str(cfg.max_iter)))
                ckt_name = 'checkpoint_epoch_' + epoch_no_str + '.pt'
                save_pth = osp.join(args.saveCheckpointDir, ckt_name)
                logger.info(
                    '\nsaving intermidiate checkpoint to {}'.format(save_pth))
                save_ckp(checkpoint, save_pth)

            #compute validation accuracy in terms of mious
            logger.info('\nevaluating the model after ' + str(epoch + 1) +
                        ' epoches')
            heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns,
                                      cfg.cropsize)
            #set back to training mode
            addScalars_val_accuracy_Tensorboard(global_it, heads, mious)
            net.train()

    ## dump the final model and evaluate the result
    checkpoint = {
        'epoch': max_epoch,
        'state_dict': net.state_dict(),
        'optimizer': optim.state_dict(),
        'lr_schdr': lr_schdr.state_dict(),
    }
    save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt')
    logger.info('\nsave Final models to {}'.format(save_pth))
    save_ckp(checkpoint, save_pth)

    logger.info('\nevaluating the final model')
    torch.cuda.empty_cache()
    heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns,
                              cfg.cropsize)
    logger.info(tabulate([
        mious,
    ], headers=heads, tablefmt='orgtbl'))
    return