Example #1
0
    def _convert_per_replica_tensor(strategy: tf.distribute.Strategy,
                                    *per_replica_tensors) -> tf.float32:
        """
        Concat the tensors distributed over the different GPU replicas.

        Parameters
        ----------
        strategy: Strategy used to distribute the GPUs.
        per_replica_tensors: tensor distributed over the GPU replicas.

        Returns
        -------
        Concatenated tensors

        """
        concatenated_tensors = []

        for per_replica_tensor in per_replica_tensors:

            concatenated_tensors.append(
                tf.concat(
                    strategy.experimental_local_results(per_replica_tensor),
                    axis=0))

        return concatenated_tensors
def materialize(strategy: tf.distribute.Strategy, value_or_nested_dict):
  """Materializes locally (possibly nested dict with) PerReplica values.

  Args:
    strategy: The strategy that will be used to evaluate.
    value_or_nested_dict: Either a single `PerReplica` object, or a nested dict
      with `PerReplica` values at the deepest level.

  Returns:
    Same type and format as the input, with PerReplica values replaced with
    corresponding `tf.Tensor`s.
  """
  if isinstance(value_or_nested_dict, dict):
    nested_dict = value_or_nested_dict
    return {
        key: materialize(strategy, value) for key, value in nested_dict.items()
    }
  else:
    return tf.concat(
        strategy.experimental_local_results(value_or_nested_dict),
        axis=0).numpy()