def torch_loader(data_path, use_val_sampler=True, min_scale=0.08, bs=192): traindir = os.path.join(data_path, 'train') valdir = os.path.join(data_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) tensor_tfm = [transforms.ToTensor(), normalize] train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(args.sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip(), ] + tensor_tfm)) val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(int(args.sz*1.14)), transforms.CenterCrop(args.sz), ] + tensor_tfm)) train_sampler = (torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None) val_sampler = (torch.utils.data.distributed.DistributedSampler(val_dataset) if args.distributed else None) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=bs, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler if use_val_sampler else None) data = ModelData(data_path, train_loader, val_loader) data.sz = args.sz if train_sampler is not None: data.trn_sampler,data.val_sampler = train_sampler,val_sampler return data
def torch_loader(data_path, size): # Data loading code traindir = os.path.join(data_path, 'train') valdir = os.path.join(data_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_tfms = transforms.Compose([ transforms.RandomResizedCrop(size), transforms.RandomHorizontalFlip(), transforms.ColorJitter(.4, .4, .4), transforms.ToTensor(), Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), normalize, ]) train_dataset = datasets.ImageFolder(traindir, train_tfms) train_sampler = ( torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_tfms = transforms.Compose([ transforms.Resize(int(size * 1.14)), transforms.CenterCrop(size), transforms.ToTensor(), normalize, ]) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, val_tfms), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) aug_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, train_tfms), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) train_loader = DataPrefetcher(train_loader) val_loader = DataPrefetcher(val_loader) aug_loader = DataPrefetcher(aug_loader) if args.prof: train_loader.stop_after = 200 val_loader.stop_after = 0 data = TorchModelData(data_path, train_loader, val_loader, aug_loader) if train_sampler is not None: data.trn_sampler = train_sampler return data, train_sampler