Exemple #1
0

def random_uniform(minval, maxval, dims, name=None):
    minval = ops.convert_to_tensor(minval)
    return random_ops.random_uniform(dims,
                                     minval,
                                     maxval,
                                     dtype=minval.dtype,
                                     name=name)


recv = gen_xla_ops.xla_recv
reduce = gen_xla_ops.xla_reduce
variadic_reduce = gen_xla_ops.xla_variadic_reduce

ops.no_gradient("XlaVariadicReduce")


def reduce_window(operand,
                  init,
                  reducer,
                  window_dimensions,
                  window_strides=None,
                  base_dilations=None,
                  window_dilations=None,
                  padding=None,
                  name=None):
    """Wraps the XLA ReduceWindow operator.

  ReduceWindow is documented at
  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
          [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])

    result = gen_ragged_math_ops.ragged_range(
        starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
    return ragged_tensor.RaggedTensor.from_row_splits(
        result.rt_dense_values, result.rt_nested_splits, validate=False)


def _infer_matching_dtype(tensors, dtype_hierarchy):
  """Infers a matching dtype for tensors, and casts them to that dtype."""
  assert all(t.dtype in dtype_hierarchy for t in tensors)
  inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index)
  return [math_ops.cast(t, inferred_dtype) for t in tensors]


ops.no_gradient('RaggedRange')

#===============================================================================
# ragged_segment_<AGGREGATE>
#===============================================================================

# Docstring template used for the raggged_segment_<AGGREGATE> ops.
_RAGGED_SEGMENT_DOCSTRING = """\
Computes the %(combination)s along segments of a RaggedTensor.

  Returns a RaggedTensor `output` with `num_segments` rows, where the row
  `output[i]` is formed by taking the %(combination)s of all rows of `data`
  whose corresponding `segment_id` is `i`.

  The length of the row `output[i]` will be the maximum of the lengths of
  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`