Example #1
0
def _create_mini_gld_dataset(
        cache_dir: str, image_dir: str) -> Tuple[ClientData, tf.data.Dataset]:
    """Generate mini federated GLDv2 dataset with the downloaded images.

  Args:
    cache_dir: The directory for caching the intermediate results.
    image_dir: The directory that contains the filtered images.

  Returns:
    A tuple of `ClientData`, `tf.data.Dataset`.
  """
    train_path = tf.keras.utils.get_file(
        MINI_GLD_TRAIN_SPLIT_FILE,
        origin=MINI_GLD_TRAIN_DOWNLOAD_URL,
        file_hash=MINI_GLD_TRAIN_SPLIT_FILE_MD5_CHECKSUM,
        hash_algorithm='md5',
        cache_dir=cache_dir)
    test_path = tf.keras.utils.get_file(
        MINI_GLD_TEST_SPLIT_FILE,
        origin=MINI_GLD_TEST_DOWNLOAD_URL,
        file_hash=MINI_GLD_TEST_SPLIT_FILE_MD5_CHECKSUM,
        hash_algorithm='md5',
        cache_dir=cache_dir)
    _create_train_data_files(cache_dir=os.path.join(cache_dir, MINI_GLD_CACHE,
                                                    TRAIN_SUB_DIR),
                             image_dir=image_dir,
                             mapping_file=train_path)
    _create_test_data_file(cache_dir=os.path.join(cache_dir, MINI_GLD_CACHE),
                           image_dir=image_dir,
                           mapping_file=test_path)
    return vision_datasets_utils.load_data_from_cache(
        os.path.join(cache_dir, MINI_GLD_CACHE), LOGGER)
Example #2
0
def _create_federated_gld_dataset(
        cache_dir: str, image_dir: str, train_mapping_file: str,
        test_mapping_file: str) -> Tuple[ClientData, tf.data.Dataset]:
    """Generate fedreated GLDv2 dataset with the downloaded images.

  Args:
    cache_dir: The directory for caching the intermediate results.
    image_dir: The directory that contains the filtered images.
    train_mapping_file: The mapping file for the train set.
    test_mapping_file: The mapping file for the test set.

  Returns:
    A tuple of `(ClientData, tf.data.Dataset)`.
  """

    _create_train_data_files(cache_dir=os.path.join(cache_dir, FED_GLD_CACHE,
                                                    TRAIN_SUB_DIR),
                             image_dir=image_dir,
                             mapping_file=train_mapping_file)
    _create_test_data_file(cache_dir=os.path.join(cache_dir, FED_GLD_CACHE),
                           image_dir=image_dir,
                           mapping_file=test_mapping_file)
    return vision_datasets_utils.load_data_from_cache(
        cache_dir=os.path.join(cache_dir, FED_GLD_CACHE),
        train_sub_dir=TRAIN_SUB_DIR,
        test_file_name=TEST_FILE_NAME,
        logger_tag=LOGGER)
Example #3
0
def _load_data_from_cache(
        cache_dir: str,
        split: INaturalistSplit) -> Tuple[ClientData, tf.data.Dataset]:
    """Load train and test data from the TFRecord files.

  Args:
    cache_dir: The directory containing the TFRecord files.
    split: The split of the federated iNaturalist 2017 dataset.

  Returns:
    A tuple of `ClientData`, `tf.data.Dataset`.
  """
    cache_dir = os.path.join(cache_dir, split.name)
    return utils.load_data_from_cache(cache_dir, TRAIN_SUB_DIR, TEST_FILE_NAME,
                                      LOGGER)
Example #4
0
def load_data(num_worker: int = 1,
              cache_dir: str = 'cache',
              gld23k: bool = False,
              base_url: str = GLD_SHARD_BASE_URL):
    """Loads a federated version of the Google Landmark v2 dataset.

  The dataset consists of photos of various world landmarks, with images
  grouped by photographer to achieve a federated partitioning of the data.
  The dataset is downloaded and cached locally. If previously downloaded, it
  tries to load the dataset from cache.

  The `tf.data.Datasets` returned by
  `tff.simulation.datasets.ClientData.create_tf_dataset_for_client` will yield
  `collections.OrderedDict` objects at each iteration, with the following keys
  and values:

    -   `'image/decoded'`: A `tf.Tensor` with `dtype=tf.uint8` that
        corresponds to the pixels of the landmark images.
    -   `'class'`: A `tf.Tensor` with `dtype=tf.int64` and shape [1],
        corresponding to the class label of the landmark ([0, 203) for gld23k,
        [0, 2028) for gld160k).

  Two flavors of GLD datasets are available. When gld23k is true, a minimum
  version of the federated Google landmark dataset will be provided for faster
  iterations. The gld23k dataset contains 203 classes, 233 clients and 23080
  images.  When gld23k is false, the gld160k dataset
  (https://arxiv.org/abs/2003.08082) will be provided.  The gld160k dataset
  contains 2,028 classes, 1262 clients and 164,172 images.

  Args:
    num_worker: (Optional) The number of threads for downloading the GLD v2
      dataset.
    cache_dir: (Optional) The directory to cache the downloaded file. If `None`,
      caches in Keras' default cache directory.
    gld23k: (Optional) When true, a smaller version of the federated Google
      Landmark v2 dataset will be loaded. This gld23k dataset is used for faster
      prototyping.
    base_url: (Optional) The base url to download GLD v2 image shards.

  Returns:
    Tuple of (train, test) where the tuple elements are
    a `tff.simulation.datasets.ClientData` and a  `tf.data.Dataset`.
  """
    if not os.path.exists(cache_dir):
        os.mkdir(cache_dir)
    q = multiprocessing.Queue(-1)
    listener = multiprocessing.Process(
        target=_listener_process,
        args=(q, os.path.join(cache_dir, 'load_data.log')))
    listener.start()
    logger = logging.getLogger(LOGGER)
    qh = logging.handlers.QueueHandler(q)
    logger.addHandler(qh)
    logger.info('Start to load data.')
    if gld23k:
        existing_data_cache = os.path.join(cache_dir, MINI_GLD_CACHE)
    else:
        existing_data_cache = os.path.join(cache_dir, FED_GLD_CACHE)
    try:
        logger.info('Try loading dataset from cache')
        return vision_datasets_utils.load_data_from_cache(
            existing_data_cache, TRAIN_SUB_DIR, TEST_FILE_NAME, LOGGER)
    except Exception:  # pylint: disable=broad-except
        logger.info('Loading from cache failed, start to download the data.')
        fed_gld_train, fed_gld_test, mini_gld_train, mini_gld_test = _download_data(
            num_worker, cache_dir, base_url)
    finally:
        q.put_nowait(None)
        listener.join()
    if gld23k:
        return mini_gld_train, mini_gld_test
    else:
        return fed_gld_train, fed_gld_test