Beispiel #1
0
def torch_loader(data_path, size):
    if not os.path.exists(data_path + '/train'): download_cifar10(data_path)

    # Data loading code
    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'test')
    normalize = transforms.Normalize(mean=[0.4914, 0.48216, 0.44653],
                                     std=[0.24703, 0.24349, 0.26159])
    tfms = [transforms.ToTensor(), normalize]

    scale_size = 40
    padding = int((scale_size - size) / 2)
    train_tfms = transforms.Compose([
        pad,  # TODO: use `padding` rather than assuming 4
        transforms.RandomCrop(size),
        transforms.ColorJitter(.25, .25, .25),
        transforms.RandomRotation(2),
        transforms.RandomHorizontalFlip(),
    ] + tfms)
    train_dataset = datasets.ImageFolder(traindir, train_tfms)
    train_sampler = (
        torch.utils.data.distributed.DistributedSampler(train_dataset)
        if args.distributed else None)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_tfms = transforms.Compose(tfms)
    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, val_tfms),
                                             batch_size=args.batch_size * 2,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    aug_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, train_tfms),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    train_loader = DataPrefetcher(train_loader)
    val_loader = DataPrefetcher(val_loader)
    aug_loader = DataPrefetcher(aug_loader)
    if args.prof:
        train_loader.stop_after = 200
        val_loader.stop_after = 0

    data = ModelData(data_path, train_loader, val_loader)
    data.sz = args.sz
    data.aug_dl = aug_loader
    if train_sampler is not None: data.trn_sampler = train_sampler

    return data
def torch_loader(data_path, size):
    # Data loading code
    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_tfms = transforms.Compose([
        transforms.RandomResizedCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(.4, .4, .4),
        transforms.ToTensor(),
        Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
        normalize,
    ])
    train_dataset = datasets.ImageFolder(traindir, train_tfms)
    train_sampler = (
        torch.utils.data.distributed.DistributedSampler(train_dataset)
        if args.distributed else None)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_tfms = transforms.Compose([
        transforms.Resize(int(size * 1.14)),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        normalize,
    ])
    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, val_tfms),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    aug_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, train_tfms),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    train_loader = DataPrefetcher(train_loader)
    val_loader = DataPrefetcher(val_loader)
    aug_loader = DataPrefetcher(aug_loader)
    if args.prof:
        train_loader.stop_after = 200
        val_loader.stop_after = 0

    data = TorchModelData(data_path, train_loader, val_loader, aug_loader)
    if train_sampler is not None: data.trn_sampler = train_sampler
    return data, train_sampler