def get_loader(args, ds_type): if ds_type is not 'train' and ds_type is not 'val': raise ValueError("ds_type has to be either 'train' or 'val'") if args.loader == 'pytorch': if ds_type == 'train': dataset = imageDataset(args.frames, args.is_cropped, args.crop_size, os.path.join(args.root, 'train'), args.batchsize, args.world_size) if args.world_size > 1: sampler = torch.utils.data.distributed.DistributedSampler( dataset) else: sampler = torch.utils.data.RandomSampler(dataset) loader = DataLoader(dataset, batch_size=args.batchsize, shuffle=(sampler is None), num_workers=0, pin_memory=True, sampler=sampler, drop_last=True) effective_bsz = args.batchsize * float(args.world_size) batches = math.ceil(len(dataset) / float(effective_bsz)) if ds_type == 'val': dataset = imageDataset(args.frames, args.is_cropped, args.crop_size, os.path.join(args.root, 'val'), args.batchsize, args.world_size) if args.world_size > 1: sampler = torch.utils.data.distributed.DistributedSampler( dataset) else: sampler = torch.utils.data.RandomSampler(dataset) loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, sampler=sampler, drop_last=True) batches = math.ceil(len(dataset) / float(args.world_size)) elif args.loader == 'DALI': loader = DALILoader(args.batchsize, os.path.join(args.root, ds_type), args.frames, args.crop_size) batches = len(loader) sampler = None else: raise ValueError('%s is not a valid option for --loader' % args.loader) return loader, batches, sampler
def get_loader(args): if args.loader == 'pytorch': dataset = imageDataset(args.frames, args.is_cropped, args.crop_size, args.root, args.batchsize) sampler = torch.utils.data.sampler.RandomSampler(dataset) train_loader = DataLoader(dataset, batch_size=args.batchsize, shuffle=(sampler is None), num_workers=10, pin_memory=True, sampler=sampler, drop_last=True) train_batches = len(dataset) elif args.loader == 'lintel': dataset = lintelDataset(args.frames, args.is_cropped, args.crop_size, args.root, args.batchsize) sampler = torch.utils.data.sampler.RandomSampler(dataset) train_loader = DataLoader(dataset, batch_size=args.batchsize, shuffle=(sampler is None), num_workers=10, pin_memory=True, sampler=sampler, drop_last=True) train_batches = len(dataset) elif args.loader == 'NVVL': train_loader = NVVL(args.frames, args.is_cropped, args.crop_size, args.root, batchsize=args.batchsize, shuffle=True, fp16=args.fp16) train_batches = len(train_loader) else: raise ValueError('%s is not a valid option for --loader' % args.loader) return train_loader, train_batches
def get_loader(args): if args.loader == 'pytorch': dataset = imageDataset(args.frames, args.is_cropped, args.crop_size, os.path.join(args.root, 'train'), args.batchsize, args.world_size) sampler = torch.utils.data.sampler.RandomSampler(dataset) train_loader = DataLoader(dataset, batch_size=args.batchsize, shuffle=(sampler is None), num_workers=10, pin_memory=True, sampler=sampler, drop_last=True) effective_bsz = args.batchsize * float(args.world_size) train_batches = math.ceil(len(dataset) / float(effective_bsz)) dataset = imageDataset(args.frames, args.is_cropped, args.crop_size, os.path.join(args.root, 'val'), args.batchsize, args.world_size) sampler = torch.utils.data.sampler.RandomSampler(dataset) val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, sampler=sampler, drop_last=True) val_batches = math.ceil(len(dataset) / float(args.world_size)) elif args.loader == 'NVVL': train_loader = NVVL(args.frames, args.is_cropped, args.crop_size, os.path.join(args.root, 'train'), batchsize=args.batchsize, shuffle=True, distributed=False, device_id=args.rank % 8, fp16=args.fp16) train_batches = len(train_loader) val_loader = NVVL(args.frames, args.is_cropped, args.crop_size, os.path.join(args.root, 'val'), batchsize=1, shuffle=True, distributed=False, device_id=args.rank % 8, fp16=args.fp16) val_batches = len(val_loader) sampler = None elif args.loader == 'DALI': train_loader = DaliLoader(args.batchsize, os.path.join(args.root, 'train'), args.frames) train_batches = len(train_loader) val_loader = DaliLoader(args.batchsize, os.path.join(args.root, 'val'), args.frames) val_batches = len(val_loader) sampler = None else: raise ValueError('%s is not a valid option for --loader' % args.loader) print(train_loader, train_batches, val_loader, val_batches) return train_loader, train_batches, val_loader, val_batches, sampler