Beispiel #1
0
 def expand_dims_(x):
     """Implementation of `expand_dims`."""
     with tf.name_scope(name or 'expand_dims'):
         x = tf.convert_to_tensor(x, name='x')
         new_axis = tf.convert_to_tensor(axis,
                                         dtype_hint=tf.int32,
                                         name='axis')
         nx = prefer_static.rank(x)
         na = prefer_static.size(new_axis)
         is_neg_axis = new_axis < 0
         k = prefer_static.reduce_sum(
             prefer_static.cast(is_neg_axis, new_axis.dtype))
         new_axis = prefer_static.where(is_neg_axis, new_axis + nx,
                                        new_axis)
         new_axis = prefer_static.sort(new_axis)
         axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1])
         idx = prefer_static.argsort(prefer_static.concat([
             axis_pos,
             prefer_static.range(nx),
             axis_neg,
         ],
                                                          axis=0),
                                     stable=True)
         shape = prefer_static.pad(prefer_static.shape(x),
                                   paddings=[[na - k, k]],
                                   constant_values=1)
         shape = prefer_static.gather(shape, idx)
         return tf.reshape(x, shape)
Beispiel #2
0
def _axis_size(x, axis=None):
    """Get number of elements of `x` in `axis`, as type `x.dtype`."""
    if axis is None:
        return prefer_static.cast(prefer_static.size(x), x.dtype)
    return prefer_static.cast(
        prefer_static.reduce_prod(
            prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
Beispiel #3
0
def expand_dims(x, axis, name=None):
    """Like `tf.expand_dims` but accepts a vector of axes to expand."""
    with tf.name_scope(name or 'expand_dims'):
        x = tf.convert_to_tensor(x, name='x')
        axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis')
        nx = prefer_static.rank(x)
        na = prefer_static.size(axis)
        is_neg_axis = axis < 0
        k = prefer_static.reduce_sum(
            prefer_static.cast(is_neg_axis, axis.dtype))
        axis = prefer_static.where(is_neg_axis, axis + nx, axis)
        axis = prefer_static.sort(axis)
        axis_neg, axis_pos = prefer_static.split(axis, [k, -1])
        idx = prefer_static.argsort(prefer_static.concat([
            axis_pos,
            prefer_static.range(nx),
            axis_neg,
        ],
                                                         axis=0),
                                    stable=True)
        shape = prefer_static.pad(prefer_static.shape(x),
                                  paddings=[[na - k, k]],
                                  constant_values=1)
        shape = prefer_static.gather(shape, idx)
        return tf.reshape(x, shape)
Beispiel #4
0
def _squeeze(x, axis):
    """A version of squeeze that works with dynamic axis."""
    x = tf.convert_to_tensor(x, name='x')
    if axis is None:
        return tf.squeeze(x, axis=None)
    axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
    axis = _make_list_or_1d_tensor(axis)  # Ensure at least 1d.
    keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis)
    return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _squeeze(x, axis):
    """A version of squeeze that works with dynamic axis."""
    x = tf.convert_to_tensor(x, name='x')
    if axis is None:
        return tf.squeeze(x, axis=None)
    axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
    axis = axis + ps.zeros([1], dtype=axis.dtype)  # Make axis at least 1d.
    keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis)
    return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
Beispiel #6
0
 def _split_sample(self, x):
   result_batch_shape = self._calculate_batch_shape()
   sample_shape_size = (ps.rank(x)
                        - ps.shape(result_batch_shape)[0]
                        - ps.rank(self.event_shape))
   all_batch_shapes = [d.batch_shape.as_list()
                       if tensorshape_util.is_fully_defined(d.batch_shape)
                       else d.batch_shape_tensor() for d in self.distributions]
   original_shapes = ps.stack(all_batch_shapes, axis=0)
   all_compose_shapes = ps.gather(original_shapes, self._axis, axis=1)
   x_split = tf.split(x, all_compose_shapes, axis=sample_shape_size+self._axis)
   return sample_shape_size, x_split
Beispiel #7
0
def _move_dims_to_flat_end(x, axis, x_ndims, right_end=True):
    """Move dims corresponding to `axis` in `x` to the end, then flatten.

  Args:
    x: `Tensor` with shape `[B0,B1,...,Bb]`.
    axis:  Python list of indices into dimensions of `x`.
    x_ndims:  Python integer holding number of dimensions in `x`.
    right_end:  Python bool.  Whether to move dims to the right end (else left).

  Returns:
    `Tensor` with value from `x` and dims in `axis` moved to end into one single
      dimension.
  """

    if not axis:
        return x

    # Suppose x.shape = [a, b, c, d]
    # Suppose axis = [1, 3]

    # other_dims = [0, 2] in example above.
    other_dims = sorted(set(range(x_ndims)).difference(axis))
    # x_permed.shape = [a, c, b, d]
    perm = other_dims + list(axis) if right_end else list(axis) + other_dims
    x_permed = tf.transpose(a=x, perm=perm)

    if tensorshape_util.is_fully_defined(x.shape):
        x_shape = tensorshape_util.as_list(x.shape)
        # other_shape = [a, c], end_shape = [b * d]
        other_shape = [x_shape[i] for i in other_dims]
        end_shape = [np.prod([x_shape[i] for i in axis])]
        full_shape = (other_shape + end_shape if right_end else end_shape +
                      other_shape)
    else:
        other_shape = ps.gather(ps.shape(x), ps.cast(other_dims, tf.int64))
        full_shape = ps.concat(
            [other_shape, [-1]] if right_end else [[-1], other_shape], axis=0)
    return tf.reshape(x_permed, shape=full_shape)
Beispiel #8
0
def count_integers(arr,
                   weights=None,
                   minlength=None,
                   maxlength=None,
                   axis=None,
                   dtype=tf.int32,
                   name=None):
    """Counts the number of occurrences of each value in an integer array `arr`.

  Works like `tf.math.bincount`, but provides an `axis` kwarg that specifies
  dimensions to reduce over.  With
    `~axis = [i for i in range(arr.ndim) if i not in axis]`,
  this function returns a `Tensor` of shape `[K] + arr.shape[~axis]`.

  If `minlength` and `maxlength` are not given, `K = tf.reduce_max(arr) + 1`
  if `arr` is non-empty, and 0 otherwise.
  If `weights` are non-None, then index `i` of the output stores the sum of the
  value in `weights` at each index where the corresponding value in `arr` is
  `i`.

  Args:
    arr: An `int32` `Tensor` of non-negative values.
    weights: If non-None, must be the same shape as arr. For each value in
      `arr`, the bin will be incremented by the corresponding weight instead of
      1.
    minlength: If given, ensures the output has length at least `minlength`,
      padding with zeros at the end if necessary.
    maxlength: If given, skips values in `arr` that are equal or greater than
      `maxlength`, ensuring that the output has length at most `maxlength`.
    axis: A `0-D` or `1-D` `int32` `Tensor` (with static values) designating
      dimensions in `arr` to reduce over.
      `Default value:` `None`, meaning reduce over all dimensions.
    dtype: If `weights` is None, determines the type of the output bins.
    name: A name scope for the associated operations (optional).

  Returns:
    A vector with the same dtype as `weights` or the given `dtype`. The bin
    values.
  """
    with tf.name_scope(name or 'count_integers'):
        if axis is None:
            return tf.math.bincount(arr,
                                    weights=weights,
                                    minlength=minlength,
                                    maxlength=maxlength,
                                    dtype=dtype)

        arr = tf.convert_to_tensor(arr, dtype=tf.int32, name='arr')
        arr_ndims = _get_static_ndims(arr, expect_static=True)

        axis = _make_static_axis_non_negative_list(axis, arr_ndims)

        # ~axis from docstring.  Dims in arr that are not in axis.
        not_axis = sorted(set(range(arr_ndims)).difference(axis))

        # If we're reducing over everything, just use standard bincount.
        if not not_axis:
            return tf.math.bincount(arr,
                                    weights=weights,
                                    minlength=minlength,
                                    maxlength=maxlength,
                                    dtype=dtype)

        # Move dims in ~axis to the left, so we can tf.map_fn bincount over them,
        # Producing counts for every index I in ~axis.
        # Thus, flat_arr is not totally flat, it just has the dims in ~axis
        # flattened.
        flat_arr = _move_dims_to_flat_end(arr,
                                          not_axis,
                                          arr_ndims,
                                          right_end=False)
        minlength = minlength if minlength is not None else tf.reduce_max(
            arr) + 1
        maxlength = maxlength if maxlength is not None else tf.reduce_max(
            arr) + 1

        # tf.map_fn over dim 0.
        if weights is None:

            def one_bincount(arr_slice):
                return tf.math.bincount(arr_slice,
                                        weights=None,
                                        minlength=minlength,
                                        maxlength=maxlength,
                                        dtype=dtype)

            flat_counts = tf.map_fn(one_bincount,
                                    elems=flat_arr,
                                    fn_output_signature=dtype)
        else:
            weights = tf.convert_to_tensor(weights, name='weights')
            _get_static_ndims(weights,
                              expect_static=True,
                              expect_ndims=arr_ndims)
            flat_weights = _move_dims_to_flat_end(weights,
                                                  not_axis,
                                                  arr_ndims,
                                                  right_end=False)

            def one_bincount(arr_and_weights_slices):
                arr_slice, weights_slice = arr_and_weights_slices
                return tf.math.bincount(arr_slice,
                                        weights=weights_slice,
                                        minlength=minlength,
                                        maxlength=maxlength,
                                        dtype=dtype)

            flat_counts = tf.map_fn(one_bincount,
                                    elems=[flat_arr, flat_weights],
                                    fn_output_signature=weights.dtype)

        # flat_counts.shape = [prod(~axis), K], because map_fn stacked on axis 0.
        # bincount needs to have the K bins in axis 0, so transpose...
        flat_counts_t = tf.transpose(a=flat_counts, perm=[1, 0])

        # Throw in this assert, to ensure shape assumptions are correct.
        _get_static_ndims(flat_counts_t, expect_ndims=2, expect_static=True)

        # not_axis_shape = arr.shape[~axis]
        not_axis_shape = ps.gather(ps.shape(arr), indices=not_axis)

        # The first index of flat_counts_t indexes bins 0,..,K-1, the rest are ~axis
        out_shape = ps.concat([[-1], not_axis_shape], axis=0)

        return tf.reshape(flat_counts_t, out_shape)