def expand_dims_(x): """Implementation of `expand_dims`.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') new_axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(new_axis) is_neg_axis = new_axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, new_axis.dtype)) new_axis = prefer_static.where(is_neg_axis, new_axis + nx, new_axis) new_axis = prefer_static.sort(new_axis) axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _axis_size(x, axis=None): """Get number of elements of `x` in `axis`, as type `x.dtype`.""" if axis is None: return prefer_static.cast(prefer_static.size(x), x.dtype) return prefer_static.cast( prefer_static.reduce_prod( prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
def expand_dims(x, axis, name=None): """Like `tf.expand_dims` but accepts a vector of axes to expand.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(axis) is_neg_axis = axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, axis.dtype)) axis = prefer_static.where(is_neg_axis, axis + nx, axis) axis = prefer_static.sort(axis) axis_neg, axis_pos = prefer_static.split(axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _squeeze(x, axis): """A version of squeeze that works with dynamic axis.""" x = tf.convert_to_tensor(x, name='x') if axis is None: return tf.squeeze(x, axis=None) axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32) axis = _make_list_or_1d_tensor(axis) # Ensure at least 1d. keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis) return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _squeeze(x, axis): """A version of squeeze that works with dynamic axis.""" x = tf.convert_to_tensor(x, name='x') if axis is None: return tf.squeeze(x, axis=None) axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32) axis = axis + ps.zeros([1], dtype=axis.dtype) # Make axis at least 1d. keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis) return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _split_sample(self, x): result_batch_shape = self._calculate_batch_shape() sample_shape_size = (ps.rank(x) - ps.shape(result_batch_shape)[0] - ps.rank(self.event_shape)) all_batch_shapes = [d.batch_shape.as_list() if tensorshape_util.is_fully_defined(d.batch_shape) else d.batch_shape_tensor() for d in self.distributions] original_shapes = ps.stack(all_batch_shapes, axis=0) all_compose_shapes = ps.gather(original_shapes, self._axis, axis=1) x_split = tf.split(x, all_compose_shapes, axis=sample_shape_size+self._axis) return sample_shape_size, x_split
def _move_dims_to_flat_end(x, axis, x_ndims, right_end=True): """Move dims corresponding to `axis` in `x` to the end, then flatten. Args: x: `Tensor` with shape `[B0,B1,...,Bb]`. axis: Python list of indices into dimensions of `x`. x_ndims: Python integer holding number of dimensions in `x`. right_end: Python bool. Whether to move dims to the right end (else left). Returns: `Tensor` with value from `x` and dims in `axis` moved to end into one single dimension. """ if not axis: return x # Suppose x.shape = [a, b, c, d] # Suppose axis = [1, 3] # other_dims = [0, 2] in example above. other_dims = sorted(set(range(x_ndims)).difference(axis)) # x_permed.shape = [a, c, b, d] perm = other_dims + list(axis) if right_end else list(axis) + other_dims x_permed = tf.transpose(a=x, perm=perm) if tensorshape_util.is_fully_defined(x.shape): x_shape = tensorshape_util.as_list(x.shape) # other_shape = [a, c], end_shape = [b * d] other_shape = [x_shape[i] for i in other_dims] end_shape = [np.prod([x_shape[i] for i in axis])] full_shape = (other_shape + end_shape if right_end else end_shape + other_shape) else: other_shape = ps.gather(ps.shape(x), ps.cast(other_dims, tf.int64)) full_shape = ps.concat( [other_shape, [-1]] if right_end else [[-1], other_shape], axis=0) return tf.reshape(x_permed, shape=full_shape)
def count_integers(arr, weights=None, minlength=None, maxlength=None, axis=None, dtype=tf.int32, name=None): """Counts the number of occurrences of each value in an integer array `arr`. Works like `tf.math.bincount`, but provides an `axis` kwarg that specifies dimensions to reduce over. With `~axis = [i for i in range(arr.ndim) if i not in axis]`, this function returns a `Tensor` of shape `[K] + arr.shape[~axis]`. If `minlength` and `maxlength` are not given, `K = tf.reduce_max(arr) + 1` if `arr` is non-empty, and 0 otherwise. If `weights` are non-None, then index `i` of the output stores the sum of the value in `weights` at each index where the corresponding value in `arr` is `i`. Args: arr: An `int32` `Tensor` of non-negative values. weights: If non-None, must be the same shape as arr. For each value in `arr`, the bin will be incremented by the corresponding weight instead of 1. minlength: If given, ensures the output has length at least `minlength`, padding with zeros at the end if necessary. maxlength: If given, skips values in `arr` that are equal or greater than `maxlength`, ensuring that the output has length at most `maxlength`. axis: A `0-D` or `1-D` `int32` `Tensor` (with static values) designating dimensions in `arr` to reduce over. `Default value:` `None`, meaning reduce over all dimensions. dtype: If `weights` is None, determines the type of the output bins. name: A name scope for the associated operations (optional). Returns: A vector with the same dtype as `weights` or the given `dtype`. The bin values. """ with tf.name_scope(name or 'count_integers'): if axis is None: return tf.math.bincount(arr, weights=weights, minlength=minlength, maxlength=maxlength, dtype=dtype) arr = tf.convert_to_tensor(arr, dtype=tf.int32, name='arr') arr_ndims = _get_static_ndims(arr, expect_static=True) axis = _make_static_axis_non_negative_list(axis, arr_ndims) # ~axis from docstring. Dims in arr that are not in axis. not_axis = sorted(set(range(arr_ndims)).difference(axis)) # If we're reducing over everything, just use standard bincount. if not not_axis: return tf.math.bincount(arr, weights=weights, minlength=minlength, maxlength=maxlength, dtype=dtype) # Move dims in ~axis to the left, so we can tf.map_fn bincount over them, # Producing counts for every index I in ~axis. # Thus, flat_arr is not totally flat, it just has the dims in ~axis # flattened. flat_arr = _move_dims_to_flat_end(arr, not_axis, arr_ndims, right_end=False) minlength = minlength if minlength is not None else tf.reduce_max( arr) + 1 maxlength = maxlength if maxlength is not None else tf.reduce_max( arr) + 1 # tf.map_fn over dim 0. if weights is None: def one_bincount(arr_slice): return tf.math.bincount(arr_slice, weights=None, minlength=minlength, maxlength=maxlength, dtype=dtype) flat_counts = tf.map_fn(one_bincount, elems=flat_arr, fn_output_signature=dtype) else: weights = tf.convert_to_tensor(weights, name='weights') _get_static_ndims(weights, expect_static=True, expect_ndims=arr_ndims) flat_weights = _move_dims_to_flat_end(weights, not_axis, arr_ndims, right_end=False) def one_bincount(arr_and_weights_slices): arr_slice, weights_slice = arr_and_weights_slices return tf.math.bincount(arr_slice, weights=weights_slice, minlength=minlength, maxlength=maxlength, dtype=dtype) flat_counts = tf.map_fn(one_bincount, elems=[flat_arr, flat_weights], fn_output_signature=weights.dtype) # flat_counts.shape = [prod(~axis), K], because map_fn stacked on axis 0. # bincount needs to have the K bins in axis 0, so transpose... flat_counts_t = tf.transpose(a=flat_counts, perm=[1, 0]) # Throw in this assert, to ensure shape assumptions are correct. _get_static_ndims(flat_counts_t, expect_ndims=2, expect_static=True) # not_axis_shape = arr.shape[~axis] not_axis_shape = ps.gather(ps.shape(arr), indices=not_axis) # The first index of flat_counts_t indexes bins 0,..,K-1, the rest are ~axis out_shape = ps.concat([[-1], not_axis_shape], axis=0) return tf.reshape(flat_counts_t, out_shape)