def _cumulative_broadcast_dynamic(event_shape):
  broadcast_shapes = [
      ps.slice(s, begin=[0], size=[ps.size(s)-1]) for s in event_shape]
  cumulative_shapes = [broadcast_shapes[0]]
  for shape in broadcast_shapes[1:]:
    out_shape = ps.broadcast_shape(shape, cumulative_shapes[-1])
    cumulative_shapes.append(out_shape)
  return [
      ps.concat([b, ps.slice(s, begin=[ps.size(s)-1], size=[1])], axis=0)
      for b, s in zip(cumulative_shapes, event_shape)]
Ejemplo n.º 2
0
    def _inverse(self, y):
        ndims = ps.rank(y)
        shifted_y = ps.pad(
            ps.slice(
                y, ps.zeros(ndims, dtype=tf.int32),
                ps.shape(y) -
                ps.one_hot(ndims + self.axis, ndims, dtype=tf.int32)
            ),  # Remove the last entry of y in the chosen dimension.
            paddings=ps.one_hot(
                ps.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1),
                2,
                dtype=tf.int32
            )  # Insert zeros at the beginning of the chosen dimension.
        )

        return y - shifted_y
Ejemplo n.º 3
0
def _event_size(tensor_structure, event_ndims):
  """Returns the number of elements in the event-portion of a structure."""
  event_shapes = nest.map_structure(
      lambda t, nd: ps.slice(ps.shape(t), [ps.rank(t)-nd], [nd]),
      tensor_structure, event_ndims)
  return sum(ps.reduce_prod(shape) for shape in nest.flatten(event_shapes))