def main(args):
    # read all the images in the folder
    if args.dataset == 'city':
        #image_path = os.path.join(args.data_path, "leftImg8bit", args.split, "*", "*.png")
        image_path ='./input/*.png'
        image_list = glob.glob(image_path)
        print(image_list)
        from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST
        seg_classes = len(CITYSCAPE_CLASS_LIST)
    elif args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOC_CLASS_LIST
        seg_classes = len(VOC_CLASS_LIST)
        data_file = os.path.join(args.data_path, 'VOC2012', 'list', '{}.txt'.format(args.split))
        if not os.path.isfile(data_file):
            print_error_message('{} file does not exist'.format(data_file))
        image_list = []
        with open(data_file, 'r') as lines:
            for line in lines:
                rgb_img_loc = '{}/{}/{}'.format(args.data_path, 'VOC2012', line.split()[0])
                if not os.path.isfile(rgb_img_loc):
                    print_error_message('{} image file does not exist'.format(rgb_img_loc))
                image_list.append(rgb_img_loc)
    else:
        print_error_message('{} dataset not yet supported'.format(args.dataset))

    if len(image_list) == 0:
        print_error_message('No files in directory: {}'.format(image_path))

    print_info_message('# of images for testing: {}'.format(len(image_list)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('{} network not yet supported'.format(args.model))
        exit(-1)

    # mdoel information
    num_params = model_parameters(model)
    flops = compute_flops(model, input=torch.Tensor(1, 3, args.im_size[0], args.im_size[1]))
    print_info_message('FLOPs for an input of size {}x{}: {:.2f} million'.format(args.im_size[0], args.im_size[1], flops))
    print_info_message('# of parameters: {}'.format(num_params))

    if args.weights_test:
        print_info_message('Loading model weights')
        weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu'))
        model.load_state_dict(weight_dict)
        print_info_message('Weight loaded successfully')
    else:
        print_error_message('weight file does not exist or not specified. Please check: {}', format(args.weights_test))

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    model = model.to(device=device)

    evaluate(args, model, image_list, device=device)
Exemple #2
0
def main(args):
    # read all the images in the folder
    if args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST
        seg_classes = len(CITYSCAPE_CLASS_LIST)
    elif args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOC_CLASS_LIST
        seg_classes = len(VOC_CLASS_LIST)
    else:
        print_error_message('{} dataset not yet supported'.format(args.dataset))

    image_list = []
    for extn in IMAGE_EXTENSIONS:
        image_list = image_list +  glob.glob(args.data_path + os.sep + '*' + extn)

    if len(image_list) == 0:
        print_error_message('No files in directory: {}'.format(args.data_path))

    print_info_message('# of images used for demonstration: {}'.format(len(image_list)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('{} network not yet supported'.format(args.model))
        exit(-1)


    if args.weights_test:
        print_info_message('Loading model weights')
        weight_dict = torch.load(args.weights_test, map_location=torch.device('cpu'))
        model.load_state_dict(weight_dict)
        print_info_message('Weight loaded successfully')
    else:
        print_error_message('weight file does not exist or not specified. Please check: {}', format(args.weights_test))

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    model = model.to(device=device)

    if torch.backends.cudnn.is_available():
        import torch.backends.cudnn as cudnn
        cudnn.benchmark = True
        cudnn.deterministic = True

    run_segmentation(args, model, image_list, device=device)
Exemple #3
0
def main(args):

    image_list = []
    for extn in IMAGE_EXTENSIONS:
        image_list = image_list + glob.glob(args.data_path + os.sep + '*' +
                                            extn)

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = 20
        model = espnetv2_seg(args)

    if args.weights_test:
        print('Loading model weights')
        weight_dict = torch.load(args.weights_test,
                                 map_location=torch.device('cpu'))
        model.load_state_dict(weight_dict)
        print('Weight loaded successfully')
    else:
        print("ERRORRRR")

    model = model.cuda()
    run_segmentation(model, image_list, device='cuda')
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Exemple #5
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST, color_encoding
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               coarse=False)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 10 / 2.8149201869965
        class_wts[1] = 10 / 6.9850029945374
        class_wts[2] = 10 / 3.7890393733978
        class_wts[3] = 10 / 9.9428062438965
        class_wts[4] = 10 / 9.7702074050903
        class_wts[5] = 10 / 9.5110931396484
        class_wts[6] = 10 / 10.311357498169
        class_wts[7] = 10 / 10.026463508606
        class_wts[8] = 10 / 4.6323022842407
        class_wts[9] = 10 / 9.5608062744141
        class_wts[10] = 10 / 7.8698215484619
        class_wts[11] = 10 / 9.5168733596802
        class_wts[12] = 10 / 10.373730659485
        class_wts[13] = 10 / 6.6616044044495
        class_wts[14] = 10 / 10.260489463806
        class_wts[15] = 10 / 10.287888526917
        class_wts[16] = 10 / 10.289801597595
        class_wts[17] = 10 / 10.405355453491
        class_wts[18] = 10 / 10.138095855713
        class_wts[19] = 0.0

    elif args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GREENHOUSE_CLASS_LIST, color_encoding
        train_dataset = GreenhouseRGBDSegmentation(
            root=args.data_path,
            list_name=args.train_list,
            train=True,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth,
            use_traversable=args.greenhouse_use_trav)
        val_dataset = GreenhouseRGBDSegmentation(
            root=args.data_path,
            list_name=args.val_list,
            train=False,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth,
            use_traversable=args.greenhouse_use_trav)
        class_weights = np.load('class_weights.npy')  # [:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        print(GREENHOUSE_CLASS_LIST)
        seg_classes = len(GREENHOUSE_CLASS_LIST)
#        color_encoding = OrderedDict([
#            ('end_of_plant', (0, 255, 0)),
#            ('other_part_of_plant', (0, 255, 255)),
#            ('artificial_objects', (255, 0, 0)),
#            ('ground', (255, 255, 0)),
#            ('background', (0, 0, 0))
#        ])
    elif args.dataset == 'ishihara':
        print(args.use_depth)
        from data_loader.segmentation.ishihara_rgbd import IshiharaRGBDSegmentation, ISHIHARA_RGBD_CLASS_LIST
        train_dataset = IshiharaRGBDSegmentation(
            root=args.data_path,
            list_name='ishihara_rgbd_train.txt',
            train=True,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth)
        val_dataset = IshiharaRGBDSegmentation(
            root=args.data_path,
            list_name='ishihara_rgbd_val.txt',
            train=False,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth)

        seg_classes = len(ISHIHARA_RGBD_CLASS_LIST)

        class_wts = torch.ones(seg_classes)

        color_encoding = OrderedDict([('Unlabeled', (0, 0, 0)),
                                      ('Building', (70, 70, 70)),
                                      ('Fence', (190, 153, 153)),
                                      ('Others', (72, 0, 90)),
                                      ('Pedestrian', (220, 20, 60)),
                                      ('Pole', (153, 153, 153)),
                                      ('Road ', (157, 234, 50)),
                                      ('Road', (128, 64, 128)),
                                      ('Sidewalk', (244, 35, 232)),
                                      ('Vegetation', (107, 142, 35)),
                                      ('Car', (0, 0, 255)),
                                      ('Wall', (102, 102, 156)),
                                      ('Traffic ', (220, 220, 0))])
    elif args.dataset == 'sun':
        print(args.use_depth)
        from data_loader.segmentation.sun_rgbd import SUNRGBDSegmentation, SUN_RGBD_CLASS_LIST
        train_dataset = SUNRGBDSegmentation(root=args.data_path,
                                            list_name='sun_rgbd_train.txt',
                                            train=True,
                                            size=crop_size,
                                            ignore_idx=args.ignore_idx,
                                            scale=args.scale,
                                            use_depth=args.use_depth)
        val_dataset = SUNRGBDSegmentation(root=args.data_path,
                                          list_name='sun_rgbd_val.txt',
                                          train=False,
                                          size=crop_size,
                                          ignore_idx=args.ignore_idx,
                                          scale=args.scale,
                                          use_depth=args.use_depth)

        seg_classes = len(SUN_RGBD_CLASS_LIST)

        class_wts = torch.ones(seg_classes)

        color_encoding = OrderedDict([('Background', (0, 0, 0)),
                                      ('Bed', (0, 255, 0)),
                                      ('Books', (70, 70, 70)),
                                      ('Ceiling', (190, 153, 153)),
                                      ('Chair', (72, 0, 90)),
                                      ('Floor', (220, 20, 60)),
                                      ('Furniture', (153, 153, 153)),
                                      ('Objects', (157, 234, 50)),
                                      ('Picture', (128, 64, 128)),
                                      ('Sofa', (244, 35, 232)),
                                      ('Table', (107, 142, 35)),
                                      ('TV', (0, 0, 255)),
                                      ('Wall', (102, 102, 156)),
                                      ('Window', (220, 220, 0))])
    elif args.dataset == 'camvid':
        print(args.use_depth)
        from data_loader.segmentation.camvid import CamVidSegmentation, CAMVID_CLASS_LIST, color_encoding
        train_dataset = CamVidSegmentation(
            root=args.data_path,
            list_name='train_camvid.txt',
            train=True,
            size=crop_size,
            scale=args.scale,
            label_conversion=args.label_conversion,
            normalize=args.normalize)
        val_dataset = CamVidSegmentation(
            root=args.data_path,
            list_name='val_camvid.txt',
            train=False,
            size=crop_size,
            scale=args.scale,
            label_conversion=args.label_conversion,
            normalize=args.normalize)

        if args.label_conversion:
            from data_loader.segmentation.greenhouse import GREENHOUSE_CLASS_LIST, color_encoding
            seg_classes = len(GREENHOUSE_CLASS_LIST)
            class_wts = torch.ones(seg_classes)
        else:
            seg_classes = len(CAMVID_CLASS_LIST)
            tmp_loader = torch.utils.data.DataLoader(train_dataset,
                                                     batch_size=1,
                                                     shuffle=False)

            class_wts = calc_cls_class_weight(tmp_loader,
                                              seg_classes,
                                              inverted=True)
            class_wts = torch.from_numpy(class_wts).float().to(device)
            #            class_wts = torch.ones(seg_classes)
            print("class weights : {}".format(class_wts))

        args.use_depth = False
    elif args.dataset == 'forest':
        from data_loader.segmentation.freiburg_forest import FreiburgForestDataset, FOREST_CLASS_LIST, color_encoding
        train_dataset = FreiburgForestDataset(train=True,
                                              size=crop_size,
                                              scale=args.scale,
                                              normalize=args.normalize)
        val_dataset = FreiburgForestDataset(train=False,
                                            size=crop_size,
                                            scale=args.scale,
                                            normalize=args.normalize)

        seg_classes = len(FOREST_CLASS_LIST)
        tmp_loader = torch.utils.data.DataLoader(train_dataset,
                                                 batch_size=1,
                                                 shuffle=False)

        class_wts = calc_cls_class_weight(tmp_loader,
                                          seg_classes,
                                          inverted=True)
        class_wts = torch.from_numpy(class_wts).float().to(device)
        #        class_wts = torch.ones(seg_classes)
        print("class weights : {}".format(class_wts))

        args.use_depth = False
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espdnet':
        from model.segmentation.espdnet import espdnet_seg
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnet_seg(args)
    elif args.model == 'espdnetue':
        from model.segmentation.espdnet_ue import espdnetue_seg2
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnetue_seg2(args, fix_pyr_plane_proj=True)
    elif args.model == 'deeplabv3':
        # from model.segmentation.deeplabv3 import DeepLabV3
        from torchvision.models.segmentation.segmentation import deeplabv3_resnet101

        args.classes = seg_classes
        # model = DeepLabV3(seg_classes)
        model = deeplabv3_resnet101(num_classes=seg_classes, aux_loss=True)
        torch.backends.cudnn.enabled = False
    elif args.model == 'unet':
        from model.segmentation.unet import UNet
        model = UNet(in_channels=3, out_channels=seg_classes)
#        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
#            in_channels=3, out_channels=seg_classes, init_features=32, pretrained=False)

    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    if args.model == 'deeplabv3' or args.model == 'unet':
        train_params = [{'params': model.parameters(), 'lr': args.lr}]

    elif args.use_depth:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }, {
            'params': model.get_depth_encoder_params(),
            'lr': args.lr * args.lr_mult
        }]
    else:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model, input_to_model=torch.Tensor(1, 3, 288, 480))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    print('device : ' + device)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))
    nid_loss = NIDLoss(image_bin=32,
                       label_bin=seg_classes) if args.use_nid else None

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss = DataParallelCriteria(nid_loss)
                nid_loss = nid_loss.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=20,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        if len(optimizer.param_groups) > 1:
            optimizer.param_groups[1]['lr'] = lr_seg
        if args.use_depth:
            optimizer.param_groups[2]['lr'] = lr_base

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))

        if args.model == 'espdnetue' or (
            (args.model == 'deeplabv3' or args.model == 'unet')
                and args.use_aux):
            from utilities.train_eval_seg import train_seg_ue as train
            from utilities.train_eval_seg import val_seg_ue as val
        else:
            from utilities.train_eval_seg import train_seg as train
            from utilities.train_eval_seg import val_seg as val

        iou_train, train_loss = train(model,
                                      train_loader,
                                      optimizer,
                                      criterion,
                                      seg_classes,
                                      epoch,
                                      device=device,
                                      use_depth=args.use_depth,
                                      add_criterion=nid_loss)
        iou_val, val_loss = val(model,
                                val_loader,
                                criterion,
                                seg_classes,
                                device=device,
                                use_depth=args.use_depth,
                                add_criterion=nid_loss)

        batch_train = iter(train_loader).next()
        batch = iter(val_loader).next()
        if args.use_depth:
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                depths=batch_train[2].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          depths=batch[2].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)

            image_grid = torchvision.utils.make_grid(
                batch[2].to(device=device).data.cpu()).numpy()
            print(type(image_grid))
            writer.add_image('Segmentation/depths', image_grid, epoch)
        else:
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)


#            image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
#            writer.add_image('Segmentation/results/val', image_grid, epoch)

# remember best miou and save checkpoint
        miou_val = iou_val[[1, 2, 3]].mean()
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train',
                          iou_train[[1, 2, 3]].mean(), epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/plant_IOU/val', iou_val[1], epoch)
        writer.add_scalar('Segmentation/ao_IOU/val', iou_val[2], epoch)
        writer.add_scalar('Segmentation/ground_IOU/val', iou_val[3], epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Exemple #6
0
def main(args):
    # read all the images in the folder
    if args.dataset == 'city':
        # image_path = os.path.join(args.data_path, "leftImg8bit", args.split, "*", "*.png")
        # image_list = glob.glob(image_path)
        # from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST
        # seg_classes = len(CITYSCAPE_CLASS_LIST)
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=(256, 256),
                                             scale=args.s,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    elif args.dataset == 'pascal':
        # from data_loader.segmentation.voc import VOC_CLASS_LIST
        # seg_classes = len(VOC_CLASS_LIST)
        # data_file = os.path.join(args.data_path, 'VOC2012', 'list', '{}.txt'.format(args.split))
        # if not os.path.isfile(data_file):
        #     print_error_message('{} file does not exist'.format(data_file))
        # image_list = []
        # with open(data_file, 'r') as lines:
        #     for line in lines:
        #         rgb_img_loc = '{}/{}/{}'.format(args.data_path, 'VOC2012', line.split()[0])
        #         if not os.path.isfile(rgb_img_loc):
        #             print_error_message('{} image file does not exist'.format(rgb_img_loc))
        #         image_list.append(rgb_img_loc)
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=(256, 256),
                                      scale=args.s)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'hockey':
        from data_loader.segmentation.hockey import HockeySegmentationDataset, HOCKEY_DATASET_CLASS_LIST
        train_dataset = HockeySegmentationDataset(root=args.data_path,
                                                  train=True,
                                                  crop_size=(256, 256),
                                                  scale=args.s)
        val_dataset = HockeySegmentationDataset(root=args.data_path,
                                                train=False,
                                                crop_size=(256, 256),
                                                scale=args.s)
        seg_classes = len(HOCKEY_DATASET_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'hockey_rink_seg':
        from data_loader.segmentation.hockey_rink_seg import HockeyRinkSegmentationDataset, HOCKEY_DATASET_CLASS_LIST
        train_dataset = HockeyRinkSegmentationDataset(root=args.data_path,
                                                      train=True,
                                                      crop_size=(256, 256),
                                                      scale=args.s)
        val_dataset = HockeyRinkSegmentationDataset(root=args.data_path,
                                                    train=False,
                                                    crop_size=(256, 256),
                                                    scale=args.s)
        seg_classes = len(HOCKEY_DATASET_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    else:
        print_error_message('{} dataset not yet supported'.format(
            args.dataset))

    if len(val_dataset) == 0:
        print_error_message('No files in directory: {}'.format(image_path))

    print_info_message('# of images for testing: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('{} network not yet supported'.format(args.model))
        exit(-1)

    # model information
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, args.im_size[0],
                                             args.im_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            args.im_size[0], args.im_size[1], flops))
    print_info_message('# of parameters: {}'.format(num_params))

    if args.weights_test:
        print_info_message('Loading model weights')
        weight_dict = torch.load(args.weights_test,
                                 map_location=torch.device('cpu'))

        if isinstance(weight_dict, dict) and 'state_dict' in weight_dict:
            model.load_state_dict(weight_dict['state_dict'])
        else:
            model.load_state_dict(weight_dict)

        print_info_message('Weight loaded successfully')
    else:
        print_error_message(
            'weight file does not exist or not specified. Please check: {}',
            format(args.weights_test))

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    model = model.to(device=device)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=40,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=4)

    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type='ce',
                                 device=device,
                                 ignore_idx=255,
                                 class_wts=class_wts.to(device))
    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # evaluate(args, model, image_list, seg_classes, device=device)
    miou_val, val_loss = val(model,
                             val_loader,
                             criterion,
                             seg_classes,
                             device=device)
    print_info_message('mIOU: {}'.format(miou_val))
Exemple #7
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    print('device : ' + device)

    # Get a summary writer for tensorboard
    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')

    #
    # Training the model with 13 classes of CamVid dataset
    # TODO: This process should be done only if specified
    #
    if not args.finetune:
        train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
            label_conversion=False)  # 13 classes
        args.use_depth = False  # 'use_depth' is always false for camvid

        print_info_message('Training samples: {}'.format(len(train_dataset)))
        print_info_message('Validation samples: {}'.format(len(val_dataset)))

        # Import model
        if args.model == 'espnetv2':
            from model.segmentation.espnetv2 import espnetv2_seg
            args.classes = seg_classes
            model = espnetv2_seg(args)
        elif args.model == 'espdnet':
            from model.segmentation.espdnet import espdnet_seg
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            print("Segmentation classes : {}".format(seg_classes))
            model = espdnet_seg(args)
        elif args.model == 'espdnetue':
            from model.segmentation.espdnet_ue import espdnetue_seg2
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            ("Segmentation classes : {}".format(seg_classes))
            print(args.weights)
            model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True)
        else:
            print_error_message('Arch: {} not yet supported'.format(
                args.model))
            exit(-1)

        # Freeze batch normalization layers?
        if args.freeze_bn:
            freeze_bn_layer(model)

        # Set learning rates
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

        # Define an optimizer
        optimizer = optim.SGD(train_params,
                              lr=args.lr * args.lr_mult,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        # Compute the FLOPs and the number of parameters, and display it
        num_params, flops = show_network_stats(model, crop_size)

        try:
            writer.add_graph(model,
                             input_to_model=torch.Tensor(
                                 1, 3, crop_size[0], crop_size[1]))
        except:
            print_log_message(
                "Not able to generate the graph. Likely because your model is not supported by ONNX"
            )

        #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
        criterion = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))
        nid_loss = NIDLoss(image_bin=32,
                           label_bin=seg_classes) if args.use_nid else None

        if num_gpus >= 1:
            if num_gpus == 1:
                # for a single GPU, we do not need DataParallel wrapper for Criteria.
                # So, falling back to its internal wrapper
                from torch.nn.parallel import DataParallel
                model = DataParallel(model)
                model = model.cuda()
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss.cuda()
            else:
                from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
                model = DataParallelModel(model)
                model = model.cuda()
                criterion = DataParallelCriteria(criterion)
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss = DataParallelCriteria(nid_loss)
                    nid_loss = nid_loss.cuda()

            if torch.backends.cudnn.is_available():
                import torch.backends.cudnn as cudnn
                cudnn.benchmark = True
                cudnn.deterministic = True

        # Get data loaders for training and validation data
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=20,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

        # Get a learning rate scheduler
        lr_scheduler = get_lr_scheduler(args.scheduler)

        write_stats_to_json(num_params, flops)

        extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
        #
        # Main training loop of 13 classes
        #
        start_epoch = 0
        best_miou = 0.0
        for epoch in range(start_epoch, args.epochs):
            lr_base = lr_scheduler.step(epoch)
            # set the optimizer with the learning rate
            # This can be done inside the MyLRScheduler
            lr_seg = lr_base * args.lr_mult
            optimizer.param_groups[0]['lr'] = lr_base
            optimizer.param_groups[1]['lr'] = lr_seg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # Use different training functions for espdnetue
            if args.model == 'espdnetue':
                from utilities.train_eval_seg import train_seg_ue as train
                from utilities.train_eval_seg import val_seg_ue as val
            else:
                from utilities.train_eval_seg import train_seg as train
                from utilities.train_eval_seg import val_seg as val

            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device,
                                           use_depth=args.use_depth,
                                           add_criterion=nid_loss)
            miou_val, val_loss = val(model,
                                     val_loader,
                                     criterion,
                                     seg_classes,
                                     device=device,
                                     use_depth=args.use_depth,
                                     add_criterion=nid_loss)

            batch_train = iter(train_loader).next()
            batch = iter(val_loader).next()
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)

            # remember best miou and save checkpoint
            is_best = miou_val > best_miou
            best_miou = max(miou_val, best_miou)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_miou': best_miou,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
            writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
            writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
            writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
            writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
            writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
            writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                              math.ceil(flops))
            writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                              math.ceil(num_params))

        # Save the pretrained weights
        model_dict = copy.deepcopy(model.state_dict())
        del model
        torch.cuda.empty_cache()

    #
    # Finetuning with 4 classes
    #
    args.ignore_idx = 4
    train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
        label_conversion=True)  # 5 classes

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    #set_parameters_for_finetuning()

    # Import model
    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espdnet':
        from model.segmentation.espdnet import espdnet_seg
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnet_seg(args)
    elif args.model == 'espdnetue':
        from model.segmentation.espdnet_ue import espdnetue_seg2
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        print(args.weights)
        model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if not args.finetune:
        new_model_dict = model.state_dict()
        #        for k, v in model_dict.items():
        #            if k.lstrip('module.') in new_model_dict:
        #                print('In:{}'.format(k.lstrip('module.')))
        #            else:
        #                print('Not In:{}'.format(k.lstrip('module.')))
        overlap_dict = {
            k.replace('module.', ''): v
            for k, v in model_dict.items()
            if k.replace('module.', '') in new_model_dict
            and new_model_dict[k.replace('module.', '')].size() == v.size()
        }
        no_overlap_dict = {
            k.replace('module.', ''): v
            for k, v in new_model_dict.items()
            if k.replace('module.', '') not in new_model_dict
            or new_model_dict[k.replace('module.', '')].size() != v.size()
        }
        print(no_overlap_dict.keys())

        new_model_dict.update(overlap_dict)
        model.load_state_dict(new_model_dict)

    output = model(torch.ones(1, 3, 288, 480))
    print(output[0].size())

    print(seg_classes)
    print(class_wts.size())
    #print(model_dict.keys())
    #print(new_model_dict.keys())
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))
    nid_loss = NIDLoss(image_bin=32,
                       label_bin=seg_classes) if args.use_nid else None

    # Set learning rates
    args.lr /= 100
    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]
    # Define an optimizer
    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss = DataParallelCriteria(nid_loss)
                nid_loss = nid_loss.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # Get data loaders for training and validation data
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=20,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    # Get a learning rate scheduler
    args.epochs = 50
    lr_scheduler = get_lr_scheduler(args.scheduler)

    # Compute the FLOPs and the number of parameters, and display it
    num_params, flops = show_network_stats(model, crop_size)
    write_stats_to_json(num_params, flops)

    extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s,
                                           crop_size[0])
    #
    # Main training loop of 13 classes
    #
    start_epoch = 0
    best_miou = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))

        # Use different training functions for espdnetue
        if args.model == 'espdnetue':
            from utilities.train_eval_seg import train_seg_ue as train
            from utilities.train_eval_seg import val_seg_ue as val
        else:
            from utilities.train_eval_seg import train_seg as train
            from utilities.train_eval_seg import val_seg as val

        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device,
                                       use_depth=args.use_depth,
                                       add_criterion=nid_loss)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device,
                                 use_depth=args.use_depth,
                                 add_criterion=nid_loss)

        batch_train = iter(train_loader).next()
        batch = iter(val_loader).next()
        in_training_visualization_img(model,
                                      images=batch_train[0].to(device=device),
                                      labels=batch_train[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/train',
                                      device=device)
        in_training_visualization_img(model,
                                      images=batch[0].to(device=device),
                                      labels=batch[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/val',
                                      device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch)
        writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch)
        writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch)
        writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch)
        writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('SegmentationConv/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()