Exemple #1
0
 def __init__(self):
     num_class = 5
     self.model = models.ERFNet(num_class)
     checkpoint = torch.load('erfnet/pretrained/ERFNet_pretrained.tar')
     torch.nn.Module.load_state_dict(model, checkpoint['state_dict'])
     cudnn.benchmark = True
     cudnn.fastest = True
     model.eval()
Exemple #2
0
def main():
    global args, best_mIoU
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.dataset == 'LaneDet':
        num_class = 20
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    # get places
    places = fluid.cuda_places()

    with fluid.dygraph.guard():
        model = models.ERFNet(num_class, [576, 1024])
        input_mean = model.input_mean
        input_std = model.input_std

        if args.resume:
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint, _ = fluid.load_dygraph(args.resume)
            model.load_dict(checkpoint)
            print("=> checkpoint loaded successfully")
        else:
            print(("=> loading checkpoint '{}'".format('trained/ERFNet_trained')))
            checkpoint, _ = fluid.load_dygraph('trained/ERFNet_trained')
            model.load_dict(checkpoint)
            print("=> default checkpoint loaded successfully")

        # Data loading code
        test_dataset = ds.LaneDataSet(
            dataset_path='datasets/PreliminaryData',
            data_list=args.val_list,
            transform=[
                lambda x: cv2.resize(x, (1024, 576)),
                lambda x: x - np.asarray(input_mean)[None, None, :] / np.array(input_std)[None, None, :],
            ]
        )

        test_loader = DataLoader(
            test_dataset,
            places=places[0],
            batch_size=1,
            shuffle=False,
            num_workers=args.workers,
            collate_fn=collate_fn
        )

        ### evaluate ###
        mIoU = validate(test_loader, model)
        # print('mIoU: {}'.format(mIoU))
    return
def main():
    global args, best_mIoU
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(gpu) for gpu in args.gpus)
    args.gpus = len(args.gpus)

    if args.dataset == 'VOCAug' or args.dataset == 'VOC2012' or args.dataset == 'COCO':
        num_class = 21
        ignore_label = 255
        scale_series = [10, 20, 30, 60]
    elif args.dataset == 'Cityscapes':
        num_class = 19
        ignore_label = 255
        scale_series = [15, 30, 45, 90]
    elif args.dataset == 'ApolloScape':
        num_class = 37
        ignore_label = 255
    elif args.dataset == 'CULane' or args.dataset == 'L4E':
        num_class = 5
        ignore_label = 255
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = models.ERFNet(3, num_class)
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=range(args.gpus)).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mIoU = checkpoint['best_mIoU']
            torch.nn.Module.load_state_dict(model, checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True
    cudnn.fastest = True

    # Data loading code

    test_loader = torch.utils.data.DataLoader(getattr(ds, 'VOCAugDataSet')(
        data_list=args.val_list,
        transform=torchvision.transforms.Compose([
            tf.GroupRandomScaleNew(size=(args.img_width, args.img_height),
                                   interpolation=(cv2.INTER_LINEAR,
                                                  cv2.INTER_NEAREST)),
            tf.GroupNormalize(mean=(input_mean, (0, )),
                              std=(input_std, (1, ))),
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=False)

    # define loss function (criterion) optimizer and evaluator
    weights = [1.0 for _ in range(5)]
    weights[0] = 0.4
    class_weights = torch.FloatTensor(weights).cuda()
    criterion = torch.nn.NLLLoss(ignore_index=ignore_label,
                                 weight=class_weights).cuda()
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))
    evaluator = EvalSegmentation(num_class, ignore_label)

    ### evaluate ###
    validate(test_loader, model, criterion, 0, evaluator)
    return
        else:
            self.val = val
            self.sum += val * n
            self.count += n
            self.avg = self.sum / self.count

gpus = [0]
resume = "trained/ERFNet_trained.tar"
evaluate = False
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in gpus)
gpus = len(gpus)

num_class = 5
ignore_label = 255

model = models.ERFNet(num_class)
input_mean = model.input_mean
input_std = model.input_std
policies = model.get_optim_policies()
model = torch.nn.DataParallel(model, device_ids=range(gpus)).cuda()

if resume:
    if os.path.isfile(resume):
        print(("=> loading checkpoint '{}'".format(resume)))
        checkpoint = torch.load(resume)
        start_epoch = checkpoint['epoch']
        best_mIoU = checkpoint['best_mIoU']
        torch.nn.Module.load_state_dict(model, checkpoint['state_dict'])
        print(("=> loaded checkpoint '{}' (epoch {})".format(evaluate, checkpoint['epoch'])))
    else:
        print(("=> no checkpoint found at '{}'".format(resume)))
Exemple #5
0
img_width = 976
img_height = 208
H_offset = 400
LANE_THRESH = 110
image_path = "D:/github/dataset/driver_23_30frame/05151640_0419.MP4/00000.jpg"

def image_feed(img):
    image = img.copy()[H_offset:, :, :]
    h,w = image.shape[:2]
    image = cv2.resize(image, (img_width, img_height))
    image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
    # image = torch.tensor(image, requires_grad=False)
    image = image.unsqueeze_(0)
    return image, h, w

model = models.ERFNet(5)
checkpoint = torch.load(PATH)
model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()

cap = cv2.VideoCapture("2.mp4")
while 1:
    _, frame = cap.read()
    frame= cv2.resize(frame, (1280,720))
    if not _:
        break
    
    # imagez = cv2.imread(frame)
    imagez = frame
    input_img, H_ori, W_ori = image_feed(imagez)
def main():
    global args, best_mIoU
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(gpu) for gpu in args.gpus)
    #args.gpus = len(args.gpus)

    if args.dataset == 'VOCAug' or args.dataset == 'VOC2012' or args.dataset == 'COCO':
        num_class = 21
        ignore_label = 255
        scale_series = [10, 20, 30, 60]
    elif args.dataset == 'Cityscapes':
        num_class = 19
        ignore_label = 255  # 0
        scale_series = [15, 30, 45, 90]
    elif args.dataset == 'ApolloScape':
        num_class = 37  # merge the noise and ignore labels
        ignore_label = 255
    elif args.dataset == 'CULane':
        num_class = 5
        ignore_label = 255
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = models.ERFNet(num_class)
    input_mean = model.input_mean
    input_std = model.input_std

    model = model.cuda()
    model = torch.nn.DataParallel(model, device_ids=args.gpus)

    #model = torch.nn.DataParallel(model, device_ids=range(args.gpus)).cuda()

    def load_my_state_dict(
        model, state_dict
    ):  # custom function to load model when not all dict elements
        own_state = model.state_dict()
        ckpt_name = []
        cnt = 0
        for name, param in state_dict.items():
            if name not in list(own_state.keys()) or 'output_conv' in name:
                ckpt_name.append(name)
                continue
            own_state[name].copy_(param)
            cnt += 1
        print('#reused param: {}'.format(cnt))
        return model

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model = load_my_state_dict(model, checkpoint['state_dict'])
            # torch.nn.Module.load_state_dict(model, checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True
    cudnn.fastest = True

    # Data loading code
    train_loader = torch.utils.data.DataLoader(getattr(
        ds,
        args.dataset.replace("CULane", "VOCAug") + 'DataSet')(
            data_list=args.train_list,
            transform=torchvision.transforms.Compose([
                tf.GroupRandomScale(size=(0.595, 0.621),
                                    interpolation=(cv2.INTER_LINEAR,
                                                   cv2.INTER_NEAREST)),
                tf.GroupRandomCropRatio(size=(args.img_width,
                                              args.img_height)),
                tf.GroupRandomRotation(degree=(-1, 1),
                                       interpolation=(cv2.INTER_LINEAR,
                                                      cv2.INTER_NEAREST),
                                       padding=(input_mean, (ignore_label, ))),
                tf.GroupNormalize(mean=(input_mean, (0, )),
                                  std=(input_std, (1, ))),
            ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(getattr(
        ds,
        args.dataset.replace("CULane", "VOCAug") + 'DataSet')(
            data_list=args.val_list,
            transform=torchvision.transforms.Compose([
                tf.GroupRandomScale(size=(0.595, 0.621),
                                    interpolation=(cv2.INTER_LINEAR,
                                                   cv2.INTER_NEAREST)),
                tf.GroupRandomCropRatio(size=(args.img_width,
                                              args.img_height)),
                tf.GroupNormalize(mean=(input_mean, (0, )),
                                  std=(input_std, (1, ))),
            ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    # define loss function (criterion) optimizer and evaluator
    weights = [1.0 for _ in range(5)]
    weights[0] = 0.4
    class_weights = torch.FloatTensor(weights).cuda()
    criterion = torch.nn.NLLLoss(ignore_index=ignore_label,
                                 weight=class_weights).cuda()
    criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    evaluator = EvalSegmentation(num_class, ignore_label)

    args.evaluate = False  #True

    if args.evaluate:
        validate(val_loader, model, criterion, 0, evaluator)
        return

    for epoch in range(args.epochs):  # args.start_epoch
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, criterion_exist, optimizer,
              epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            mIoU = validate(val_loader, model, criterion,
                            (epoch + 1) * len(train_loader), evaluator)
            # remember best mIoU and save checkpoint
            is_best = mIoU > best_mIoU
            best_mIoU = max(mIoU, best_mIoU)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_mIoU': best_mIoU,
                }, is_best)
def main():
    global args, best_mIoU
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(gpu) for gpu in args.gpus)
    args.gpus = len(args.gpus)

    if args.no_partialbn:
        sync_bn.Synchronize.init(args.gpus)

    if args.dataset == 'VOCAug' or args.dataset == 'VOC2012' or args.dataset == 'COCO':
        num_class = 21
        ignore_label = 255
        scale_series = [10, 20, 30, 60]
    elif args.dataset == 'Cityscapes':
        num_class = 19
        ignore_label = 255  # 0
        scale_series = [15, 30, 45, 90]
    elif args.dataset == 'ApolloScape':
        num_class = 37  # merge the noise and ignore labels
        ignore_label = 255  # 0
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = models.ERFNet(
        num_class, partial_bn=not args.no_partialbn
    )  # models.PSPNet(num_class, base_model=args.arch, dropout=args.dropout, partial_bn=not args.no_partialbn)
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=range(args.gpus)).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mIoU = checkpoint['best_mIoU']
            torch.nn.Module.load_state_dict(model, checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True
    cudnn.fastest = True

    # Data loading code

    test_loader = torch.utils.data.DataLoader(getattr(
        ds,
        args.dataset.replace("ApolloScape", "VOCAug") + 'DataSet')(
            data_list=args.val_list,
            transform=[
                torchvision.transforms.Compose([
                    tf.GroupRandomScaleRatio(
                        size=(1692, 1692, 505, 505),
                        interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                    tf.GroupNormalize(mean=(input_mean, (0, )),
                                      std=(input_std, (1, ))),
                ]),
                torchvision.transforms.Compose([
                    tf.GroupRandomScaleRatio(
                        size=(1861, 1861, 556, 556),
                        interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                    tf.GroupNormalize(mean=(input_mean, (0, )),
                                      std=(input_std, (1, ))),
                ]),
                torchvision.transforms.Compose([
                    tf.GroupRandomScaleRatio(
                        size=(1624, 1624, 485, 485),
                        interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                    tf.GroupNormalize(mean=(input_mean, (0, )),
                                      std=(input_std, (1, ))),
                ]),
                torchvision.transforms.Compose([
                    tf.GroupRandomScaleRatio(
                        size=(2030, 2030, 606, 606),
                        interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                    tf.GroupNormalize(mean=(input_mean, (0, )),
                                      std=(input_std, (1, ))),
                ])
            ]),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=False)

    # define loss function (criterion) optimizer and evaluator
    weights = [1.0 for _ in range(37)]
    weights[0] = 0.05
    weights[36] = 0.05
    class_weights = torch.FloatTensor(weights).cuda()
    criterion = torch.nn.NLLLoss(ignore_index=ignore_label,
                                 weight=class_weights).cuda()
    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    evaluator = EvalSegmentation(num_class, ignore_label)

    ### evaluate ###
    validate(test_loader, model, criterion, 0, evaluator)
    return
def main():
    global best_mIoU, start_epoch

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if args.dataset == 'LaneDet':
        num_class = 20
        ignore_label = 255
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    # get places
    places = fluid.cuda_places()

    with fluid.dygraph.guard():
        model = models.ERFNet(num_class, [args.img_height, args.img_width])
        input_mean = model.input_mean
        input_std = model.input_std

        # Data loading code
        train_dataset = ds.LaneDataSet(
            dataset_path='datasets/PreliminaryData',
            data_list=args.train_list,
            transform=[
                tf.GroupRandomScale(size=(int(args.img_width), int(args.img_width * 1.2)),
                                    interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                tf.GroupRandomCropRatio(size=(args.img_width, args.img_height)),
                tf.GroupNormalize(mean=(input_mean, (0,)), std=(input_std, (1,))),
            ]
        )

        train_loader = DataLoader(
            train_dataset,
            places=places[0],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            drop_last=True
        )

        val_dataset = ds.LaneDataSet(
            dataset_path='datasets/PreliminaryData',
            data_list=args.train_list,
            transform=[
                tf.GroupRandomScale(size=args.img_width, interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                tf.GroupNormalize(mean=(input_mean, (0,)), std=(input_std, (1,))),
            ],
            is_val=False
        )

        val_loader = DataLoader(
            val_dataset,
            places=places[0],
            batch_size=1,
            shuffle=False,
            num_workers=args.workers,
        )

        # define loss function (criterion) optimizer and evaluator
        weights = [1.0 for _ in range(num_class)]
        weights[0] = 0.25
        weights = fluid.dygraph.to_variable(np.array(weights, dtype=np.float32))
        criterion = fluid.dygraph.NLLLoss(weight=weights, ignore_index=ignore_label)
        evaluator = EvalSegmentation(num_class, ignore_label)

        optimizer = fluid.optimizer.MomentumOptimizer(learning_rate=fluid.dygraph.CosineDecay(
                                                                    args.lr, len(train_loader), args.epochs),
                                                      momentum=args.momentum,
                                                      parameter_list=model.parameters(),
                                                      regularization=fluid.regularizer.L2Decay(
                                                          regularization_coeff=args.weight_decay))

        if args.resume:
            print(("=> loading checkpoint '{}'".format(args.resume)))
            start_epoch = int(''.join([x for x in args.resume.split('/')[-1] if x.isdigit()]))
            checkpoint, optim_checkpoint = fluid.load_dygraph(args.resume)
            model.load_dict(checkpoint)
            optimizer.set_dict(optim_checkpoint)
            print(("=> loaded checkpoint (epoch {})".format(start_epoch)))
        else:
            try:
                checkpoint, _ = fluid.load_dygraph(args.weight)
                model.load_dict(checkpoint)
                print("=> pretrained model loaded successfully")
            except:
                print(("=> no pretrained model found at '{}'".format(args.weight)))

        for epoch in range(start_epoch, args.epochs):
            # train for one epoch
            loss = train(train_loader, model, criterion, optimizer, epoch)

            # writer.add_scalar('lr', optimizer.current_step_lr(), epoch + 1)

            if (epoch + 1) % args.save_freq == 0 or epoch == args.epochs - 1:
                save_checkpoint(model.state_dict(), epoch)
                save_checkpoint(optimizer.state_dict(), epoch)

            # evaluate on validation set
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                mIoU = validate(val_loader, model, evaluator, epoch)

                # remember best mIoU
                is_best = mIoU > best_mIoU
                best_mIoU = max(mIoU, best_mIoU)
                if is_best:
                    tag_best(epoch, best_mIoU)
Exemple #9
0
def main():
    global args, best_mIoU
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
        str(gpu) for gpu in args.gpus)
    args.gpus = len(args.gpus)

    if args.no_partialbn:
        sync_bn.Synchronize.init(args.gpus)

    if args.dataset == 'VOCAug' or args.dataset == 'VOC2012' or args.dataset == 'COCO':
        num_class = 21
        ignore_label = 255
        scale_series = [10, 20, 30, 60]
    elif args.dataset == 'Cityscapes':
        num_class = 19
        ignore_label = 255  # 0
        scale_series = [15, 30, 45, 90]
    elif args.dataset == 'ApolloScape':
        num_class = 37  # merge the noise and ignore labels
        ignore_label = 255  # 0
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model = models.ERFNet(
        num_class, partial_bn=not args.no_partialbn
    )  # models.PSPNet(num_class, base_model=args.arch, dropout=args.dropout, partial_bn=not args.no_partialbn)
    input_mean = model.input_mean
    input_std = model.input_std
    # policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=range(args.gpus)).cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        # print(own_state.keys())
        ckpt_name = []
        cnt = 0
        for name, param in state_dict.items():
            if name.replace('module.features',
                            'module') not in list(own_state.keys()):
                ckpt_name.append(name)
                continue
            own_state[name.replace('module.features', 'module')].copy_(param)
            # print(cnt)
            cnt += 1
        print('#reused param: {}'.format(cnt))
        # print(ckpt_name)
        return model

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # weightspath = args.resume
            # best_mIoU = checkpoint['best_mIoU']
            # model = load_my_state_dict(model, checkpoint['state_dict'])
            torch.nn.Module.load_state_dict(model, checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True
    cudnn.fastest = True

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        getattr(
            ds,
            args.dataset.replace("ApolloScape", "VOCAug") + 'DataSet_train')
        (
            data_list=args.train_list,
            transform=torchvision.transforms.Compose([
                tf.GroupRandomScale(size=(0.5, 0.5),
                                    interpolation=(cv2.INTER_LINEAR,
                                                   cv2.INTER_NEAREST)),
                # tf.GroupRandomScaleRatio(size=(args.train_size, args.train_size + 20, int(args.train_size * 1 / 3), int(args.train_size * 1 / 3) + 20), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                # tf.GroupRandomRotation(degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), padding=(input_mean, (ignore_label, ))),
                tf.GroupRandomCropRatio(size=(args.train_size,
                                              int(args.train_size * 1 / 3))),
                tf.GroupRandomRotation(degree=(-10, 10),
                                       interpolation=(cv2.INTER_LINEAR,
                                                      cv2.INTER_NEAREST),
                                       padding=(input_mean, (ignore_label, ))),
                tf.GroupNormalize(mean=(input_mean, (0, )),
                                  std=(input_std, (1, ))),
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=False,
        drop_last=True)  # pin_memory=True

    val_loader = torch.utils.data.DataLoader(
        getattr(
            ds,
            args.dataset.replace("ApolloScape", "VOCAug") + 'DataSet_train')
        (
            data_list=args.val_list,
            transform=torchvision.transforms.Compose([
                tf.GroupRandomScale(size=(0.5, 0.5),
                                    interpolation=(cv2.INTER_LINEAR,
                                                   cv2.INTER_NEAREST)),
                # tf.GroupRandomScaleRatio(size=(args.test_size, args.test_size, int(args.test_size * 1 / 3), int(args.test_size * 1 / 3)), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)),
                tf.GroupRandomCropRatio(size=(args.train_size,
                                              int(args.train_size * 1 / 3))),
                tf.GroupNormalize(mean=(input_mean, (0, )),
                                  std=(input_std, (1, ))),
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=False)  # pin_memory=True

    # define loss function (criterion) optimizer and evaluator
    weights = [1.0 for _ in range(37)]
    weights[0] = 0.05
    weights[36] = 0.05
    class_weights = torch.FloatTensor(weights).cuda()
    criterion = torch.nn.NLLLoss(ignore_index=ignore_label,
                                 weight=class_weights).cuda()
    '''for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))'''
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    evaluator = EvalSegmentation(num_class, ignore_label)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, evaluator)
        return

    for epoch in range(args.epochs):  # args.start_epoch
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            mIoU = validate(val_loader, model, criterion,
                            (epoch + 1) * len(train_loader), evaluator)
            # remember best mIoU and save checkpoint
            is_best = mIoU > best_mIoU
            best_mIoU = max(mIoU, best_mIoU)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_mIoU': best_mIoU,
                }, is_best)