def testEventShape(self, num_or_size_splits, expected_split_sizes): num_or_size_splits = self.build_input(num_or_size_splits) total_size = np.sum(expected_split_sizes) shape_in_static = tf.TensorShape([total_size, 2]) shape_out_static = [ tf.TensorShape([d, 2]) for d in expected_split_sizes] bijector = tfb.Split( num_or_size_splits=num_or_size_splits, axis=-2, validate_args=True) # Test that forward_ and inverse_event_shape are correct when # event_shape_in/_out are statically known, even when the input shapes # are only partially specified. self.assertAllEqual( bijector.forward_event_shape(shape_in_static), shape_out_static) self.assertEqual( bijector.inverse_event_shape(shape_out_static), shape_in_static) # Shape is always known for splitting in eager mode, so we skip these tests. if tf.executing_eagerly(): return self.assertAllEqual( [s.as_list() for s in bijector.forward_event_shape( tf.TensorShape([total_size, None]))], [[d, None] for d in expected_split_sizes]) if bijector.split_sizes is None: static_split_sizes = tensorshape_util.constant_value_as_shape( expected_split_sizes).as_list() else: static_split_sizes = tensorshape_util.constant_value_as_shape( num_or_size_splits).as_list() static_total_size = None if None in static_split_sizes else total_size # Test correctness with an inverse input dimension of None that coincides # with the `-1` element in not-fully specified `split_sizes` shape_with_maybe_unknown_dim = ( [[None, 3]] + [[d, 3] for d in expected_split_sizes[1:]]) self.assertAllEqual( bijector.inverse_event_shape(shape_with_maybe_unknown_dim).as_list(), [static_total_size, 3]) # Test correctness with an input dimension of None that does not coincide # with a `-1` split_size. shape_with_deducable_dim = [[d, 3] for d in expected_split_sizes] shape_with_deducable_dim[2] = [None, 3] self.assertAllEqual( bijector.inverse_event_shape( shape_with_deducable_dim).as_list(), [total_size, 3]) # Test correctness for an input shape of known rank only. if bijector.split_sizes is not None: shape_with_unknown_total = ( [[d, None] for d in static_split_sizes]) else: shape_with_unknown_total = [[None, None]] * len(expected_split_sizes) self.assertAllEqual( [s.as_list() for s in bijector.forward_event_shape( tf.TensorShape([None, None]))], shape_with_unknown_total)
def _inverse_event_shape_tensor(self, output_shapes): """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. Args: output_shapes: An iterable of `Tensor`, `int32` vectors indicating event-shapes passed into `inverse` function. The length of the iterable must be equal to the number of splits. Returns: inverse_event_shape_tensor: `Tensor`, `int32` vector indicating event-portion shape after applying `inverse`. """ # Validate `output_shapes` statically if possible and get assertions. is_validated = self._validate_output_shapes([ tensorshape_util.constant_value_as_shape(s) for s in output_shapes ]) if is_validated or not self.validate_args: assertions = [] else: assertions = self._validate_output_shape_tensors(output_shapes) with tf.control_dependencies(assertions): total_size = tf.reduce_sum([t[self.axis] for t in output_shapes]) inverse_event_shape = tf.tensor_scatter_nd_update( output_shapes[0], [[prefer_static.rank_from_shape(output_shapes[0]) + self.axis] ], [total_size]) return tf.identity( tf.convert_to_tensor(inverse_event_shape, dtype_hint=tf.int32, name='inverse_event_shape'))
def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" if tensorshape_util.rank(batch_shape.shape) is not None: if tensorshape_util.rank(batch_shape.shape) != 1: raise ValueError('`batch_shape` must be a vector ' '(saw rank: {}).'.format( tensorshape_util.rank(batch_shape.shape))) batch_shape_static = tensorshape_util.constant_value_as_shape(batch_shape) batch_size_static = tensorshape_util.num_elements(batch_shape_static) dist_batch_size_static = tensorshape_util.num_elements( distribution.batch_shape) if batch_size_static is not None and dist_batch_size_static is not None: if batch_size_static != dist_batch_size_static: raise ValueError('`batch_shape` size ({}) must match ' '`distribution.batch_shape` size ({}).'.format( batch_size_static, dist_batch_size_static)) if tensorshape_util.dims(batch_shape_static) is not None: if any( tf.compat.dimension_value(dim) is not None and tf.compat.dimension_value(dim) < 1 for dim in batch_shape_static): raise ValueError('`batch_shape` elements must be >=-1.')
def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape) if tensorshape_util.is_fully_defined(batch_shape_static): return np.int32(batch_shape_static), batch_shape_static, [] with tf.name_scope(name or 'calculate_reshape'): original_size = tf.reduce_prod(original_shape) implicit_dim = tf.equal(new_shape, -1) size_implicit_dim = (original_size // tf.maximum(1, -tf.reduce_prod(new_shape))) expanded_new_shape = tf.where( # Assumes exactly one `-1`. implicit_dim, size_implicit_dim, new_shape) validations = [] if not validate else [ # pylint: disable=g-long-ternary assert_util.assert_rank( original_shape, 1, message='Original shape must be a vector.'), assert_util.assert_rank( new_shape, 1, message='New shape must be a vector.'), assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim, dtype=tf.int32), 1, message='At most one dimension can be unknown.'), assert_util.assert_positive( expanded_new_shape, message='Shape elements must be >=-1.'), assert_util.assert_equal(tf.reduce_prod(expanded_new_shape), original_size, message='Shape sizes do not match.'), ] return expanded_new_shape, batch_shape_static, validations
def test_constant_value_as_shape(self): x = np.array([1, 2, 3, 4], dtype=np.int32) s = tensorshape_util.constant_value_as_shape(x) self.assertIsInstance(s, tf.TensorShape) self.assertAllEqual(x, s) x = tf.Variable([3]) s = tensorshape_util.constant_value_as_shape(x) # `s` could be `TensorShape(None)` or `TensorShape([None])`, depending on # whether or not we're executing eagerly. We could improve # `constant_value_as_shape` to always return `TensorShape([None])`. self.assertFalse(s.is_fully_defined()) self.assertEqual( tf.TensorShape(None), tensorshape_util.constant_value_as_shape( tf.Variable([7, 2], shape=tf.TensorShape([None]))))
def _forward_event_shape(self, input_shape): """Shape of a single sample from a single batch as a list of `TensorShape`s. Same meaning as `forward_event_shape_tensor`. May be only partially defined. Args: input_shape: `TensorShape` indicating event-portion shape passed into `forward` function. Returns: forward_event_shape: A list of (possibly unknown) `TensorShape`s indicating event-portion shape after applying `forward`. The length of the list is equal to the number of splits. """ self._validate_input_shape(input_shape) if tensorshape_util.rank(input_shape) is None: output_shapes = [None] * self.num_splits else: input_shape = tf.TensorShape(input_shape).as_list() axis = tf.get_static_value(self.axis) if self.split_sizes is None: # Calculate `split_sizes` from `input_shape` and `num_splits`, if # possible. split_size = (None if input_shape[axis] is None else input_shape[axis] // self.num_splits) split_sizes = [split_size] * self.num_splits else: static_split_sizes = tf.get_static_value(self.split_sizes) if static_split_sizes is None: static_split_sizes = [None] * self.num_splits split_sizes = tensorshape_util.constant_value_as_shape( static_split_sizes).as_list() # If there is a single unknown element of `split_sizes` and the input # dimension is known, set the unknown element equal to the difference # between the input dimension and the sum of the known elements of # `split_sizes`. if sum(s is None for s in split_sizes) == 1: if input_shape is not None and input_shape[ axis] is not None: total_size = input_shape[axis] deduced_split_size = ( total_size - sum(s for s in split_sizes if s is not None)) split_sizes = [ deduced_split_size if s is None else s for s in split_sizes ] output_shapes = [] for split_size in split_sizes: output_shape = input_shape[:] output_shape[axis] = split_size output_shapes.append(output_shape) return [tf.TensorShape(shape) for shape in output_shapes]
def _forward_event_shape_tensor(self, input_shape): """Shape of a sample from a single batch as a list of `int32` 1D `Tensor`s. Args: input_shape: `Tensor`, `int32` vector indicating event-portion shape passed into `forward` function. Returns: forward_event_shape_tensor: A list of `Tensor`, `int32` vectors indicating event-portion shape after applying `forward`. The length of the list is equal to the number of splits. """ # Validate `input_shape` statically if possible and get assertions. is_validated = self._validate_input_shape( tensorshape_util.constant_value_as_shape(input_shape)) if is_validated or not self.validate_args: assertions = [] else: assertions = self._validate_input_shape_tensor(input_shape) with tf.control_dependencies(assertions): if self.split_sizes is None: split_sizes = tf.convert_to_tensor( [input_shape[self.axis] // self.num_splits] * self.num_splits) else: # Deduce the value of the unknown element of `split_sizes`, if any. split_sizes = tf.convert_to_tensor(self.split_sizes) split_sizes = tf.where( split_sizes < 0, input_shape[self.axis] - tf.reduce_sum(split_sizes) - 1, # Cancel the unknown size `-1`. split_sizes) # Each element of the `output_shape_tensor` list is equal to the # `input_shape`, with the corresponding element of `split_sizes` # substituted in the `axis` position. positive_axis = prefer_static.rank_from_shape( input_shape) + self.axis tiled_input_shape = tf.tile(input_shape[tf.newaxis, :], [self.num_splits, 1]) fused_output_shapes = tf.concat([ tiled_input_shape[:, :positive_axis], split_sizes[..., tf.newaxis], tiled_input_shape[:, positive_axis + 1:] ], axis=1) output_shapes = tf.unstack(fused_output_shapes, num=self.num_splits) return [ tf.identity( tf.convert_to_tensor(t, dtype_hint=tf.int32, name='forward_event_shape')) for t in output_shapes ]
def _event_shape(self): s = tf.get_static_value(self.sample_shape) if tensorshape_util.rank(s) == 1: sample_shape = tf.TensorShape(s) else: sample_shape = tensorshape_util.constant_value_as_shape(self.sample_shape) if (tensorshape_util.rank(sample_shape) is None or tensorshape_util.rank(self.distribution.event_shape) is None): return tf.TensorShape(None) return tensorshape_util.concatenate(sample_shape, self.distribution.event_shape)
def __init__(self, distribution, batch_shape, validate_args=False, allow_nan_stats=True, name=None): """Construct BatchReshape distribution. Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. batch_shape: Positive `int`-like vector-shaped `Tensor` representing the new shape of the batch dimensions. Up to one dimension may contain `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: The name to give Ops created by the initializer. Default value: `"BatchReshape" + distribution.name`. Raises: ValueError: if `batch_shape` is not a vector. ValueError: if `batch_shape` has non-positive elements. ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ parameters = dict(locals()) name = name or 'BatchReshape' + distribution.name with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([batch_shape], dtype_hint=tf.int32) # The unexpanded batch shape may contain up to one dimension of -1. self._batch_shape_unexpanded = tensor_util.convert_nonref_to_tensor( batch_shape, dtype=dtype, name='batch_shape', as_shape_tensor=True) validate_init_args_statically(distribution, self._batch_shape_unexpanded) self._distribution = distribution self._batch_shape_static = tensorshape_util.constant_value_as_shape( self._batch_shape_unexpanded) super(BatchReshape, self).__init__( dtype=distribution.dtype, reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, dimension, batch_shape=tuple(), dtype=tf.float32, validate_args=False, allow_nan_stats=True, name='SphericalUniform'): """Creates a new `SphericalUniform` instance. Args: dimension: Python `int`. The dimension of the embedded space where the sphere resides. batch_shape: Positive `int`-like vector-shaped `Tensor` representing the new shape of the batch dimensions. Default value: []. dtype: DType of the generated samples. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: For known-bad arguments, i.e. unsupported event dimension. """ parameters = dict(locals()) with tf.name_scope(name) as name: if dimension < 0: raise ValueError( 'Cannot sample negative-dimension unit vectors.') shape_dtype = dtype_util.common_dtype([batch_shape], dtype_hint=tf.int32) self._dimension = dimension self._batch_shape_parameter = tensor_util.convert_nonref_to_tensor( batch_shape, dtype=shape_dtype, name='batch_shape', as_shape_tensor=True) self._batch_shape_static = tensorshape_util.constant_value_as_shape( self._batch_shape_parameter) super(SphericalUniform, self).__init__(dtype=dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, parameters=parameters, name=name)
def _inverse_event_shape(self, output_shapes): """Shape of a sample from a single batch as a [nested] `TensorShape`. Same meaning as `inverse_event_shape_tensor`. May be only partially defined. Args: output_shapes: Iterable of `TensorShape`s indicating the event shapes passed into `inverse` function. The length of the iterable must be equal to the number of splits. Returns: inverse_event_shape: `TensorShape` indicating event-portion shape after applying `inverse`. Possibly unknown. """ self._validate_output_shapes(output_shapes) shapes = [] for s in output_shapes: if tensorshape_util.rank(s) is None: return tf.TensorShape(None) shapes.append(tf.TensorShape(s).as_list()) axis = tf.get_static_value(self.axis) if self.split_sizes is None: split_size = None for shape in output_shapes: if shape[axis] is not None: split_size = shape[axis] split_sizes = [split_size] * self.num_splits else: static_split_sizes = tf.get_static_value(self.split_sizes) if static_split_sizes is None: static_split_sizes = [None] * self.num_splits split_sizes = tensorshape_util.constant_value_as_shape( static_split_sizes).as_list() # Deduce as much static information about `inverse_event_shape` as possible. # If all elements of `split_sizes` are known, the concatenated dimension # of `inverse_event_shape` is the sum of `split_sizes`. if not any(s is None for s in split_sizes): total_size = sum(split_sizes) else: # If at least one of `split_sizes` and `output_shape[axis]` is known # for each split, we can determine `total_size`. total_size = 0 for split, output_shape in zip(split_sizes, shapes): if split is None and output_shape[axis] is None: total_size = None break total_size += split or output_shape[axis] shape = shapes[0] shape[axis] = total_size return tf.TensorShape(shape)
def test_transform_joint_to_joint(self, split_sizes): dist_batch_shape = tf.nest.pack_sequence_as( split_sizes, [tensorshape_util.constant_value_as_shape(s) for s in [[2, 3], [2, 1], [1, 3]]]) bijector_batch_shape = [1, 3] # Build a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size, batch_shape: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.random.uniform( minval=1., maxval=2., shape=batch_shape + [size], seed=seed())), split_sizes, dist_batch_shape) if isinstance(split_sizes, dict): base_dist = tfd.JointDistributionNamed(component_dists) else: base_dist = tfd.JointDistributionSequential(component_dists) # Transform the distribution by applying a separate bijector to each part. bijectors = [tfb.Exp(), tfb.Scale( tf.random.uniform( minval=1., maxval=2., shape=bijector_batch_shape, seed=seed())), tfb.Reshape([2, 1])] bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors), validate_args=True) # Transform a joint distribution that has different batch shape components transformed_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertRegex( str(transformed_dist), '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name)) self.assertAllEqualNested( transformed_dist.event_shape, bijector.forward_event_shape(base_dist.event_shape)) self.assertAllEqualNested(*self.evaluate(( transformed_dist.event_shape_tensor(), bijector.forward_event_shape_tensor(base_dist.event_shape_tensor())))) # Test that the batch shape components of the input are the same as those of # the output. self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape) self.assertAllEqualNested( self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape) self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)
def _batch_shape(self): # If there's a chance that the batch_shape has been overridden, we return # what we statically know about the `batch_shape_override`. This works # because: `_is_maybe_batch_override` means `static_override` is `None` or a # non-empty list, i.e., we don't statically know the `batch_shape` or we do. # # Notice that this implementation parallels the `_event_shape` except that # the `bijector` doesn't get to alter the `batch_shape`. Recall that # `batch_shape` is a property of a distribution while `event_shape` is # shared between both the `distribution` instance and the `bijector`. static_override = tensorshape_util.constant_value_as_shape( self._override_batch_shape) return (static_override if self._is_maybe_batch_override else self.distribution.batch_shape)
def _event_shape(self): # If there's a chance that the event_shape has been overridden, we return # what we statically know about the `event_shape_override`. This works # because: `_is_maybe_event_override` means `static_override` is `None` or a # non-empty list, i.e., we don't statically know the `event_shape` or we do. # # Since the `bijector` may change the `event_shape`, we then forward what we # know to the bijector. This allows the `bijector` to have final say in the # `event_shape`. static_override = tensorshape_util.constant_value_as_shape( self._override_event_shape) return self.bijector.forward_event_shape( static_override if self._is_maybe_event_override else self. distribution.event_shape)
def __init__(self, loc, presoftplus_scale, batch_shape=tuple(), dtype=tf.float32, validate_args=False, allow_nan_stats=True, name='Radial'): r"""Constructor. Args: loc: `Tensor` representing the mean of the distribution. presoftplus_scale: `Tensor` representing the pre-softplus scale, `\rho`. batch_shape: Positive `int`-like vector-shaped `Tensor` representing the new shape of the batch dimensions. Default value: []. dtype: the data type of the distribution. validate_args: Python `bool`, default `False`. When `True`, distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: For known-bad arguments, i.e. unsupported event dimension. """ parameters = dict(locals()) with tf.name_scope(name) as name: shape_dtype = dtype_util.common_dtype([batch_shape], dtype_hint=tf.int32) self._loc = loc self._presoftplus_scale = presoftplus_scale self._batch_shape_parameter = tensor_util.convert_nonref_to_tensor( batch_shape, dtype=shape_dtype, name='batch_shape') self._batch_shape_static = ( tensorshape_util.constant_value_as_shape(self._batch_shape_parameter)) super(Radial, self).__init__( dtype=dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=(tfp.distributions.FULLY_REPARAMETERIZED), parameters=parameters, name=name)
def _batch_shape(self): # If there's a chance that the batch_shape has been overridden, we return # what we statically know about the `override_batch_shape`. This works # because: `_is_maybe_batch_override` means that the `constant_value()` of # `override_batch_shape` is `None` or a non-empty list, i.e., we don't # statically know the `batch_shape` or we do. # # Notice that this implementation parallels the `_event_shape` except that # the `bijector` doesn't get to alter the `batch_shape`. Recall that # `batch_shape` is a property of a distribution while `event_shape` is # shared between both the `distribution` instance and the `bijector`. if self._is_maybe_batch_override: return tensorshape_util.constant_value_as_shape( self._override_batch_shape) # As with `batch_shape_tensor`, if the base distribution is joint with # structured batch shape and the transformed distribution is not joint, # the batch shape components of the base distribution are broadcast to # obtain the batch shape of the transformed distribution. batch_shape = self.distribution.batch_shape if tf.nest.is_nested(batch_shape) and not self._is_joint: batch_shape = functools.reduce(tf.broadcast_static_shape, tf.nest.flatten(batch_shape)) return batch_shape
def _event_shape(self): return tensorshape_util.constant_value_as_shape( tf.expand_dims(self._k, axis=0))
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init: axis_ = tf.get_static_value(self._axis) if axis_ is not None and axis_ < 0: raise ValueError('Axis should be positive, %d was given' % axis_) if axis_ is None: assertions.append(tf.assert_greater_equal(axis_, 0)) all_event_shapes = [d.event_shape for d in self._distributions] if all(tensorshape_util.is_fully_defined(event_shape) for event_shape in all_event_shapes): if all_event_shapes[1:] != all_event_shapes[:-1]: raise ValueError('Distributions must have the same `event_shape`;' 'found: {}' % all_event_shapes) all_batch_shapes = [d.batch_shape for d in self._distributions] if all(tensorshape_util.is_fully_defined(batch_shape) for batch_shape in all_batch_shapes): batch_shape = all_batch_shapes[0].as_list() batch_shape[self._axis] = 1 for b in all_batch_shapes[1:]: b = b.as_list() if len(batch_shape) != len(b): raise ValueError('Incompatible batch shape % s with %s' % (batch_shape, b)) b[self._axis] = 1 tf.broadcast_static_shape( tensorshape_util.constant_value_as_shape(batch_shape), tensorshape_util.constant_value_as_shape(b)) if not self.validate_args: return [] if self.validate_args: # Validate that event shapes all match. all_event_shapes = [d.event_shape for d in self._distributions] if not all(tensorshape_util.is_fully_defined(event_shape) for event_shape in all_event_shapes): all_event_shape_tensors = [d.event_shape_tensor() for d in self._distributions] def _get_shapes(static_shape, dynamic_shape): if tensorshape_util.is_fully_defined(static_shape): return static_shape else: return dynamic_shape event_shapes = tf.nest.map_structure(_get_shapes, all_event_shapes, all_event_shape_tensors) event_shapes = tf.nest.flatten(event_shapes) assertions.extend( assert_util.assert_equal( e1, e2, message='Distributions should have same event shapes.') for e1, e2 in zip(event_shapes[1:], event_shapes[:-1])) # Validate that batch shapes are broadcastable and concatenable along # the specified axis. if not all(tensorshape_util.is_fully_defined(d.batch_shape) for d in self._distributions): for i, d in enumerate(self._distributions[:-1]): assertions.append(tf.assert_equal( tf.size(d.batch_shape_tensor()), tf.size(self._distributions[i+1].batch_shape_tensor()))) batch_shape_tensors = [ ps.tensor_scatter_nd_update(d.batch_shape_tensor(), updates=1, indices=[self._axis]) for d in self._distributions ] assertions.append( functools.reduce(tf.broadcast_dynamic_shape, batch_shape_tensors[1:], batch_shape_tensors[:-1])) return assertions
def test_transform_joint_to_joint(self, split_sizes): dist_batch_shape = tf.nest.pack_sequence_as( split_sizes, [tensorshape_util.constant_value_as_shape(s) for s in [[2, 3], [2, 1], [1, 3]]]) bijector_batch_shape = [1, 3] # Build a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size, batch_shape: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.exp( tf.random.normal(batch_shape + [size], seed=seed()))), split_sizes, dist_batch_shape) if isinstance(split_sizes, dict): base_dist = tfd.JointDistributionNamed(component_dists) else: base_dist = tfd.JointDistributionSequential(component_dists) # Transform the distribution by applying a separate bijector to each part. bijectors = [tfb.Exp(), tfb.Scale(tf.random.normal(bijector_batch_shape, seed=seed())), tfb.Reshape([2, 1])] bijector = ToyZipMap(tf.nest.pack_sequence_as(split_sizes, bijectors)) with self.assertRaisesRegexp(ValueError, 'Overriding the batch shape'): tfd.TransformedDistribution(base_dist, bijector, batch_shape=[3]) with self.assertRaisesRegexp(ValueError, 'Overriding the event shape'): tfd.TransformedDistribution(base_dist, bijector, event_shape=[3]) # Transform a joint distribution that has different batch shape components transformed_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertAllEqualNested( transformed_dist.event_shape, bijector.forward_event_shape(base_dist.event_shape)) self.assertAllEqualNested(*self.evaluate(( transformed_dist.event_shape_tensor(), bijector.forward_event_shape_tensor(base_dist.event_shape_tensor())))) # Test that the batch shape components of the input are the same as those of # the output. self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape) self.assertAllEqualNested( self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape) self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape) # Check transformed `log_prob` against the base distribution. sample_shape = [3] sample = base_dist.sample(sample_shape, seed=seed()) x = tf.nest.map_structure(tf.zeros_like, sample) y = bijector.forward(x) base_logprob = base_dist.log_prob(x) event_ndims = tf.nest.map_structure(lambda s: s.ndims, transformed_dist.event_shape) ildj = bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) (transformed_logprob, base_logprob_plus_ildj, log_transformed_prob ) = self.evaluate([ transformed_dist.log_prob(y), base_logprob + ildj, tf.math.log(transformed_dist.prob(y)) ]) self.assertAllClose(base_logprob_plus_ildj, transformed_logprob) self.assertAllClose(transformed_logprob, log_transformed_prob) # Test that `.sample()` works and returns a result of the expected structure # and shape. y_sampled = transformed_dist.sample(sample_shape, seed=seed()) self.assertAllEqual(tf.nest.map_structure(lambda y: y.shape, y), tf.nest.map_structure(lambda y: y.shape, y_sampled))
def _batch_shape(self): return tensorshape_util.constant_value_as_shape( self._calculate_batch_shape())
def test_transform_joint_to_joint(self, split_sizes): dist_batch_shape = tf.nest.pack_sequence_as(split_sizes, [ tensorshape_util.constant_value_as_shape(s) for s in [[2, 3], [2, 1], [1, 3]] ]) bijector_batch_shape = [1, 3] # Build a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size, batch_shape: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.random.uniform(minval=1., maxval=2., shape=batch_shape + [size], seed=seed())), split_sizes, dist_batch_shape) if isinstance(split_sizes, dict): base_dist = tfd.JointDistributionNamed(component_dists) else: base_dist = tfd.JointDistributionSequential(component_dists) # Transform the distribution by applying a separate bijector to each part. bijectors = [ tfb.Exp(), tfb.Scale( tf.random.uniform(minval=1., maxval=2., shape=bijector_batch_shape, seed=seed())), tfb.Reshape([2, 1]) ] bijector = tfb.JointMap(tf.nest.pack_sequence_as( split_sizes, bijectors), validate_args=True) # Transform a joint distribution that has different batch shape components transformed_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertRegex( str(transformed_dist), '{}.*batch_shape.*event_shape.*dtype'.format( transformed_dist.name)) self.assertAllEqualNested( transformed_dist.event_shape, bijector.forward_event_shape(base_dist.event_shape)) self.assertAllEqualNested( *self.evaluate((transformed_dist.event_shape_tensor(), bijector.forward_event_shape_tensor( base_dist.event_shape_tensor())))) # Test that the batch shape components of the input are the same as those of # the output. self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape) self.assertAllEqualNested( self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape) self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape) # Check transformed `log_prob` against the base distribution. sample_shape = [3] sample = base_dist.sample(sample_shape, seed=seed()) x = tf.nest.map_structure(tf.zeros_like, sample) y = bijector.forward(x) base_logprob = base_dist.log_prob(x) event_ndims = tf.nest.map_structure(lambda s: s.ndims, transformed_dist.event_shape) ildj = bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) (transformed_logprob, base_logprob_plus_ildj, log_transformed_prob) = self.evaluate([ transformed_dist.log_prob(y), base_logprob + ildj, tf.math.log(transformed_dist.prob(y)) ]) self.assertAllClose(base_logprob_plus_ildj, transformed_logprob) self.assertAllClose(transformed_logprob, log_transformed_prob) # Test that `.sample()` works and returns a result of the expected structure # and shape. y_sampled = transformed_dist.sample(sample_shape, seed=seed()) self.assertAllEqual( tf.nest.map_structure(lambda y: y.shape, y), tf.nest.map_structure(lambda y: y.shape, y_sampled)) # Test that a `Restructure` bijector applied to a `JointDistribution` works # as expected. num_components = len(split_sizes) input_keys = (split_sizes.keys() if isinstance(split_sizes, dict) else range(num_components)) output_keys = [str(i) for i in range(num_components)] output_structure = {k: v for k, v in zip(output_keys, input_keys)} restructure = tfb.Restructure(output_structure) restructured_dist = tfd.TransformedDistribution(base_dist, bijector=restructure, validate_args=True) # Check that attributes of the restructured distribution have the same # nested structure as the `output_structure` of the bijector. Pass a no-op # as the `assert_fn` since the contents of the structures are not # required to be the same. noop_assert_fn = lambda *_: None self.assertAllAssertsNested(noop_assert_fn, restructured_dist.event_shape, output_structure) self.assertAllAssertsNested(noop_assert_fn, restructured_dist.batch_shape, output_structure) self.assertAllAssertsNested( noop_assert_fn, self.evaluate(restructured_dist.event_shape_tensor()), output_structure) self.assertAllAssertsNested( noop_assert_fn, self.evaluate(restructured_dist.batch_shape_tensor()), output_structure) self.assertAllAssertsNested( noop_assert_fn, self.evaluate( restructured_dist.sample(seed=test_util.test_seed())))
def _replace_event_shape_in_shape_tensor( input_shape, event_shape_in, event_shape_out, validate_args): """Replaces the rightmost dims in a `Tensor` representing a shape. Args: input_shape: a rank-1 `Tensor` of integers event_shape_in: the event shape expected to be present in rightmost dims of `shape_in`. event_shape_out: the event shape with which to replace `event_shape_in` in the rightmost dims of `input_shape`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: output_shape: A rank-1 integer `Tensor` with the same contents as `input_shape` except for the event dims, which are replaced with `event_shape_out`. """ output_tensorshape, is_validated = _replace_event_shape_in_tensorshape( tensorshape_util.constant_value_as_shape(input_shape), event_shape_in, event_shape_out) # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function # correctly supports control_dependencies. validation_dependencies = ( map(tf.identity, (event_shape_in, event_shape_out)) if validate_args else ()) if (tensorshape_util.is_fully_defined(output_tensorshape) and (is_validated or not validate_args)): with tf.control_dependencies(validation_dependencies): output_shape = tf.convert_to_tensor( output_tensorshape, name='output_shape', dtype_hint=tf.int32) return output_shape, output_tensorshape with tf.control_dependencies(validation_dependencies): event_shape_in_ndims = ( tf.size(event_shape_in) if tensorshape_util.num_elements(event_shape_in.shape) is None else tensorshape_util.num_elements(event_shape_in.shape)) input_non_event_shape, input_event_shape = tf.split( input_shape, num_or_size_splits=[-1, event_shape_in_ndims]) additional_assertions = [] if is_validated: pass elif validate_args: # Check that `input_event_shape` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. mask = event_shape_in >= 0 explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask) explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask) additional_assertions.append( assert_util.assert_equal( explicit_input_event_shape, explicit_event_shape_in, message='Input `event_shape` does not match `event_shape_in`.')) # We don't explicitly additionally verify # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split` # already makes this assertion. with tf.control_dependencies(additional_assertions): output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0, name='output_shape') return output_shape, output_tensorshape
def _replace_event_shape_in_tensorshape( input_tensorshape, event_shape_in, event_shape_out): """Replaces the event shape dims of a `TensorShape`. Args: input_tensorshape: a `TensorShape` instance in which to attempt replacing event shape. event_shape_in: `Tensor` shape representing the event shape expected to be present in (rightmost dims of) `tensorshape_in`. Must be compatible with the rightmost dims of `tensorshape_in`. event_shape_out: `Tensor` shape representing the new event shape, i.e., the replacement of `event_shape_in`, Returns: output_tensorshape: `TensorShape` with the rightmost `event_shape_in` replaced by `event_shape_out`. Might be partially defined, i.e., `TensorShape(None)`. is_validated: Python `bool` indicating static validation happened. Raises: ValueError: if we can determine the event shape portion of `tensorshape_in` as well as `event_shape_in` both statically, and they are not compatible. "Compatible" here means that they are identical on any dims that are not -1 in `event_shape_in`. """ event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape) if tensorshape_util.rank( input_tensorshape) is None or event_shape_in_ndims is None: return tf.TensorShape(None), False # Not is_validated. input_non_event_ndims = tensorshape_util.rank( input_tensorshape) - event_shape_in_ndims if input_non_event_ndims < 0: raise ValueError( 'Input has fewer ndims ({}) than event shape ndims ({}).'.format( tensorshape_util.rank(input_tensorshape), event_shape_in_ndims)) input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims] input_event_tensorshape = input_tensorshape[input_non_event_ndims:] # Check that `input_event_shape_` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. event_shape_in_ = tf.get_static_value(event_shape_in) is_validated = ( tensorshape_util.is_fully_defined(input_event_tensorshape) and event_shape_in_ is not None) if is_validated: input_event_shape_ = np.int32(input_event_tensorshape) mask = event_shape_in_ >= 0 explicit_input_event_shape_ = input_event_shape_[mask] explicit_event_shape_in_ = event_shape_in_[mask] if not all(explicit_input_event_shape_ == explicit_event_shape_in_): raise ValueError( 'Input `event_shape` does not match `event_shape_in`. ' '({} vs {}).'.format(input_event_shape_, event_shape_in_)) event_tensorshape_out = tensorshape_util.constant_value_as_shape( event_shape_out) if tensorshape_util.rank(event_tensorshape_out) is None: output_tensorshape = tf.TensorShape(None) else: output_tensorshape = tensorshape_util.concatenate( input_non_event_tensorshape, event_tensorshape_out) return output_tensorshape, is_validated
def __init__(self, image_shape, conditional_shape=None, num_resnet=5, num_hierarchies=3, num_filters=160, num_logistic_mix=10, receptive_field_dims=(3, 3), dropout_p=0.5, resnet_activation='concat_elu', use_weight_norm=True, use_data_init=True, high=255, low=0, dtype=tf.float32, name='PixelCNN'): """Construct Pixel CNN++ distribution. Args: image_shape: 3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image. conditional_shape: `TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input. num_resnet: `int`, the number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1]. num_hierarchies: `int`, the number of hightest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].) num_filters: `int`, the number of convolutional filters. num_logistic_mix: `int`, number of components in the logistic mixture distribution. receptive_field_dims: `tuple`, height and width in pixels of the receptive field of the convolutional layers above and to the left of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2] shows a receptive field of (3, 5) (the row containing the current pixel is included in the height). The default of (3, 3) was used to produce the results in [1]. dropout_p: `float`, the dropout probability. Should be between 0 and 1. resnet_activation: `string`, the type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'. use_weight_norm: `bool`, if `True` then use weight normalization (works only in Eager mode). use_data_init: `bool`, if `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`). high: `int`, the maximum value of the input data (255 for an 8-bit image). low: `int`, the minimum value of the input data. dtype: Data type of the `Distribution`. name: `string`, the name of the `Distribution`. """ parameters = dict(locals()) with tf.name_scope(name) as name: super(PixelCNN, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=False, allow_nan_stats=True, parameters=parameters, name=name) if not tensorshape_util.is_fully_defined(image_shape): raise ValueError('`image_shape` must be fully defined.') if (conditional_shape is not None and not tensorshape_util.is_fully_defined(conditional_shape)): raise ValueError('`conditional_shape` must be fully defined`') if tensorshape_util.rank(image_shape) != 3: raise ValueError( '`image_shape` must have length 3, representing ' '[height, width, channels] dimensions.') self._high = tf.cast(high, self.dtype) self._low = tf.cast(low, self.dtype) self._num_logistic_mix = num_logistic_mix self.network = _PixelCNNNetwork( dropout_p=dropout_p, num_resnet=num_resnet, num_hierarchies=num_hierarchies, num_filters=num_filters, num_logistic_mix=num_logistic_mix, receptive_field_dims=receptive_field_dims, resnet_activation=resnet_activation, use_weight_norm=use_weight_norm, use_data_init=use_data_init, dtype=dtype) image_shape = tensorshape_util.constant_value_as_shape(image_shape) conditional_shape = ( None if conditional_shape is None else tensorshape_util.constant_value_as_shape(conditional_shape)) image_input_shape = tensorshape_util.concatenate([None], image_shape) if conditional_shape is None: input_shape = image_input_shape else: conditional_input_shape = tensorshape_util.concatenate( [None], conditional_shape) input_shape = [image_input_shape, conditional_input_shape] self.image_shape = image_shape self.conditional_shape = conditional_shape self.network.build(input_shape)