def get_dataset(mode, cfg, return_idx=False, return_category=False, use_target_domain = False): ''' Returns the dataset. Args: model (nn.Module): the model which is used cfg (dict): config dictionary return_idx (bool): whether to include an ID field use_target_domain (bool): whether to use the target_domain dataset ''' method = cfg['method'] dataset_type = cfg['data']['dataset'] if use_target_domain: #dataset_type = cfg['data']['uda_dataset'] dataset_folder = cfg['data']['uda_path'] categories = cfg['data']['uda_classes'] else: dataset_folder = cfg['data']['path'] categories = cfg['data']['classes'] # Get split splits = { 'train': cfg['data']['train_split'], 'val': cfg['data']['val_split'], 'test': cfg['data']['test_split'], } split = splits[mode] # Create dataset if dataset_type == 'Shapes3D': # Dataset fields # Method specific fields (usually correspond to output) fields = method_dict[method].config.get_data_fields(mode, cfg) # Input fields inputs_field = get_inputs_field(mode, cfg, use_target_domain) if inputs_field is not None: fields['inputs'] = inputs_field # adding field for UDA input when training if mode == 'train' and cfg['training']['uda_type'] is not None: # Also data-augment target domain imgs? if cfg['data']['img_augment']: resize_op = transforms.RandomResizedCrop( cfg['data']['img_size'], (0.75, 1.), (1., 1.)) else: resize_op = transforms.Resize((cfg['data']['img_size'])) transform = transforms.Compose([ resize_op, transforms.ToTensor(), ]) # random_view=True enables randomness fields['inputs_target_domain'] = data.ImagesField( #cfg['data']['uda_path_train'], transform=transform, random_view=True, image_based_hier=True cfg['data']['uda_path_train'], transform=transform, random_view=True, extensions=['jpg', 'jpeg', 'png'], image_based_hier=True ) if return_idx: fields['idx'] = data.IndexField() if return_category: fields['category'] = data.CategoryField() dataset = data.Shapes3dDataset( dataset_folder, fields, split=split, categories=categories, ) elif dataset_type == 'kitti': dataset = data.KittiDataset( dataset_folder, img_size=cfg['data']['img_size'], return_idx=return_idx ) elif dataset_type == 'online_products': dataset = data.OnlineProductDataset( dataset_folder, img_size=cfg['data']['img_size'], classes=cfg['data']['classes'], max_number_imgs=cfg['generation']['max_number_imgs'], return_idx=return_idx, return_category=return_category ) elif dataset_type == 'images': dataset = data.ImageDataset( dataset_folder, img_size=cfg['data']['img_size'], return_idx=return_idx, ) else: raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) return dataset
def get_dataset(mode, cfg, return_idx=False, return_category=False): ''' Returns the dataset. Args: model (nn.Module): the model which is used cfg (dict): config dictionary return_idx (bool): whether to include an ID field ''' method = cfg['method'] dataset_type = cfg['data']['dataset'] dataset_folder = cfg['data']['path'] categories = cfg['data']['classes'] # Get split splits = { 'train': cfg['data']['train_split'], 'val': cfg['data']['val_split'], 'test': cfg['data']['test_split'], } split = splits[mode] # Create dataset if dataset_type == 'Shapes3D': # Dataset fields # Method specific fields (usually correspond to output) fields = method_dict[method].config.get_data_fields(mode, cfg) # Input fields inputs_field = get_inputs_field(mode, cfg) if inputs_field is not None: fields['inputs'] = inputs_field if return_idx: fields['idx'] = data.IndexField() if return_category: fields['category'] = data.CategoryField() dataset = data.Shapes3dDataset( dataset_folder, fields, split=split, categories=categories, ) elif dataset_type == 'kitti': dataset = data.KittiDataset(dataset_folder, img_size=cfg['data']['img_size'], return_idx=return_idx) elif dataset_type == 'online_products': dataset = data.OnlineProductDataset( dataset_folder, img_size=cfg['data']['img_size'], classes=cfg['data']['classes'], max_number_imgs=cfg['generation']['max_number_imgs'], return_idx=return_idx, return_category=return_category) elif dataset_type == 'images': dataset = data.ImageDataset( dataset_folder, img_size=cfg['data']['img_size'], return_idx=return_idx, ) else: raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset']) return dataset
def get_dataset(mode, cfg, batch_size, shuffle, repeat_count, epoch, return_idx=False, return_category=False): """ Returns the dataset. Args: model (nn.Module): the model which is used cfg (dict): config dictionary return_idx (bool): whether to include an ID field """ method = cfg["method"] dataset_type = cfg["data"]["dataset"] dataset_folder = cfg["data"]["path"] categories = cfg["data"]["classes"] # Get split splits = { "train": cfg["data"]["train_split"], "val": cfg["data"]["val_split"], "test": cfg["data"]["test_split"], } split = splits[mode] # Create dataset if dataset_type == "Shapes3D": # Dataset fields # Method specific fields (usually correspond to output) fields = method_dict[method].config.get_data_fields(mode, cfg) # Input fields inputs_field = get_inputs_field(mode, cfg) if inputs_field is not None: fields["inputs"] = inputs_field if return_idx: fields["idx"] = data.IndexField() if return_category: fields["category"] = data.CategoryField() dataset = data.Shapes3dDataset( dataset_folder, fields, split=split, batch_size=batch_size, shuffle=shuffle, repeat_count=repeat_count, epoch=epoch, categories=categories, ) elif dataset_type == "kitti": dataset = data.KittiDataset(dataset_folder, batch_size=batch_size, shuffle=shuffle, img_size=cfg["data"]["img_size"], return_idx=return_idx) elif dataset_type == "online_products": dataset = data.OnlineProductDataset( dataset_folder, batch_size=batch_size, shuffle=shuffle, img_size=cfg["data"]["img_size"], classes=cfg["data"]["classes"], max_number_imgs=cfg["generation"]["max_number_imgs"], return_idx=return_idx, return_category=return_category, ) elif dataset_type == "images": dataset = data.ImageDataset( dataset_folder, batch_size=batch_size, shuffle=shuffle, img_size=cfg["data"]["img_size"], return_idx=return_idx, ) else: raise ValueError('Invalid dataset "%s"' % cfg["data"]["dataset"]) return dataset