示例#1
0
    def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
        """
        Returns a test :class:`~tf.data.Dataset`.

        Args:
            test_dataset (:class:`~tf.data.Dataset`):
                The dataset to use. The dataset should yield tuples of ``(features, labels)`` where ``features`` is a
                dict of input features and ``labels`` is the labels. If ``labels`` is a tensor, the loss is calculated
                by the model by calling ``model(features, labels=labels)``. If ``labels`` is a dict, such as when using
                a QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
                ``model(features, **labels)``.

        Subclass and override this method if you want to inject some custom behavior.
        """

        num_examples = test_dataset.cardinality().numpy()

        if num_examples < 0:
            raise ValueError("The training dataset must have an asserted cardinality")

        approx = math.floor if self.args.dataloader_drop_last else math.ceil
        steps = approx(num_examples / self.args.eval_batch_size)
        ds = (
            test_dataset.repeat()
            .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
            .prefetch(tf.data.experimental.AUTOTUNE)
        )

        return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
示例#2
0
文件: models.py 项目: jackd/kblocks
def as_infinite_iterator(
        dataset: tf.data.Dataset,
        steps_per_epoch: Optional[int] = None) -> Tuple[tf.data.Iterator, int]:
    """
    Get an iterator for an infinite dataset and steps_per_epoch.

    Args:
        dataset: possibly infinite dataset.
        steps_per_epoch: number of steps per epoch if `dataset` has infinite
            cardinality, otherwise `None` or `dataset`'s cardinality.

    Returns:
        iterator: tf.data.Iterator of possibly repeated `dataset`.
        steps_per_epoch: number of elements in iterator considered one epoch.

    Raises:
        ValueError is dataset has finite cardinality inconsistent with steps_per_epoch.
    """
    cardinality = tf.keras.backend.get_value(dataset.cardinality())
    if steps_per_epoch is None:
        steps_per_epoch = cardinality
        if cardinality == tf.data.INFINITE_CARDINALITY:
            raise ValueError(
                "steps_per_epoch must be provided if dataset has infinite "
                "cardinality")
        dataset = dataset.repeat()
    elif cardinality != tf.data.INFINITE_CARDINALITY:
        assert cardinality == steps_per_epoch
        dataset = dataset.repeat()
    return iter(dataset), steps_per_epoch
示例#3
0
def get_augmented_data(
    dataset: tf.data.Dataset,
    batch_size: int,
    map_func: Callable,
    shuffle_buffer: Optional[int] = None,
    shuffle_seed: Optional[int] = None,
    augment_seed: Optional[int] = None,
    use_stateless_map: bool = False,
) -> RepeatedData:
    if shuffle_buffer is not None:
        dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed)
    dataset = dataset.batch(batch_size)
    steps_per_epoch = tf.keras.backend.get_value(dataset.cardinality())
    # repeat before map so stateless map is different across epochs
    dataset = dataset.repeat()
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    if use_stateless_map:
        dataset = dataset.apply(
            tfrng.data.stateless_map(
                map_func,
                seed=augment_seed,
                num_parallel_calls=AUTOTUNE,
            ))
    else:
        # if map_func has random elements this won't be deterministic
        dataset = dataset.map(map_func, num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(AUTOTUNE)
    return RepeatedData(dataset, steps_per_epoch)
示例#4
0
def _validate_data(data: tf.data.Dataset, steps: Optional[int]):
    cardinality = tf.keras.backend.get_value(data.cardinality())
    if cardinality == tf.data.INFINITE_CARDINALITY:
        assert steps is not None
    else:
        assert cardinality > 0
        assert steps is None
示例#5
0
def pad_dataset(dataset: tf.data.Dataset,
                *,
                batch_dims: Sequence[int],
                pad_up_to_batches: Optional[int] = None,
                cardinality: Optional[int] = None):
    """Adds padding to a dataset.

  Args:
    dataset: The dataset to be padded.
    batch_dims: List of size of batch dimensions. Multiple batch dimension can
      be used to provide inputs for multiple devices. E.g.
      [jax.local_device_count(), batch_size // jax.device_count()].
    pad_up_to_batches: Set this option to process the entire dataset. When set,
      then the dataset is first padded to the specified number of batches. A new
      feature called "mask" is added to every batch. This feature is set to
      `True` for every example that comes from `dataset_builder`, and to `False`
      for every example that is padded to get to the specified number of
      batches. Note that the specified `dataset_builder` and `split` must result
      in at least `pad_up_to_batches` (possibly partial) batches. If `None`,
      derives from `batch_dims` and `cardinality` such that `pad_up_to_batches *
      batch_dims == cardinality`. Note that `cardinality` is what you pass in,
      not necessarily the original full dataset size if you decide to shard it
      per host.
    cardinality: Number of examples in the dataset. Only needed when the
      cardinality cannot be retrieved via `ds.cardinalty()` (e.g. because of
      using `ds.filter()`).

  Returns:
    The padded dataset, with the added feature "mask" that is set to `True` for
    examples from the original `dataset` and to `False` for padded examples.
  """
    if not isinstance(dataset.element_spec, dict):
        raise ValueError("The dataset must have dictionary elements.")
    if cardinality is None:
        cardinality = dataset.cardinality()
        if cardinality == tf.data.UNKNOWN_CARDINALITY:
            raise ValueError(
                "Cannot determine dataset cardinality. This can happen when you use "
                "a `.filter()` on the dataset. Please provide the cardinality as an "
                "argument to `create_dataset()`.")
    if "mask" in dataset.element_spec:
        raise ValueError("Dataset already contains a feature named \"mask\".")
    if pad_up_to_batches is None:
        pad_up_to_batches = int(np.ceil(cardinality / np.prod(batch_dims)))

    filler_element = tf.nest.map_structure(
        lambda spec: tf.zeros(spec.shape, spec.dtype)[None],
        dataset.element_spec)
    filler_element["mask"] = [False]
    filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)

    dataset = dataset.map(lambda features: dict(mask=True, **features),
                          num_parallel_calls=AUTOTUNE)
    padding = pad_up_to_batches * np.prod(batch_dims) - int(cardinality)
    assert padding >= 0, (
        f"Invalid padding={padding} (batch_dims={batch_dims}, cardinality="
        f"{cardinality}, pad_up_to_batches={pad_up_to_batches})")
    return dataset.concatenate(filler_dataset.repeat(padding))
示例#6
0
文件: repeated.py 项目: jackd/kblocks
 def __init__(self,
              dataset: tf.data.Dataset,
              steps_per_epoch: Optional[int] = None):
     cardinality = tf.keras.backend.get_value(dataset.cardinality())
     if steps_per_epoch is None:
         steps_per_epoch = cardinality
         if cardinality == tf.data.INFINITE_CARDINALITY:
             raise ValueError(
                 "steps_per_epoch must be provided if dataset has infinite "
                 "cardinality")
         dataset = dataset.repeat()
     elif cardinality != tf.data.INFINITE_CARDINALITY:
         assert cardinality == steps_per_epoch
         dataset = dataset.repeat()
     self._dataset = dataset
     self._steps_per_epoch = steps_per_epoch
示例#7
0
文件: profile.py 项目: jackd/kblocks
def profile_model(
    model: tf.keras.Model,
    dataset: tf.data.Dataset,
    inference_only: bool = False,
    **kwargs,
):
    if dataset.cardinality() != tf.data.INFINITE_CARDINALITY:
        dataset = dataset.repeat()
    it = iter(dataset)
    model_func = (model.make_predict_function()
                  if inference_only else model.make_train_function())

    def func():
        return model_func(it)

    return profile_func(func,
                        **kwargs,
                        name="predict" if inference_only else "train")
示例#8
0
def benchmark_model(model: tf.keras.Model,
                    dataset: tf.data.Dataset,
                    inference_only=False,
                    **kwargs):
    if dataset.cardinality() != tf.data.INFINITE_CARDINALITY:
        dataset = dataset.repeat()
    inputs, labels, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(
        as_inputs(dataset))
    if inference_only:
        op = model(inputs)
    else:
        variables = model.trainable_variables
        with tf.GradientTape() as tape:
            predictions = model(inputs)
            loss = model.loss(labels, predictions, sample_weight=sample_weight)
        grads = tape.gradient(loss, variables)
        op = model.optimizer.apply_gradients(zip(grads, variables))
    return benchmark_op(op, **kwargs)
示例#9
0
def tfrecords_cache(
    dataset: tf.data.Dataset,
    cache_dir: str,
    num_parallel_calls: int = 1,
    compression: Optional[str] = None,
    deterministic: Optional[bool] = None,
):
    if tf.executing_eagerly():
        tf.io.gfile.makedirs(cache_dir)

    cardinality = dataset.cardinality()
    path = cache_dir + "/serialized.tfrecords"  # must work in graph mode
    if tf.shape(tf.io.matching_files(path))[0] == 0:
        logging.info(f"Saving tfrecords dataset to {path}")
        save(dataset, path, compression=compression)
    return parse(
        load(path, compression=compression),
        spec=dataset.element_spec,
        num_parallel_calls=num_parallel_calls,
        deterministic=deterministic,
    ).apply(tf.data.experimental.assert_cardinality(cardinality))
def compute_predictions_jax(
    model: PredictionModel, dataset: tf.data.Dataset, batch_size: int
)-> Iterator[Tuple[types.ModelPredictions, types.Features]]:
  """Yield the predictions of the given JAX model on the given dataset.

  Note that this also works in multi-host configurations. You have to make
  sure that this function gets called on all hosts. The results will be yielded
  only to the host with a jax.host_id() equal to 0.

  Args:
    model: A function that takes tensor-valued features and returns a vector of
      predictions.
    dataset: The dataset that the function consumes to produce the predictions.
    batch_size: The batch size that should be used.

  Yields:
    The predictions of the model on the dataset.
  """

  def _gather(inputs):
    return jax.lax.all_gather(inputs, "i")
  gather = jax.pmap(_gather, axis_name="i")

  def infer(features):
    probabilities = model(features)
    return_vals = (probabilities, features["metadata"], features["mask"])
    return_vals_reshaped = jax.tree_map(
        lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]),
        return_vals
    )
    return jax.tree_map(lambda x: x[0], gather(return_vals_reshaped))

  if dataset.cardinality() < 0:
    raise ValueError(
        "The cardinality must be known when running JAX multi-host models.")
  total_batches = math.ceil(dataset.cardinality() / batch_size)
  lcm = lambda x, y: (x * y) // math.gcd(x, y)
  # We want each shard (host) to get an equal number of batches.
  total_batches_padded = lcm(jax.host_count(), total_batches)
  logging.info("Total batches %d, rounded up to %d",
               total_batches, total_batches_padded)

  def pad_strings(array):
    if array.dtype != tf.string:
      return array
    array_bytes = tf.strings.unicode_decode(array, "UTF-8")
    # The return type is either Tensor or RaggedTensor.
    try:
      # When a RaggedTensor, which we need to convert it.
      # to_tensor() adds a leading dimension of size 1, which we drop.
      array_bytes = array_bytes.to_tensor()[0]
    except AttributeError:
      pass
    array_size = tf.size(array_bytes)
    with tf.control_dependencies([
        tf.compat.v1.assert_less_equal(array_size, 1024)]):
      packed = tf.pad(array_bytes, [[0, 1024 - array_size]])
    return {"__packed": tf.ensure_shape(packed, [1024])}

  def unpad_strings(array):
    if isinstance(array, dict):
      with_trailing_zeros = bytes(tf.strings.unicode_encode(
          np.asarray(array["__packed"]).reshape(-1), "UTF-8").numpy())
      return with_trailing_zeros.rstrip(b"\x00")
    else:
      return np.asarray(array)

  def pad_strings_in_metadata(features):
    """Only padding of the strings subject to a gather operation."""
    features["metadata"] = tf.nest.map_structure(pad_strings,
                                                 features["metadata"])
    return features

  dataset = clu_dd.pad_dataset(
      dataset.map(pad_strings_in_metadata),
      batch_dims=[batch_size],
      pad_up_to_batches=total_batches_padded,
      cardinality=None,  # It will be inferred from the datset.
  ).batch(batch_size)

  # The shard for the current host.
  dataset_shard = dataset.shard(jax.host_count(), jax.host_id())
  logging.info("Batches per host: %d", dataset_shard.cardinality())
  for features in dataset_shard.as_numpy_iterator():
    time_start = time.time()
    # There is a bug in XLA, the following fails for int8s.
    features["mask"] = features["mask"].astype(np.int32)

    flatten = lambda array: array.reshape((-1,) + array.shape[2:])
    predictions, metadatas, masks = jax.tree_map(flatten, infer(features))

    time_end = time.time()
    time_delta_per_example = (time_end - time_start) / predictions.shape[0]
    predictions = np.asarray(predictions)  # Materialize.
    if jax.host_id() == 0:
      for i in range(predictions.shape[0]):
        if masks[i]:
          predictions_i = types.ModelPredictions(
              predictions=[predictions[i]], time_in_s=time_delta_per_example)
          metadata_i = _slice_dictionary(metadatas, i)
          is_leaf_fn = lambda x: isinstance(x, dict) and "__packed" in x
          metadata_i_unpadded = jax.tree_map(
              unpad_strings, metadata_i, is_leaf=is_leaf_fn)
          yield predictions_i, metadata_i_unpadded
示例#11
0
文件: tfds.py 项目: deepmind/acme
def _dataset_size_upperbound(dataset: tf.data.Dataset) -> int:
    if dataset.cardinality() != tf.data.experimental.UNKNOWN_CARDINALITY:
        return dataset.cardinality()
    return tf.cast(
        dataset.batch(1000).reduce(0, lambda x, step: x + 1000), tf.int64)