示例#1
0
文件: config.py 项目: b7leung/occ_uda
def get_dataset(mode, cfg, return_idx=False, return_category=False, use_target_domain = False):
    ''' Returns the dataset.

    Args:
        model (nn.Module): the model which is used
        cfg (dict): config dictionary
        return_idx (bool): whether to include an ID field
        use_target_domain (bool): whether to use the target_domain dataset
    '''

    method = cfg['method']
    dataset_type = cfg['data']['dataset']
    if use_target_domain:
        #dataset_type = cfg['data']['uda_dataset']
        dataset_folder = cfg['data']['uda_path']
        categories = cfg['data']['uda_classes']
    else:
        dataset_folder = cfg['data']['path']
        categories = cfg['data']['classes']

    # Get split
    splits = {
        'train': cfg['data']['train_split'],
        'val': cfg['data']['val_split'],
        'test': cfg['data']['test_split'],
    }

    split = splits[mode]

    # Create dataset
    if dataset_type == 'Shapes3D':
        # Dataset fields
        # Method specific fields (usually correspond to output)
        fields = method_dict[method].config.get_data_fields(mode, cfg)
        # Input fields
        inputs_field = get_inputs_field(mode, cfg, use_target_domain)
        if inputs_field is not None:
            fields['inputs'] = inputs_field

        # adding field for UDA input when training
        if mode == 'train' and cfg['training']['uda_type'] is not None:
            # Also data-augment target domain imgs?
            if cfg['data']['img_augment']:
                resize_op = transforms.RandomResizedCrop(
                    cfg['data']['img_size'], (0.75, 1.), (1., 1.))
            else:
                resize_op = transforms.Resize((cfg['data']['img_size']))
            transform = transforms.Compose([
                resize_op, transforms.ToTensor(),
            ])

            # random_view=True enables randomness
            fields['inputs_target_domain'] = data.ImagesField(
                #cfg['data']['uda_path_train'], transform=transform, random_view=True, image_based_hier=True
                cfg['data']['uda_path_train'], transform=transform, random_view=True, extensions=['jpg', 'jpeg', 'png'], image_based_hier=True
            )

        if return_idx:
            fields['idx'] = data.IndexField()

        if return_category:
            fields['category'] = data.CategoryField()

        dataset = data.Shapes3dDataset(
            dataset_folder, fields,
            split=split,
            categories=categories,
        )
    elif dataset_type == 'kitti':
        dataset = data.KittiDataset(
            dataset_folder, img_size=cfg['data']['img_size'],
            return_idx=return_idx
        )
    elif dataset_type == 'online_products':
        dataset = data.OnlineProductDataset(
            dataset_folder, img_size=cfg['data']['img_size'],
            classes=cfg['data']['classes'],
            max_number_imgs=cfg['generation']['max_number_imgs'],
            return_idx=return_idx, return_category=return_category
        )
    elif dataset_type == 'images':
        dataset = data.ImageDataset(
            dataset_folder, img_size=cfg['data']['img_size'],
            return_idx=return_idx,
        )
    else:
        raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])
 
    return dataset
示例#2
0
def get_dataset(mode, cfg, return_idx=False, return_category=False):
    ''' Returns the dataset.

    Args:
        model (nn.Module): the model which is used
        cfg (dict): config dictionary
        return_idx (bool): whether to include an ID field
    '''
    method = cfg['method']
    dataset_type = cfg['data']['dataset']
    dataset_folder = cfg['data']['path']
    categories = cfg['data']['classes']

    # Get split
    splits = {
        'train': cfg['data']['train_split'],
        'val': cfg['data']['val_split'],
        'test': cfg['data']['test_split'],
    }

    split = splits[mode]

    # Create dataset
    if dataset_type == 'Shapes3D':
        # Dataset fields
        # Method specific fields (usually correspond to output)
        fields = method_dict[method].config.get_data_fields(mode, cfg)
        # Input fields
        inputs_field = get_inputs_field(mode, cfg)
        if inputs_field is not None:
            fields['inputs'] = inputs_field

        if return_idx:
            fields['idx'] = data.IndexField()

        if return_category:
            fields['category'] = data.CategoryField()

        dataset = data.Shapes3dDataset(
            dataset_folder,
            fields,
            split=split,
            categories=categories,
        )
    elif dataset_type == 'kitti':
        dataset = data.KittiDataset(dataset_folder,
                                    img_size=cfg['data']['img_size'],
                                    return_idx=return_idx)
    elif dataset_type == 'online_products':
        dataset = data.OnlineProductDataset(
            dataset_folder,
            img_size=cfg['data']['img_size'],
            classes=cfg['data']['classes'],
            max_number_imgs=cfg['generation']['max_number_imgs'],
            return_idx=return_idx,
            return_category=return_category)
    elif dataset_type == 'images':
        dataset = data.ImageDataset(
            dataset_folder,
            img_size=cfg['data']['img_size'],
            return_idx=return_idx,
        )
    else:
        raise ValueError('Invalid dataset "%s"' % cfg['data']['dataset'])

    return dataset
示例#3
0
def get_dataset(mode,
                cfg,
                batch_size,
                shuffle,
                repeat_count,
                epoch,
                return_idx=False,
                return_category=False):
    """ Returns the dataset.

  Args:
      model (nn.Module): the model which is used
      cfg (dict): config dictionary
      return_idx (bool): whether to include an ID field
  """
    method = cfg["method"]
    dataset_type = cfg["data"]["dataset"]
    dataset_folder = cfg["data"]["path"]
    categories = cfg["data"]["classes"]

    # Get split
    splits = {
        "train": cfg["data"]["train_split"],
        "val": cfg["data"]["val_split"],
        "test": cfg["data"]["test_split"],
    }

    split = splits[mode]

    # Create dataset
    if dataset_type == "Shapes3D":
        # Dataset fields
        # Method specific fields (usually correspond to output)
        fields = method_dict[method].config.get_data_fields(mode, cfg)
        # Input fields
        inputs_field = get_inputs_field(mode, cfg)
        if inputs_field is not None:
            fields["inputs"] = inputs_field

        if return_idx:
            fields["idx"] = data.IndexField()

        if return_category:
            fields["category"] = data.CategoryField()

        dataset = data.Shapes3dDataset(
            dataset_folder,
            fields,
            split=split,
            batch_size=batch_size,
            shuffle=shuffle,
            repeat_count=repeat_count,
            epoch=epoch,
            categories=categories,
        )
    elif dataset_type == "kitti":
        dataset = data.KittiDataset(dataset_folder,
                                    batch_size=batch_size,
                                    shuffle=shuffle,
                                    img_size=cfg["data"]["img_size"],
                                    return_idx=return_idx)
    elif dataset_type == "online_products":
        dataset = data.OnlineProductDataset(
            dataset_folder,
            batch_size=batch_size,
            shuffle=shuffle,
            img_size=cfg["data"]["img_size"],
            classes=cfg["data"]["classes"],
            max_number_imgs=cfg["generation"]["max_number_imgs"],
            return_idx=return_idx,
            return_category=return_category,
        )
    elif dataset_type == "images":
        dataset = data.ImageDataset(
            dataset_folder,
            batch_size=batch_size,
            shuffle=shuffle,
            img_size=cfg["data"]["img_size"],
            return_idx=return_idx,
        )
    else:
        raise ValueError('Invalid dataset "%s"' % cfg["data"]["dataset"])

    return dataset