def train():
    logger = logging.getLogger()
    is_dist = dist.is_initialized()

    ## dataset
    dl = get_data_loader(cfg.im_root,
                         cfg.train_im_anns,
                         cfg.ims_per_gpu,
                         cfg.scales,
                         cfg.cropsize,
                         cfg.max_iter,
                         mode='train',
                         distributed=is_dist)

    ## model
    net, criteria_pre, criteria_aux = set_model()

    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## ddp training
    net = set_model_dist(net)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ##load checkpoin if exits for resuming training
    if args.loadCheckpointLocation != None:
        net, optim, lr_schdr, start_iteration = load_ckp(
            args.loadCheckpointLocation, net, optim, lr_schdr)
    else:
        start_iteration = 0

    ## train loop
    for current_it, (im, lb) in enumerate(dl):
        #on resumed training 'it' will be incremented from what was left else the sum is 0 anyways
        it = current_it + start_iteration
        im = im.cuda()
        lb = lb.cuda()

        lb = torch.squeeze(lb, 1)

        optim.zero_grad()
        logits, *logits_aux = net(im)
        loss_pre = criteria_pre(logits, lb)
        loss_aux = [
            crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)
        ]
        loss = loss_pre + sum(loss_aux)
        if has_apex:
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optim.step()
        torch.cuda.synchronize()
        lr_schdr.step()

        time_meter.update()
        loss_meter.update(loss.item())
        loss_pre_meter.update(loss_pre.item())
        _ = [
            mter.update(lss.item())
            for mter, lss in zip(loss_aux_meters, loss_aux)
        ]

        ## print training log message
        if (it + 1) % 100 == 0:
            lr = lr_schdr.get_lr()
            lr = sum(lr) / len(lr)
            print_log_msg(it, cfg.max_iter, lr, time_meter, loss_meter,
                          loss_pre_meter, loss_aux_meters)

        #save the checkpoint on every some iteration
        if (it + 1) % args.saveOnEveryIt == 0:
            if args.saveCheckpointDir != None:
                checkpoint = {
                    'iteration': it + 1,
                    'state_dict': net.state_dict(),
                    'optimizer': optim.state_dict(),
                    'lr_schdr': lr_schdr.state_dict(),
                }
                iteration_no_str = (str(it + 1)).zfill(len(str(cfg.max_iter)))
                ckt_name = 'checkpoint_it_' + iteration_no_str + '.pt'
                save_pth = osp.join(args.saveCheckpointDir, ckt_name)
                logger.info(
                    '\nsaving intermidiate checkpoint to {}'.format(save_pth))
                save_ckp(checkpoint, save_pth)

    ## dump the final model and evaluate the result
    checkpoint = {
        'iteration': cfg.max_iter,
        'state_dict': net.state_dict(),
        'optimizer': optim.state_dict(),
        'lr_schdr': lr_schdr.state_dict(),
    }
    save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt')
    logger.info('\nsave Final models to {}'.format(save_pth))
    save_ckp(checkpoint, save_pth)

    logger.info('\nevaluating the final model')
    torch.cuda.empty_cache()
    heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns)
    logger.info(tabulate([
        mious,
    ], headers=heads, tablefmt='orgtbl'))
    return
Exemplo n.º 2
0
def train():
    logger = logging.getLogger()

    ## dataset
    dl = get_data_loader(cfg.im_root,
                         cfg.train_im_anns,
                         cfg.ims_per_gpu,
                         cfg.scales,
                         cfg.cropsize,
                         cfg.max_iter,
                         mode='train',
                         distributed=False)

    #send few training images to tensorboard
    addImage_Tensorboard(dl)
    #finding max epoch to train
    dataset_length = len(dl.dataset)
    print("Dataset length: ", dataset_length)
    batch_size = cfg.ims_per_gpu
    print("Batch Size: ", batch_size)
    iteration_per_epoch = int(dataset_length / batch_size)
    max_epoch = int(cfg.max_iter / iteration_per_epoch)
    print("Max_epoch: ", max_epoch)

    ## model
    net, criteria_pre, criteria_aux = set_model()

    ## optimizer
    optim = set_optimizer(net)

    ## fp16
    if has_apex:
        opt_level = 'O1' if cfg.use_fp16 else 'O0'
        net, optim = amp.initialize(net, optim, opt_level=opt_level)

    ## meters
    time_meter, loss_meter, loss_pre_meter, loss_aux_meters = set_meters()

    ## lr scheduler
    lr_schdr = WarmupPolyLrScheduler(
        optim,
        power=0.9,
        max_iter=cfg.max_iter,
        warmup_iter=cfg.warmup_iters,
        warmup_ratio=0.1,
        warmup='exp',
        last_epoch=-1,
    )

    ##load checkpoin if exits for resuming training
    if args.loadCheckpointLocation != None:
        net, optim, lr_schdr, start_epoch = load_ckp(
            args.loadCheckpointLocation, net, optim, lr_schdr)
    else:
        start_epoch = 0
    #send the model structure to tensorboard
    addGraph_Tensorboard(net, dl)

    ## train loop
    for current_epoch in range(max_epoch):
        #on resumed training 'epoch' will be incremented from what was left else the sum is 0 anyways
        epoch = start_epoch + current_epoch

        for it, (im, lb) in enumerate(dl):

            im = im.to(device)
            lb = lb.to(device)

            lb = torch.squeeze(lb, 1)

            optim.zero_grad()
            logits, *logits_aux = net(im)
            loss_pre = criteria_pre(logits, lb)
            loss_aux = [
                crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)
            ]
            loss = loss_pre + sum(loss_aux)

            if has_apex:
                with amp.scale_loss(loss, optim) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optim.step()
            lr_schdr.step()

            time_meter.update()
            loss_meter.update(loss.item())
            loss_pre_meter.update(loss_pre.item())
            _ = [
                mter.update(lss.item())
                for mter, lss in zip(loss_aux_meters, loss_aux)
            ]

            ## print training log message
            global_it = it + epoch * iteration_per_epoch
            if (global_it + 1) % 100 == 0:
                lr = lr_schdr.get_lr()
                lr = sum(lr) / len(lr)
                #write important scalars to tensorboard
                addScalars_loss_Tensorboard(global_it, loss_meter)
                addScalars_lr_Tensorboard(global_it, lr)
                print_log_msg(global_it, cfg.max_iter, lr, time_meter,
                              loss_meter, loss_pre_meter, loss_aux_meters)

        #save the checkpoint on every some epoch
        if (epoch + 1) % args.saveOnEveryEpoch == 0:
            if args.saveCheckpointDir != None:
                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'optimizer': optim.state_dict(),
                    'lr_schdr': lr_schdr.state_dict(),
                }
                epoch_no_str = (str(epoch + 1)).zfill(len(str(cfg.max_iter)))
                ckt_name = 'checkpoint_epoch_' + epoch_no_str + '.pt'
                save_pth = osp.join(args.saveCheckpointDir, ckt_name)
                logger.info(
                    '\nsaving intermidiate checkpoint to {}'.format(save_pth))
                save_ckp(checkpoint, save_pth)

            #compute validation accuracy in terms of mious
            logger.info('\nevaluating the model after ' + str(epoch + 1) +
                        ' epoches')
            heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns,
                                      cfg.cropsize)
            #set back to training mode
            addScalars_val_accuracy_Tensorboard(global_it, heads, mious)
            net.train()

    ## dump the final model and evaluate the result
    checkpoint = {
        'epoch': max_epoch,
        'state_dict': net.state_dict(),
        'optimizer': optim.state_dict(),
        'lr_schdr': lr_schdr.state_dict(),
    }
    save_pth = osp.join(args.saveCheckpointDir, 'model_final.pt')
    logger.info('\nsave Final models to {}'.format(save_pth))
    save_ckp(checkpoint, save_pth)

    logger.info('\nevaluating the final model')
    torch.cuda.empty_cache()
    heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns,
                              cfg.cropsize)
    logger.info(tabulate([
        mious,
    ], headers=heads, tablefmt='orgtbl'))
    return