Exemple #1
0
    def test_builder_memoization(self, mock_tfds_builder):
        mock_tfds_builder.side_effect = (
            lambda name, data_dir: ",".join([name, data_dir or ""]))

        ds1 = utils.LazyTfdsLoader("ds1")
        self.assertEqual("ds1,", ds1.builder)
        self.assertEqual(1, tfds.builder.call_count)

        # Builder should be cached with same name.
        self.assertEqual("ds1,", ds1.builder)
        self.assertEqual(1, tfds.builder.call_count)

        # Same name but different data dir is a cache miss.
        ds1_dir1 = utils.LazyTfdsLoader("ds1", "dir1")
        self.assertEqual("ds1,dir1", ds1_dir1.builder)
        self.assertEqual(2, tfds.builder.call_count)
        # Same name and data dir is a cache hit.
        self.assertEqual("ds1,dir1", ds1_dir1.builder)
        self.assertEqual(2, tfds.builder.call_count)

        # Different name is a cache miss.
        ds2 = utils.LazyTfdsLoader("ds2")
        self.assertEqual("ds2,", ds2.builder)
        self.assertEqual(3, tfds.builder.call_count)

        # Different split map name is a cache hit.
        ds2 = utils.LazyTfdsLoader("ds2", split_map={"train": "validation"})
        self.assertEqual("ds2,", ds2.builder)
        self.assertEqual(3, tfds.builder.call_count)

        # Try calling everything again, order shouldn't matter.
        self.assertEqual("ds1,", ds1.builder)
        self.assertEqual("ds1,dir1", ds1_dir1.builder)
        self.assertEqual("ds2,", ds2.builder)
        self.assertEqual(3, tfds.builder.call_count)
Exemple #2
0
  def test_split_map(self, mock_tfds_load):
    utils.LazyTfdsLoader._MEMOIZED_BUILDERS[("ds/c1", None)] = mock.Mock(
        info=mock.Mock(splits={
            "validation": mock.Mock(
                num_examples=420,
                file_instructions=["f1", "f2"]),
            "test": mock.Mock(
                num_examples=42,
                file_instructions=["f3"]),
        }))

    ds = utils.LazyTfdsLoader(
        "ds/c1", split_map={"train": "validation", "validation": "test"})

    # test .load()
    ds.load("train", shuffle_files=False)
    mock_tfds_load.assert_called_once_with(
        "ds/c1",
        split="validation",
        data_dir=None,
        shuffle_files=False,
        download=True,
        try_gcs=True)

    # test .size()
    self.assertEqual(420, ds.size(split="train"))
    self.assertEqual(42, ds.size(split="validation"))
    with self.assertRaises(KeyError):
      ds.size(split="test")

    # test .files()
    self.assertListEqual(["f1", "f2"], ds.files(split="train"))
    self.assertListEqual(["f3"], ds.files(split="validation"))
    with self.assertRaises(KeyError):
      ds.files(split="test")
  def __init__(
      self,
      tfds_name: str,
      tfds_data_dir: Optional[str] = None,
      splits: Optional[Union[Iterable[str], Mapping[str, str]]] = None
    ):
    """TfdsTask constructor.

    Args:
      tfds_name: string, the name and version number of a TFDS dataset,
        optionally with a config.
      tfds_data_dir: string, an optional path to a specific TFDS data directory
        to use.
      splits: an iterable of allowable string split names, a dict mapping
        allowable canonical splits (e.g., 'validation') to TFDS splits or slices
        (e.g., 'train[':1%']), or None. The default, None, uses all available
          splits from the TFDS dataset info.
    """
    if ":" not in tfds_name:
      raise ValueError("TFDS name must contain a version number, got: %s" %
                       tfds_name)

    self._tfds_dataset = utils.LazyTfdsLoader(
        tfds_name,
        data_dir=tfds_data_dir,
        split_map=splits if isinstance(splits, dict) else None)

    # If splits are not provided, we pass an empty tuple and use the lazy
    # lookup in the `splits` property.
    super().__init__(splits=splits or ())
  def __init__(
      self,
      name,
      tfds_name,
      text_preprocessor,
      metric_fns,
      tfds_data_dir=None,
      splits=None,
      **task_kwargs):
    """TfdsTask constructor.

    Args:
      name: string, a unique name for the Task. A ValueError will be raised if
        another task with this name is already registered.
      tfds_name: string, the name and version number of a TFDS dataset,
        optionally with a config.
      text_preprocessor: a function (or list of functions) that (each) takes in
        a tf.data.Dataset of string features and returns a tf.data.Dataset of
        string features. Can be set to None as a no-op. If a list is given,
        they will be executed sequentially.
      metric_fns: list(callable), list of metric functions with the signature
        metric_fn(targets, predictions) to use during evaluation.
      tfds_data_dir: string, an optional path to a specific TFDS data directory
        to use.
      splits: a list(string) of allowable splits to load, a dict mapping
        allowable canonical splits (e.g., 'validation') to TFDS splits or slices
        (e.g., 'train[':1%']), or None. The default, None, uses all available
        splits from the TFDS dataset info.
      **task_kwargs: dict, additional keyword arguments for the parent `Task`
        class.
    """
    if ":" not in tfds_name:
      raise ValueError(
          "TFDS name must contain a version number, got: %s" % tfds_name)

    self._tfds_dataset = utils.LazyTfdsLoader(
        tfds_name,
        data_dir=tfds_data_dir,
        split_map=splits if isinstance(splits, dict) else None)

    def dataset_fn(split, shuffle_files, seed=None):
      return self._tfds_dataset.load(
          split, shuffle_files, seed=seed)

    super().__init__(
        name,
        dataset_fn=dataset_fn,
        splits=list(splits) if splits else None,
        text_preprocessor=text_preprocessor,
        metric_fns=metric_fns,
        **task_kwargs)
  def test_get_dataset_onthefly(self):
    test_utils.verify_task_matches_fake_datasets(
        self.uncached_task, use_cached=False)

    # Test with token preprocessor.
    self.uncached_task._token_preprocessor = test_utils.test_token_preprocessor
    test_utils.verify_task_matches_fake_datasets(
        self.uncached_task, use_cached=False, token_preprocessed=True)

    # Override mock to get more examples.
    def fake_load(s, shuffle_files=False):
      del shuffle_files  # Unused, to mimic TFDS API
      return test_utils.get_fake_dataset(s).repeat().take(20)
    test_utils.add_fake_tfds(
        utils.LazyTfdsLoader("fake:0.0.0")._replace(load=fake_load))