Exemplo n.º 1
0
def load_data():
    train_transform = transforms.Compose([
        transforms.RandomCrop(IM_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.RandomCrop(IM_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    trainset = dataset.LSUN(DATADIR, 'train', train_transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True,
                                              drop_last=True)

    print("Train set size: " + str(len(trainset)))

    valset = dataset.LSUN(DATADIR, 'val', val_transform)
    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=BATCH_SIZE,
                                            shuffle=False,
                                            num_workers=2,
                                            pin_memory=True,
                                            drop_last=True)
    print("Val set size: " + str(len(valset)))
    return trainloader, valloader
Exemplo n.º 2
0
def train_seg(args):
    rand_state = np.random.RandomState(1311)
    torch.manual_seed(1311)
    device = 'cuda' if (torch.cuda.is_available()) else 'cpu'

    # We have 2975 images total in the training set, so let's choose 500 for 3 cycles,
    # 1500 images total (~1/2 of total)
    images_per_cycle = 150

    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    # Data loading code
    data_dir = args.data_dir
    info = json.load(open(join(data_dir, 'info.json'), 'r'))
    normalize = data_transforms.Normalize(mean=info['mean'], std=info['std'])
    t = []
    if args.random_rotate > 0:
        t.append(data_transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(data_transforms.RandomScale(args.random_scale))
    t.extend([
        data_transforms.RandomCrop(crop_size),
        data_transforms.RandomHorizontalFlip(),
        data_transforms.ToTensor(), normalize
    ])
    dataset = SegList(data_dir,
                      'train',
                      data_transforms.Compose(t),
                      list_dir=args.list_dir)
    training_dataset_no_augmentation = SegList(
        data_dir,
        'train',
        data_transforms.Compose([data_transforms.ToTensor(), normalize]),
        list_dir=args.list_dir)

    unlabeled_idx = list(range(len(dataset)))
    labeled_idx = []
    validation_accuracies = list()
    validation_mAPs = list()
    progress = tqdm.tqdm(range(10))
    for cycle in progress:
        single_model = DRNSeg(args.arch, args.classes, None, pretrained=True)
        if args.pretrained:
            single_model.load_state_dict(torch.load(args.pretrained))

        # Wrap our model in Active Learning Model.
        if args.use_loss_prediction_al:
            single_model = ActiveLearning(single_model,
                                          global_avg_pool_size=6,
                                          fc_width=256)
        elif args.use_discriminative_al:
            single_model = DiscriminativeActiveLearning(single_model)
        optim_parameters = single_model.optim_parameters()

        model = torch.nn.DataParallel(single_model).cuda()

        # Don't apply a 'mean' reduction, we need the whole loss vector.
        criterion = nn.NLLLoss(ignore_index=255, reduction='none')

        criterion.cuda()

        if args.choose_images_with_highest_loss:
            # Choosing images based on the ground truth labels.
            # We want to check if predicting loss with 100% accuracy would result to
            # a good active learning algorithm.
            new_indices, entropies = choose_new_labeled_indices_using_gt(
                model, cycle, rand_state, unlabeled_idx,
                training_dataset_no_augmentation, device, criterion,
                images_per_cycle)
        else:
            new_indices, entropies = choose_new_labeled_indices(
                model,
                training_dataset_no_augmentation,
                cycle,
                rand_state,
                labeled_idx,
                unlabeled_idx,
                device,
                images_per_cycle,
                args.use_loss_prediction_al,
                args.use_discriminative_al,
                input_pickle_file=None)
        labeled_idx.extend(new_indices)
        print("Running on {} labeled images.".format(len(labeled_idx)))
        if args.output_superannotate_csv_file is not None:
            # Write image paths to csv file which can be uploaded to annotate.online.
            write_entropies_csv(training_dataset_no_augmentation, new_indices,
                                entropies, args.output_superannotate_csv_file)

        train_loader = torch.utils.data.DataLoader(data.Subset(
            dataset, labeled_idx),
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers,
                                                   pin_memory=True,
                                                   drop_last=True)
        val_loader = torch.utils.data.DataLoader(SegList(
            data_dir,
            'val',
            data_transforms.Compose([
                data_transforms.RandomCrop(crop_size),
                data_transforms.ToTensor(),
                normalize,
            ]),
            list_dir=args.list_dir),
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=num_workers,
                                                 pin_memory=True,
                                                 drop_last=True)

        # define loss function (criterion) and optimizer.
        optimizer = torch.optim.SGD(optim_parameters,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        cudnn.benchmark = True
        best_prec1 = 0
        best_mAP = 0
        start_epoch = 0

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if args.evaluate:
            validate(val_loader,
                     model,
                     criterion,
                     eval_score=accuracy,
                     num_classes=args.classes,
                     use_loss_prediction_al=args.use_loss_prediction_al)
            return

        progress_epoch = tqdm.tqdm(range(start_epoch, args.epochs))
        for epoch in progress_epoch:
            lr = adjust_learning_rate(args, optimizer, epoch)
            logger.info('Cycle {0} Epoch: [{1}]\tlr {2:.06f}'.format(
                cycle, epoch, lr))
            # train for one epoch
            train(train_loader,
                  model,
                  criterion,
                  optimizer,
                  epoch,
                  eval_score=accuracy,
                  use_loss_prediction_al=args.use_loss_prediction_al,
                  active_learning_lamda=args.lamda)

            # evaluate on validation set
            prec1, mAP1 = validate(
                val_loader,
                model,
                criterion,
                eval_score=accuracy,
                num_classes=args.classes,
                use_loss_prediction_al=args.use_loss_prediction_al)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            best_mAP = max(mAP1, best_mAP)
            checkpoint_path = os.path.join(args.save_path,
                                           'checkpoint_latest.pth.tar')
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'best_mAP': best_mAP,
                },
                is_best,
                filename=checkpoint_path)
            if (epoch + 1) % args.save_iter == 0:
                history_path = os.path.join(
                    args.save_path,
                    'checkpoint_{:03d}.pth.tar'.format(epoch + 1))
                shutil.copyfile(checkpoint_path, history_path)
        validation_accuracies.append(best_prec1)
        validation_mAPs.append(best_mAP)
        print("{} accuracies: {} mAPs {}".format(
            "Active Learning" if args.use_loss_prediction_al else "Random",
            str(validation_accuracies), str(validation_mAPs)))
Exemplo n.º 3
0
    #     t.append(transforms.RandomScale(0))
    normalize = transforms.Normalize(mean=[102.9801, 115.9465, 122.7717],
                                     std=[1., 1., 1.])
    # t.extend([transforms.RandomCrop((768, 320), 4),
    #           transforms.RandomHorizontalFlip(),
    #           # transforms.ToNumpy(1/255.0),
    #           # transforms.RandomGammaImg((0.7,1.5)),
    #           # transforms.RandomBrightnessImg(0.2),
    #           # transforms.RandomContrastImg((0.8, 1.2)),
    #           # transforms.RandomGaussianNoiseImg(0.02),
    #           # transforms.ToNumpy(255.0),
    #           transforms.ToTensor(convert_pix_range=False),
    #           normalize])

    t.extend([
        transforms.RandomCrop((768, 320), 4),
        transforms.RandomHorizontalFlip(),
        # transforms.ToNumpy(1/255.0),
        # transforms.RandomGammaImg((0.7,1.5)),
        # transforms.RandomBrightnessImg(0.2),
        # transforms.RandomContrastImg((0.8, 1.2)),
        # transforms.RandomGaussianNoiseImg(0.02),
        # transforms.ToNumpy(255.0),
        transforms.ToTensor(convert_pix_range=False),
        normalize
    ])

    # data_dir = '/home/hzjiang/workspace/Data/CityScapes'
    data_dir = '/home/hzjiang/workspace/Data/KITTI_Semantics'
    train_data = SegList(data_dir,
                         'train',
Exemplo n.º 4
0
def train_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size
    checkpoint_dir = args.checkpoint_dir

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    # print(dla_up.__dict__.get(args.arch))
    single_model = dla_up.__dict__.get(args.arch)(classes=args.classes,
                                                  down_ratio=args.down)

    single_model = convert_model(single_model)

    model = torch.nn.DataParallel(single_model).cuda()
    print('model_created')
    if args.edge_weight > 0:
        weight = torch.from_numpy(
            np.array([1, args.edge_weight], dtype=np.float32))
        # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
        criterion = nn.NLLLoss2d(ignore_index=-1, weight=weight)
    else:
        # criterion = nn.NLLLoss2d(ignore_index=255)
        criterion = nn.NLLLoss2d(ignore_index=-1)

    criterion.cuda()

    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))  #TODO
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend([transforms.RandomHorizontalFlip()])  #TODO

    t_val = []
    t_val.append(transforms.RandomCrop(crop_size))

    dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/image_02/'
    dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/' + args.target + '/'
    my_train = BasicDataset(dir_img,
                            dir_mask,
                            transforms.Compose(t),
                            is_train=True)

    val_dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/image_02/'
    val_dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/' + args.target + '/'
    my_val = BasicDataset(val_dir_img,
                          val_dir_mask,
                          transforms.Compose(t_val),
                          is_train=True)

    train_loader = torch.utils.data.DataLoader(my_train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        my_val,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)  #TODO  batch_size
    print("loader created")
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = None  #TODO

    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    confusion_labels = np.arange(0, 5)
    val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                  ignore_label=-1)

    if args.evaluate:
        confusion_labels = np.arange(0, 2)
        val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                      ignore_label=-1,
                                                      reduce=True)
        validate(val_loader,
                 model,
                 criterion,
                 confusion_matrix=val_confusion_matrix)
        return
    writer = SummaryWriter(comment=args.log)

    # TODO test val
    # print("test val")
    # prec1 = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix)

    for epoch in range(start_epoch, args.epochs):
        train_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                        ignore_label=-1)
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch

        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              lr_scheduler,
              confusion_matrix=train_confusion_matrix,
              writer=writer)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict()
            },
            is_best=False,
            filename=checkpoint_path)

        # evaluate on validation set
        val_confusion_matrix = RunningConfusionMatrix(confusion_labels,
                                                      ignore_label=-1)
        prec1, loss_val = validate(val_loader,
                                   model,
                                   criterion,
                                   confusion_matrix=val_confusion_matrix)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        writer.add_scalar('mIoU/epoch', prec1, epoch + 1)
        writer.add_scalar('loss/epoch', loss_val, epoch + 1)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)

        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)

    writer.close()
Exemplo n.º 5
0
def train_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    single_model = dla_up.__dict__.get(args.arch)(args.classes,
                                                  pretrained_base,
                                                  down_ratio=args.down)
    model = torch.nn.DataParallel(single_model).cuda()
    if args.edge_weight > 0:
        weight = torch.from_numpy(
            np.array([1, args.edge_weight], dtype=np.float32))
        criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
    else:
        criterion = nn.NLLLoss2d(ignore_index=255)

    criterion.cuda()

    data_dir = args.data_dir
    info = dataset.load_dataset_info(data_dir)
    normalize = transforms.Normalize(mean=info.mean, std=info.std)
    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor(), normalize])
    train_loader = torch.utils.data.DataLoader(SegList(
        data_dir, 'train', transforms.Compose(t), binary=(args.classes == 2)),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        SegList(
            data_dir,
            'val',
            transforms.Compose([
                transforms.RandomCrop(crop_size),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
            binary=(args.classes == 2)),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, eval_score=accuracy)
        return

    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              eval_score=accuracy)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, eval_score=accuracy)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        checkpoint_path = 'checkpoint_latest.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)
        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)
Exemplo n.º 6
0
def get_loader(args, split, out_name=False, customized_task_set=None):
    """Returns data loader depending on dataset and split"""
    dataset = args.dataset
    loader = None

    if customized_task_set is None:
        task_set = args.task_set
    else:
        task_set = customized_task_set

    if dataset == 'taskonomy':
        print('using taskonomy')
        if split == 'train':
            loader = torch.utils.data.DataLoader(TaskonomyLoader(
                root=args.data_dir,
                is_training=True,
                threshold=1200,
                task_set=task_set,
                model_whitelist=None,
                model_limit=30,
                output_size=None),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

        if split == 'val':
            loader = torch.utils.data.DataLoader(
                TaskonomyLoader(root=args.data_dir,
                                is_training=False,
                                threshold=1200,
                                task_set=task_set,
                                model_whitelist=None,
                                model_limit=30,
                                output_size=None),
                batch_size=args.test_batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)

        if split == 'adv_val':
            loader = torch.utils.data.DataLoader(
                TaskonomyLoader(root=args.data_dir,
                                is_training=False,
                                threshold=1200,
                                task_set=task_set,
                                model_whitelist=None,
                                model_limit=30,
                                output_size=None),
                batch_size=args.test_batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)

    elif dataset == 'voc':
        if split == 'train':
            loader = torch.utils.data.DataLoader(VOCSegmentation(
                args=args, base_dir=args.data_dir, split='train'),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)
        elif split == 'val':
            loader = torch.utils.data.DataLoader(
                VOCSegmentation(args=args,
                                base_dir=args.data_dir,
                                split='val',
                                out_name=out_name),
                batch_size=args.test_batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)
        elif split == 'adv_val':
            loader = torch.utils.data.DataLoader(VOCSegmentation(
                args=args,
                base_dir=args.data_dir,
                split='val',
                out_name=out_name),
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    elif dataset == 'coco':
        if split == 'train':
            loader = torch.utils.data.DataLoader(COCOSegmentation(
                args=args, base_dir=args.data_dir, split='train'),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)
        elif split == 'val':
            loader = torch.utils.data.DataLoader(
                COCOSegmentation(args=args,
                                 base_dir=args.data_dir,
                                 split='val',
                                 out_name=out_name),
                batch_size=args.test_batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)
        elif split == 'adv_val':
            loader = torch.utils.data.DataLoader(COCOSegmentation(
                args=args,
                base_dir=args.data_dir,
                split='val',
                out_name=out_name),
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    elif dataset == 'cityscape':
        data_dir = args.data_dir
        info = json.load(open(join(data_dir, 'info.json'), 'r'))
        normalize = transforms.Normalize(mean=info['mean'], std=info['std'])
        t = []
        if args.random_rotate > 0:
            t.append(transforms.RandomRotate(args.random_rotate))
        if args.random_scale > 0:
            t.append(transforms.RandomScale(args.random_scale))
        t.extend([
            transforms.RandomCrop(args.crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])

        task_set_present = hasattr(args, 'task_set')
        if split == 'train':
            if task_set_present:
                print(
                    "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n"
                )
                loader = torch.utils.data.DataLoader(
                    SegDepthList(data_dir,
                                 'train',
                                 transforms.Compose(t),
                                 list_dir=args.list_dir),
                    batch_size=args.batch_size,
                    shuffle=True,
                    num_workers=args.workers,
                    pin_memory=True,
                    drop_last=True)
            else:
                loader = torch.utils.data.DataLoader(
                    SegList(data_dir,
                            'train',
                            transforms.Compose(t),
                            list_dir=args.list_dir),
                    batch_size=args.batch_size,
                    shuffle=True,
                    num_workers=args.workers,
                    pin_memory=True,
                    drop_last=True)
        elif split == 'val':
            if args.task_set != []:
                print(
                    "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n"
                )
                loader = torch.utils.data.DataLoader(
                    SegDepthList(data_dir,
                                 'val',
                                 transforms.Compose([
                                     transforms.ToTensor(),
                                     normalize,
                                 ]),
                                 list_dir=args.list_dir,
                                 out_name=out_name),
                    batch_size=args.test_batch_size,
                    shuffle=False,
                    num_workers=args.workers,
                    pin_memory=True,
                    drop_last=True)
            else:
                print("city test eval!")
                loader = torch.utils.data.DataLoader(
                    SegList(data_dir,
                            'val',
                            transforms.Compose([
                                transforms.ToTensor(),
                                normalize,
                            ]),
                            list_dir=args.list_dir,
                            out_name=out_name),
                    batch_size=args.test_batch_size,
                    shuffle=False,
                    num_workers=args.workers,
                    pin_memory=True,
                    drop_last=True)
        elif split == 'adv_val':  # has batch size 1
            if task_set_present:
                print(
                    "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n"
                )
                loader = torch.utils.data.DataLoader(SegDepthList(
                    data_dir,
                    'val',
                    transforms.Compose([
                        transforms.ToTensor(),
                        normalize,
                    ]),
                    list_dir=args.list_dir,
                    out_name=out_name),
                                                     batch_size=1,
                                                     shuffle=False,
                                                     num_workers=args.workers,
                                                     pin_memory=True,
                                                     drop_last=True)
            else:
                loader = torch.utils.data.DataLoader(SegList(
                    data_dir,
                    'val',
                    transforms.Compose([
                        transforms.ToTensor(),
                        normalize,
                    ]),
                    list_dir=args.list_dir,
                    out_name=out_name),
                                                     batch_size=1,
                                                     shuffle=False,
                                                     num_workers=args.workers,
                                                     pin_memory=True,
                                                     drop_last=True)

    return loader
Exemplo n.º 7
0
        self.image_list = [line.strip() for line in open(image_path, 'r')]
        if exists(label_path):
            self.label_list = [line.strip() for line in open(label_path, 'r')]
            assert len(self.image_list) == len(self.label_list)


if __name__ == "__main__":
    #Testing the dataloader
    data_dir = "/home/amogh/data/datasets/drn_data/DRN-move/cityscape_dataset/"
    info = json.load(open(join(data_dir, 'info.json'), 'r'))
    normalize = transforms.Normalize(mean=info['mean'], std=info['std'])
    t = []
    # t.append(transforms.RandomRotate(0))
    # t.append(transforms.RandomScale(0))
    t.extend([
        transforms.RandomCrop(896),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    # loader = SegDepthList(data_dir="/home/amogh/data/datasets/drn_data/DRN-move/cityscape_dataset/",
    loader = torch.utils.data.DataLoader(SegDepthList(data_dir,
                                                      'train',
                                                      transforms.Compose(t),
                                                      list_dir=None),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=1,
                                         pin_memory=True,
                                         drop_last=True)
Exemplo n.º 8
0
def train_seg(args):
    writer = SummaryWriter(comment=args.log)
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size
    checkpoint_dir = args.checkpoint_dir

    print(' '.join(sys.argv))
    # logger.info(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    pretrained_base = args.pretrained_base
    # print(dla_up.__dict__.get(args.arch))
    single_model = dla_up.__dict__.get(args.arch)(classes=args.classes,
                                                  down_ratio=args.down)
    model = torch.nn.DataParallel(single_model).cuda()
    print('model_created')
    if args.bg_weight > 0:
        weight_array = np.ones(args.classes, dtype=np.float32)
        weight_array[0] = args.bg_weight
        weight = torch.from_numpy(weight_array)
        # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
        criterion = nn.NLLLoss2d(ignore_index=255, weight=weight)
    else:
        # criterion = nn.NLLLoss2d(ignore_index=255)
        criterion = nn.NLLLoss2d(ignore_index=255)

    criterion.cuda()

    t = []
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.append(transforms.RandomCrop(crop_size))  #TODO
    if args.random_color:
        t.append(transforms.RandomJitter(0.4, 0.4, 0.4))
    t.extend([transforms.RandomHorizontalFlip()])  #TODO

    t_val = []
    t_val.append(transforms.RandomCrop(crop_size))

    train_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_train2017.json'
    train_root = '/shared/xudongliu/COCO/train2017/train2017'
    my_train = COCOSeg(train_root,
                       train_json,
                       transforms.Compose(t),
                       is_train=True)

    val_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_val2017.json'
    val_root = '/shared/xudongliu/COCO/2017val/val2017'
    my_val = COCOSeg(val_root,
                     val_json,
                     transforms.Compose(t_val),
                     is_train=True)

    train_loader = torch.utils.data.DataLoader(my_train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        my_val,
        batch_size=20,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True)  #TODO  batch_size
    print("loader created")

    # optimizer = torch.optim.Adam(single_model.optim_parameters(),
    #                             args.lr,
    #                              weight_decay=args.weight_decay) #TODO adam optimizer
    optimizer = torch.optim.SGD(
        single_model.optim_parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)  #TODO adam optimizer

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                              T_max=32)  #TODO
    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, eval_score=accuracy)
        return

    # TODO test val
    # print("test val")
    # prec1 = validate(val_loader, model, criterion, eval_score=accuracy)

    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer, epoch)
        print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch

        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              lr_scheduler,
              eval_score=accuracy,
              writer=writer)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict()
            },
            is_best=False,
            filename=checkpoint_path)

        # evaluate on validation set
        prec1, loss_val, recall_val = validate(val_loader,
                                               model,
                                               criterion,
                                               eval_score=accuracy)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        writer.add_scalar('accuracy/epoch', prec1, epoch + 1)
        writer.add_scalar('loss/epoch', loss_val, epoch + 1)
        writer.add_scalar('recall/epoch', recall_val, epoch + 1)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_{}.pth.tar'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path)

        if (epoch + 1) % args.save_freq == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)

    writer.close()
Exemplo n.º 9
0
def train_cnn(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = cfg['CROP_SIZE']

    for k, v in args.__dict__.items():
        print(k, ':', v)

    single_model = QGCNN()
    model = torch.nn.DataParallel(single_model)

    if cfg['FEATS']:
        feat_names, weights = zip(*(tuple(*f.items()) for f in cfg['FEATS']))
    else:
        feat_names, weights = None, None

    criterion = ComLoss(cfg['IQA_MODEL'],
                        weights,
                        feat_names,
                        patch_size=cfg['PATCH_SIZE'],
                        pixel_criterion=cfg['CRITERION'])
    criterion.cuda()

    # Data loading
    data_dir = cfg['DATA_DIR']
    list_dir = cfg['LIST_DIR']
    t = [
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]
    # Note that the cropsize could have a significant influence,
    # i.e., with a small cropsize the model would get overfitted
    # easily thus hard to train
    train_loader = torch.utils.data.DataLoader(DataList(data_dir,
                                                        'train',
                                                        transforms.Compose(t),
                                                        list_dir=list_dir),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    # The cropsize of the validation set dramatically affects the
    # evaluation accuracy, which means the quality of the whole
    # image might be very different from that of its cropped patches.
    #
    # Try setting batch_size = 1 and no crop (disable RandomCrop)
    # to improve the effect of early stopping.
    val_loader = DataList(data_dir,
                          'val',
                          transforms.Compose([transforms.ToTensor()]),
                          list_dir=list_dir)

    optimizer = torch.optim.Adam(single_model.parameters(),
                                 lr=args.lr,
                                 betas=(0.9, 0.99),
                                 weight_decay=args.weight_decay)

    cudnn.benchmark = True

    weight_dir = join(out_dir, 'weights/')
    if not exists(weight_dir):
        os.mkdir(weight_dir)

    best_prec = 0
    start_epoch = 0

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logger_s.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            logger_s.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger_f.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger_f.warning("=> no checkpoint found at '{}'".format(
                args.resume))

    if args.evaluate:
        validate(val_loader, model.cuda(), criterion, eval_score=accuracy)
        return

    for epoch in range(start_epoch, args.epochs):

        lr = adjust_learning_rate(args, optimizer, epoch)

        if criterion.weights is not None and (epoch + 1) % 100 == 0:
            criterion.weights /= 10.0

        logger_s.info('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch
        train(train_loader,
              model.cuda(),
              criterion,
              optimizer,
              epoch,
              eval_score=accuracy)

        # Evaluate on validation set
        prec = validate(val_loader,
                        model.cuda(),
                        criterion,
                        eval_score=accuracy)

        is_best = prec > best_prec
        best_prec = max(prec, best_prec)
        logger_s.info('current best {:.6f}'.format(best_prec))

        checkpoint_path = join(weight_dir, 'checkpoint_latest.pkl')
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec': best_prec,
            },
            is_best,
            filename=checkpoint_path)

        if (epoch + 1) % args.store_interval == 0:
            history_path = join(weight_dir,
                                'checkpoint_{:03d}.pkl'.format(epoch + 1))
            shutil.copyfile(checkpoint_path, history_path)
Exemplo n.º 10
0
def train_seg(args):
    batch_size = args.batch_size
    num_workers = args.workers
    crop_size = args.crop_size

    print(' '.join(sys.argv))

    for k, v in args.__dict__.items():
        print(k, ':', v)

    single_model = DRNSeg(args.arch, args.classes, None, pretrained=True)
    if args.pretrained:
        single_model.load_state_dict(torch.load(args.pretrained))
    model = torch.nn.DataParallel(single_model).cuda()
    criterion = nn.NLLLoss2d(ignore_index=255)

    criterion.cuda()

    # Data loading code
    data_dir = args.data_dir
    info = json.load(open(join(data_dir, 'info.json'), 'r'))
    normalize = transforms.Normalize(mean=info['mean'], std=info['std'])
    t = []
    if args.downsample:
        t.append(transforms.Scale(0.5))
    if args.random_rotate > 0:
        t.append(transforms.RandomRotate(args.random_rotate))
    if args.random_scale > 0:
        t.append(transforms.RandomScale(args.random_scale))
    t.extend([
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    train_loader = torch.utils.data.DataLoader(SegList(data_dir, 'train',
                                                       transforms.Compose(t)),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(SegList(
        data_dir, 'val',
        transforms.Compose([
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             drop_last=True)

    # define loss function (criterion) and pptimizer
    optimizer = torch.optim.SGD(single_model.optim_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True
    best_prec1 = 0
    start_epoch = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.evaluate:
        validate(val_loader, model, criterion, eval_score=accuracy)
        return

    for epoch in range(start_epoch, args.epochs):
        lr = adjust_learning_rate(args, optimizer, epoch)
        logger.info('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr))
        # train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              eval_score=accuracy)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, eval_score=accuracy)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        checkpoint_path = 'checkpoint_latest.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=checkpoint_path,
            prefix=args.arch)
        if (epoch + 1) % 10 == 0:
            history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)
            shutil.copyfile(checkpoint_path, history_path)
            # save historical data to s3
            upload_to_s3(history_path, prefix=args.arch)
        # save latest checkpoint to s3
        try:
            upload_to_s3(checkpoint_path, prefix=args.arch)
        except:
            logging.info('failed to upload latest checkpoint to s3')