def testSerializer(self):
        path = self.create_tempdir().create_file('myfile.tfrecords').full_path
        serializer = rm.metrics.Serializer(path)
        predictions_and_metadata = [
            (types.ModelPredictions(np.array([[0., 1., 2.]],
                                             dtype=np.float32)), {
                                                 'element_id': 1,
                                                 'label': 1
                                             }),
            (types.ModelPredictions(np.array([[3., 4., 5.]],
                                             dtype=np.float32)), {
                                                 'element_id': 2,
                                                 'label': 2
                                             }),
            (types.ModelPredictions(np.array([[6., 7., 8.]],
                                             dtype=np.float32)), {
                                                 'element_id': 3,
                                                 'label': 3
                                             }),
        ]

        serializer.add_predictions(*predictions_and_metadata[0])
        serializer.add_predictions(*predictions_and_metadata[1])
        serializer.flush()
        actual = list(serializer.read_predictions())
        for x, y in zip(predictions_and_metadata[:2], actual):
            self.assertAllEqual(x[0].predictions, y[0].predictions)
            self.assertEqual(x[1], y[1])

        serializer.add_predictions(*predictions_and_metadata[2])
        serializer.flush()
        actual = list(serializer.read_predictions())
        for x, y in zip(predictions_and_metadata, actual):
            self.assertAllEqual(x[0].predictions, y[0].predictions)
            self.assertEqual(x[1], y[1])
    def read_predictions(
            self) -> Iterator[Tuple[types.ModelPredictions, types.Features]]:
        """Reads path in order to yield each prediction and metadata."""
        def parse(features_serialized):
            features = {
                "predictions": tf.io.FixedLenFeature([], tf.string),
                "metadata": tf.io.FixedLenFeature([], tf.string),
            }
            features = tf.io.parse_single_example(features_serialized,
                                                  features)
            features["predictions"] = tf.io.parse_tensor(
                features["predictions"], tf.float32)
            return features

        path = tf.convert_to_tensor(self._path)
        dataset = tf.data.TFRecordDataset(path).map(parse)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = (
            tf.data.experimental.AutoShardPolicy.DATA)
        dataset = dataset.with_options(options)
        for example in dataset:
            prediction = types.ModelPredictions(
                predictions=example["predictions"].numpy())
            metadata = json.loads(example["metadata"].numpy())
            # Apply a special case to lists of size 1. We need to adjust for the fact
            # that int-casting a Tensor with shape [1] works (this may be the original
            # element), but int-casting a list of size 1 (this may be the saved
            # element) doesn't work.
            for key, value in metadata.items():
                if isinstance(value, list) and len(value) == 1:
                    metadata[key] = value[0]
            yield prediction, metadata
Example #3
0
 def add_predictions(self,
                     model_predictions: types.ModelPredictions) -> None:
     # Wrapping label into size-1 batch.
     if "label" in model_predictions.metadata:
         label = self.parse_label(model_predictions.metadata["label"])
     else:
         label = [model_predictions.metadata["labels_multi_hot"]]
     super().add_predictions(
         types.ModelPredictions(model_predictions.element_id,
                                {"label": label},
                                model_predictions.predictions))
Example #4
0
 def add_predictions(self, model_predictions: types.ModelPredictions,
                     metadata: types.Features) -> None:
     # Wrapping label into size-1 batch.
     try:
         label = self.parse_label(metadata["label"])
     except KeyError:
         label = [metadata["labels_multi_hot"]]
     model_output = types.ModelPredictions(
         predictions=model_predictions.predictions)
     metadata = dict(metadata)
     metadata["label"] = label
     super().add_predictions(model_output, metadata=metadata)
def compute_predictions(
    model: PredictionModel, dataset: tf.data.Dataset,
    strategy: tf.distribute.Strategy, batch_size: int
) -> Iterator[Tuple[types.ModelPredictions, types.Features]]:
  """Yield the predictions of the model on the given dataset.

  Args:
    model: A function that takes tensor-valued features and returns a vector of
      predictions.
    dataset: The dataset that the function consumes to produce the predictions.
    strategy: The distribution strategy to use when computing.
    batch_size: The batch size that should be used.

  Yields:
    Pairs of model predictions and the corresponding metadata.
  """
  with strategy.scope():
    dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = (
        tf.data.experimental.AutoShardPolicy.DATA)
    dataset = dataset.with_options(options)

  for features in strategy.experimental_distribute_dataset(dataset):
    time_start = time.time()
    if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
      # TODO(josipd): Figure this out better. We can't easily filter,
      #               as they are PerReplica values, not tensors.
      features_model = {"image": features["image"]}
    else:
      features_model = features
    predictions = materialize(strategy,
                              strategy.run(model, args=(features_model,)))
    time_end = time.time()
    time_delta_per_example = (time_end - time_start) / predictions.shape[0]
    metadatas = materialize(strategy, features["metadata"])
    for i in range(predictions.shape[0]):
      model_predictions = types.ModelPredictions(
          predictions=[predictions[i]],
          time_in_s=time_delta_per_example)
      metadata_i = _slice_dictionary(metadatas, i)
      yield model_predictions, metadata_i
def compute_predictions(
        model: PredictionModel, dataset: tf.data.Dataset,
        strategy: tf.distribute.Strategy) -> Iterator[types.ModelPredictions]:
    """Yield the predictions of the model on the given dataset.

  Note that the dataset is expected to yield batches of tensors.

  Args:
    model: A function that takes tensor-valued features and returns a vector of
      predictions.
    dataset: The dataset that the function consumes to produce the predictions.
    strategy: The distribution strategy to use when computing.

  Yields:
    The predictions of the model on the dataset.
  """

    for features in strategy.experimental_distribute_dataset(dataset):
        # TODO(josipd): Figure out how to pass only tpu-allowed types.
        time_start = time.time()
        predictions = materialize(
            strategy, strategy.run(model,
                                   args=({
                                       "image": features["image"]
                                   }, )))
        time_end = time.time()
        time_delta_per_example = (time_end - time_start) / predictions.shape[0]
        try:
            element_ids = materialize(strategy, features["element_id"])
        except KeyError:
            element_ids = [None] * predictions.shape[0]
        metadatas = materialize(strategy, features["metadata"])
        for i in range(predictions.shape[0]):
            yield types.ModelPredictions(element_id=element_ids[i],
                                         metadata=_slice_dictionary(
                                             metadatas, i),
                                         predictions=[predictions[i]],
                                         time_in_s=time_delta_per_example)
Example #7
0
def add_batch(metric: base.Metric, predictions, **metadata):
    """Add a batch of predictions.

  Example usage:
  ```
  metric = rm.metrics.get("accuracy")()
  rm.metrics.add_batch(metric, [[.6, .4], [.9, .1]], label=[1, 0])
  metric.result()  # Returns {"accuracy": 0.5}.
  ```

  Args:
    metric: The metric where the predictions will be added.
    predictions: A 2d array (list or numpy array), containing one prediction per
      row.
    **metadata:
      The keys and values that will be used to construct the metadata. It can
      be any (arbitrarily) nested dictionary, with 2d arrays (list or numpy
      arrays) leaves, each holding one example per row.
  """
    for i, predictions_i in enumerate(predictions):
        metadata_i = _recursive_map(operator.itemgetter(i), metadata)
        metric.add_predictions(
            types.ModelPredictions(predictions=[predictions_i]), metadata_i)
def compute_predictions_jax(
    model: PredictionModel, dataset: tf.data.Dataset, batch_size: int
)-> Iterator[Tuple[types.ModelPredictions, types.Features]]:
  """Yield the predictions of the given JAX model on the given dataset.

  Note that this also works in multi-host configurations. You have to make
  sure that this function gets called on all hosts. The results will be yielded
  only to the host with a jax.host_id() equal to 0.

  Args:
    model: A function that takes tensor-valued features and returns a vector of
      predictions.
    dataset: The dataset that the function consumes to produce the predictions.
    batch_size: The batch size that should be used.

  Yields:
    The predictions of the model on the dataset.
  """

  def _gather(inputs):
    return jax.lax.all_gather(inputs, "i")
  gather = jax.pmap(_gather, axis_name="i")

  def infer(features):
    probabilities = model(features)
    return_vals = (probabilities, features["metadata"], features["mask"])
    return_vals_reshaped = jax.tree_map(
        lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]),
        return_vals
    )
    return jax.tree_map(lambda x: x[0], gather(return_vals_reshaped))

  if dataset.cardinality() < 0:
    raise ValueError(
        "The cardinality must be known when running JAX multi-host models.")
  total_batches = math.ceil(dataset.cardinality() / batch_size)
  lcm = lambda x, y: (x * y) // math.gcd(x, y)
  # We want each shard (host) to get an equal number of batches.
  total_batches_padded = lcm(jax.host_count(), total_batches)
  logging.info("Total batches %d, rounded up to %d",
               total_batches, total_batches_padded)

  def pad_strings(array):
    if array.dtype != tf.string:
      return array
    array_bytes = tf.strings.unicode_decode(array, "UTF-8")
    # The return type is either Tensor or RaggedTensor.
    try:
      # When a RaggedTensor, which we need to convert it.
      # to_tensor() adds a leading dimension of size 1, which we drop.
      array_bytes = array_bytes.to_tensor()[0]
    except AttributeError:
      pass
    array_size = tf.size(array_bytes)
    with tf.control_dependencies([
        tf.compat.v1.assert_less_equal(array_size, 1024)]):
      packed = tf.pad(array_bytes, [[0, 1024 - array_size]])
    return {"__packed": tf.ensure_shape(packed, [1024])}

  def unpad_strings(array):
    if isinstance(array, dict):
      with_trailing_zeros = bytes(tf.strings.unicode_encode(
          np.asarray(array["__packed"]).reshape(-1), "UTF-8").numpy())
      return with_trailing_zeros.rstrip(b"\x00")
    else:
      return np.asarray(array)

  def pad_strings_in_metadata(features):
    """Only padding of the strings subject to a gather operation."""
    features["metadata"] = tf.nest.map_structure(pad_strings,
                                                 features["metadata"])
    return features

  dataset = clu_dd.pad_dataset(
      dataset.map(pad_strings_in_metadata),
      batch_dims=[batch_size],
      pad_up_to_batches=total_batches_padded,
      cardinality=None,  # It will be inferred from the datset.
  ).batch(batch_size)

  # The shard for the current host.
  dataset_shard = dataset.shard(jax.host_count(), jax.host_id())
  logging.info("Batches per host: %d", dataset_shard.cardinality())
  for features in dataset_shard.as_numpy_iterator():
    time_start = time.time()
    # There is a bug in XLA, the following fails for int8s.
    features["mask"] = features["mask"].astype(np.int32)

    flatten = lambda array: array.reshape((-1,) + array.shape[2:])
    predictions, metadatas, masks = jax.tree_map(flatten, infer(features))

    time_end = time.time()
    time_delta_per_example = (time_end - time_start) / predictions.shape[0]
    predictions = np.asarray(predictions)  # Materialize.
    if jax.host_id() == 0:
      for i in range(predictions.shape[0]):
        if masks[i]:
          predictions_i = types.ModelPredictions(
              predictions=[predictions[i]], time_in_s=time_delta_per_example)
          metadata_i = _slice_dictionary(metadatas, i)
          is_leaf_fn = lambda x: isinstance(x, dict) and "__packed" in x
          metadata_i_unpadded = jax.tree_map(
              unpad_strings, metadata_i, is_leaf=is_leaf_fn)
          yield predictions_i, metadata_i_unpadded