def CreateSrcDataLoader(args, mode=None): if args.source == 'gta5' and mode == 'train_semseg': source_dataset = GTA5DataSet( args.data_dir, './dataset/gta5_list/train_semseg_net.txt', crop_size=image_sizes['cityscapes'], resize=image_sizes['gta5'], mean=IMG_MEAN) elif args.source == 'gta5' and mode == 'val_semseg': source_dataset = GTA5DataSet(args.data_dir, './dataset/gta5_list/val_semseg_net.txt', crop_size=image_sizes['cityscapes'], resize=image_sizes['gta5'], mean=IMG_MEAN) elif args.source == 'gta5' and mode is None: source_dataset = GTA5DataSet(args.data_dir, args.data_list, crop_size=image_sizes['cityscapes'], resize=image_sizes['gta5'], mean=IMG_MEAN) elif args.source == 'synthia': source_dataset = SYNDataSet(args.data_dir, args.data_list, crop_size=image_sizes['cityscapes'], resize=image_sizes['synthia'], mean=IMG_MEAN) else: raise ValueError('The source dataset mush be either gta5 or synthia') source_dataloader = data.DataLoader(source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) return source_dataloader
def CreateSrcDataLoader(args): if args.source == 'gta5': source_dataset = GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size, crop_size=image_sizes['gta5'], mean=IMG_MEAN) elif args.source == 'synthia': source_dataset = SYNDataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size, crop_size=image_sizes['synthia'], mean=IMG_MEAN) elif args.source == 'triangle': source_dataset = triangleDataset(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size, crop_size=image_sizes['triangle'], mean=IMG_MEAN) else: raise ValueError('The target dataset mush be either gta5 or synthia') source_dataloader = data.DataLoader(source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) return source_dataloader
def CreateSrcDataLoader(args): """ Creates the source data loader Args: commandline arguments Raises: ValueError: The value arg.source must be either gta5 or synthia Returns: torch.utils.data.DataLoader """ if args.source == 'gta5': source_dataset = GTA5DataSet(args.data_dir, args.data_list, crop_size=image_sizes['cityscapes'], resize=image_sizes['gta5'], mean=IMG_MEAN, max_iters=args.num_steps * args.batch_size) elif args.source == 'synthia': source_dataset = SYNDataSet(args.data_dir, args.data_list, crop_size=image_sizes['cityscapes'], resize=image_sizes['synthia'], mean=IMG_MEAN, max_iters=args.num_steps * args.batch_size) else: raise ValueError('The source dataset mush be either gta5 or synthia') source_dataloader = data.DataLoader(source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) return source_dataloader