def _cumulative_broadcast_dynamic(event_shape): broadcast_shapes = [ ps.slice(s, begin=[0], size=[ps.size(s)-1]) for s in event_shape] cumulative_shapes = [broadcast_shapes[0]] for shape in broadcast_shapes[1:]: out_shape = ps.broadcast_shape(shape, cumulative_shapes[-1]) cumulative_shapes.append(out_shape) return [ ps.concat([b, ps.slice(s, begin=[ps.size(s)-1], size=[1])], axis=0) for b, s in zip(cumulative_shapes, event_shape)]
def _inverse(self, y): ndims = ps.rank(y) shifted_y = ps.pad( ps.slice( y, ps.zeros(ndims, dtype=tf.int32), ps.shape(y) - ps.one_hot(ndims + self.axis, ndims, dtype=tf.int32) ), # Remove the last entry of y in the chosen dimension. paddings=ps.one_hot( ps.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1), 2, dtype=tf.int32 ) # Insert zeros at the beginning of the chosen dimension. ) return y - shifted_y
def _event_size(tensor_structure, event_ndims): """Returns the number of elements in the event-portion of a structure.""" event_shapes = nest.map_structure( lambda t, nd: ps.slice(ps.shape(t), [ps.rank(t)-nd], [nd]), tensor_structure, event_ndims) return sum(ps.reduce_prod(shape) for shape in nest.flatten(event_shapes))