Exemplo n.º 1
0
  def get_padded_one_shot_dataset(self, *, batch_shape,
                                  split, shard_id, num_shards):
    """Non-repeated non-shuffled sharded dataset with padding.

    Should not drop any examples. Augmentation is disabled.

    Args:
      batch_shape: leading shape of batches
      split: which dataset split to load
      shard_id: current shard id (e.g. process_index)
      num_shards: number of shards (e.g. process_count)

    Returns:
      dataset
    """
    ds = self._load_tfds(split=split, shuffle_seed=None)
    ds = ds.map(
        functools.partial(self._preprocess, split=split, augment=False),
        num_parallel_calls=tf.data.AUTOTUNE)
    ds = deterministic_data.pad_dataset(
        ds, batch_dims=(num_shards, *batch_shape),
        cardinality={'train': self.num_train, 'eval': self.num_eval}[split])
    ds = ds.shard(index=shard_id, num_shards=num_shards)
    ds = batch_dataset(ds, batch_shape=batch_shape)
    return ds.prefetch(tf.data.AUTOTUNE)
 def test_pad_dataset(self):
   dataset = tf.data.Dataset.from_tensor_slices(
       dict(x=tf.ones((12, 10)), y=tf.ones(12)))
   padded_dataset = deterministic_data.pad_dataset(
       dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
   self.assertAllClose(
       dict(
           x=tf.concat([tf.ones(
               (12, 10)), tf.zeros((8, 10))], axis=0),
           y=tf.concat([tf.ones(12), tf.zeros(8)], axis=0),
           mask=tf.concat(
               [tf.ones(12, bool), tf.zeros(8, bool)], axis=0)),
       next(iter(padded_dataset.batch(20))))
  def test_pad_nested_dataset(self):
    dataset = tf.data.Dataset.from_tensor_slices(
        {"x": {"z": (tf.ones((12, 10)), tf.ones(12))},
         "y": tf.ones((12, 4))})

    def expected(*dims):
      return tf.concat([tf.ones((12,) + dims), tf.zeros((8,) + dims)], axis=0)

    padded_dataset = deterministic_data.pad_dataset(
        dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
    self.assertAllClose(
        {"x": {"z": (expected(10), expected())},
         "y": expected(4),
         "mask": tf.concat([tf.ones(12, bool), tf.zeros(8, bool)], axis=0)},
        next(iter(padded_dataset.batch(20))))
Exemplo n.º 4
0
def compute_predictions_jax(
    model: PredictionModel, dataset: tf.data.Dataset, batch_size: int
)-> Iterator[Tuple[types.ModelPredictions, types.Features]]:
  """Yield the predictions of the given JAX model on the given dataset.

  Note that this also works in multi-host configurations. You have to make
  sure that this function gets called on all hosts. The results will be yielded
  only to the host with a jax.host_id() equal to 0.

  Args:
    model: A function that takes tensor-valued features and returns a vector of
      predictions.
    dataset: The dataset that the function consumes to produce the predictions.
    batch_size: The batch size that should be used.

  Yields:
    The predictions of the model on the dataset.
  """

  def _gather(inputs):
    return jax.lax.all_gather(inputs, "i")
  gather = jax.pmap(_gather, axis_name="i")

  def infer(features):
    probabilities = model(features)
    return_vals = (probabilities, features["metadata"], features["mask"])
    return_vals_reshaped = jax.tree_map(
        lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]),
        return_vals
    )
    return jax.tree_map(lambda x: x[0], gather(return_vals_reshaped))

  if dataset.cardinality() < 0:
    raise ValueError(
        "The cardinality must be known when running JAX multi-host models.")
  total_batches = math.ceil(dataset.cardinality() / batch_size)
  lcm = lambda x, y: (x * y) // math.gcd(x, y)
  # We want each shard (host) to get an equal number of batches.
  total_batches_padded = lcm(jax.host_count(), total_batches)
  logging.info("Total batches %d, rounded up to %d",
               total_batches, total_batches_padded)

  def pad_strings(array):
    if array.dtype != tf.string:
      return array
    array_bytes = tf.strings.unicode_decode(array, "UTF-8")
    # The return type is either Tensor or RaggedTensor.
    try:
      # When a RaggedTensor, which we need to convert it.
      # to_tensor() adds a leading dimension of size 1, which we drop.
      array_bytes = array_bytes.to_tensor()[0]
    except AttributeError:
      pass
    array_size = tf.size(array_bytes)
    with tf.control_dependencies([
        tf.compat.v1.assert_less_equal(array_size, 1024)]):
      packed = tf.pad(array_bytes, [[0, 1024 - array_size]])
    return {"__packed": tf.ensure_shape(packed, [1024])}

  def unpad_strings(array):
    if isinstance(array, dict):
      with_trailing_zeros = bytes(tf.strings.unicode_encode(
          np.asarray(array["__packed"]).reshape(-1), "UTF-8").numpy())
      return with_trailing_zeros.rstrip(b"\x00")
    else:
      return np.asarray(array)

  def pad_strings_in_metadata(features):
    """Only padding of the strings subject to a gather operation."""
    features["metadata"] = tf.nest.map_structure(pad_strings,
                                                 features["metadata"])
    return features

  dataset = clu_dd.pad_dataset(
      dataset.map(pad_strings_in_metadata),
      batch_dims=[batch_size],
      pad_up_to_batches=total_batches_padded,
      cardinality=None,  # It will be inferred from the datset.
  ).batch(batch_size)

  # The shard for the current host.
  dataset_shard = dataset.shard(jax.host_count(), jax.host_id())
  logging.info("Batches per host: %d", dataset_shard.cardinality())
  for features in dataset_shard.as_numpy_iterator():
    time_start = time.time()
    # There is a bug in XLA, the following fails for int8s.
    features["mask"] = features["mask"].astype(np.int32)

    flatten = lambda array: array.reshape((-1,) + array.shape[2:])
    predictions, metadatas, masks = jax.tree_map(flatten, infer(features))

    time_end = time.time()
    time_delta_per_example = (time_end - time_start) / predictions.shape[0]
    predictions = np.asarray(predictions)  # Materialize.
    if jax.host_id() == 0:
      for i in range(predictions.shape[0]):
        if masks[i]:
          predictions_i = types.ModelPredictions(
              predictions=[predictions[i]], time_in_s=time_delta_per_example)
          metadata_i = _slice_dictionary(metadatas, i)
          is_leaf_fn = lambda x: isinstance(x, dict) and "__packed" in x
          metadata_i_unpadded = jax.tree_map(
              unpad_strings, metadata_i, is_leaf=is_leaf_fn)
          yield predictions_i, metadata_i_unpadded