def assert_shapes_unchanged(target_shaped_dict, possibly_bcast_dict): for param, target_param_val in six.iteritems(target_shaped_dict): np.testing.assert_array_equal( tensorshape_util.as_list(target_param_val.shape), tensorshape_util.as_list(possibly_bcast_dict[param].shape))
def testShape(self): x_shape = tf.TensorShape([5, 4, 6]) y_shape = tf.TensorShape([5, 4, 4, 4]) b = tfb.CorrelationCholesky(validate_args=True) x = tf.ones(shape=x_shape, dtype=tf.float32) y_ = b.forward(x) self.assertAllEqual(tensorshape_util.as_list(y_.shape), tensorshape_util.as_list(y_shape)) x_ = b.inverse(y_) self.assertAllEqual(tensorshape_util.as_list(x_.shape), tensorshape_util.as_list(x_shape)) y_shape_ = b.forward_event_shape(x_shape) self.assertAllEqual(tensorshape_util.as_list(y_shape_), tensorshape_util.as_list(y_shape)) x_shape_ = b.inverse_event_shape(y_shape) self.assertAllEqual(tensorshape_util.as_list(x_shape_), tensorshape_util.as_list(x_shape)) y_shape_tensor = self.evaluate( b.forward_event_shape_tensor(tensorshape_util.as_list(x_shape))) self.assertAllEqual(y_shape_tensor, tensorshape_util.as_list(y_shape)) x_shape_tensor = self.evaluate( b.inverse_event_shape_tensor(tensorshape_util.as_list(y_shape))) self.assertAllEqual(x_shape_tensor, tensorshape_util.as_list(x_shape))
def mixtures_same_family(draw, batch_shape=None, event_dim=None, enable_vars=False, depth=None): """Strategy for drawing `MixtureSameFamily` distributions. The component distribution is drawn from the `distributions` strategy. The Categorical mixture distributions are either shared across all batch members, or drawn independently for the full batch (as required by `MixtureSameFamily`). Args: draw: Hypothesis MacGuffin. Supplied by `@hps.composite`. batch_shape: An optional `TensorShape`. The batch shape of the resulting `MixtureSameFamily` distribution. The component distribution will have a batch shape of 1 rank higher (for the components being mixed). Hypothesis will pick a batch shape if omitted. event_dim: Optional Python int giving the size of each of the component distribution's parameters' event dimensions. This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all Tensors, never Variables or DeferredTensor. depth: Python `int` giving maximum nesting depth of compound Distributions. Returns: dists: A strategy for drawing `MixtureSameFamily` distributions with the specified `batch_shape` (or an arbitrary one if omitted). """ if depth is None: depth = draw(depths()) if batch_shape is None: # Ensure the components dist has at least one batch dim (a component dim). batch_shape = draw(tfp_hps.batch_shapes(min_ndims=1, min_lastdimsize=2)) else: # This mixture adds a batch dim to its underlying components dist. batch_shape = tensorshape_util.concatenate( batch_shape, draw( tfp_hps.batch_shapes(min_ndims=1, max_ndims=1, min_lastdimsize=2))) component_dist = draw( distributions(batch_shape=batch_shape, event_dim=event_dim, enable_vars=enable_vars, depth=depth - 1)) logging.info('component distribution: %s; parameters used: %s', component_dist, [ k for k, v in six.iteritems(component_dist.parameters) if v is not None ]) # scalar or same-shaped categorical? mixture_batch_shape = draw( hps.one_of(hps.just(batch_shape[:-1]), hps.just(tf.TensorShape([])))) mixture_dist = draw( base_distributions(dist_name='Categorical', batch_shape=mixture_batch_shape, event_dim=tensorshape_util.as_list(batch_shape)[-1], enable_vars=enable_vars)) logging.info( 'mixture distribution: %s; parameters used: %s', mixture_dist, [ k for k, v in six.iteritems(mixture_dist.parameters) if v is not None ]) result_dist = tfd.MixtureSameFamily(components_distribution=component_dist, mixture_distribution=mixture_dist, validate_args=True) if batch_shape[:-1] != result_dist.batch_shape: msg = ('TransformedDistribution strategy generated a bad batch shape ' 'for {}, should have been {}.').format(result_dist, batch_shape[:-1]) raise AssertionError(msg) return result_dist
def _sample_n(self, n, seed=None): if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] n = tf.convert_to_tensor(n, name='n') static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements(static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = tf.shape(cat_samples)[1:] batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = None # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) if event_shape is None: batch_ndims = prefer_static.rank_from_shape(batch_shape) event_shape = tf.shape(samples_class_c)[1 + batch_ndims:] # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name='samples_class_c_gather') samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape(lhs_flat_ret, tf.concat([samples_shape, event_shape], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret
def __init__(self, output_shape=(32, 32, 3), num_glow_blocks=3, num_steps_per_block=32, coupling_bijector_fn=None, exit_bijector_fn=None, grab_after_block=None, use_actnorm=True, seed=None, validate_args=False, name='glow'): """Creates the Glow bijector. Args: output_shape: A list of integers, specifying the event shape of the output, of the bijectors forward pass (the image). Specified as [H, W, C]. Default Value: (32, 32, 3) num_glow_blocks: An integer, specifying how many downsampling levels to include in the model. This must divide equally into both H and W, otherwise the bijector would not be invertible. Default Value: 3 num_steps_per_block: An integer specifying how many Affine Coupling and 1x1 convolution layers to include at each level of the spatial hierarchy. Default Value: 32 (i.e. the value used in the original glow paper). coupling_bijector_fn: A function which takes the argument `input_shape` and returns a callable neural network (e.g. a keras.Sequential). The network should either return a tensor with the same event shape as `input_shape` (this will employ additive coupling), a tensor with the same height and width as `input_shape` but twice the number of channels (this will employ affine coupling), or a bijector which takes in a tensor with event shape `input_shape`, and returns a tensor with shape `input_shape`. exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is a function which takes the argument `input_shape` and `output_chan` and returns a callable neural network. The neural network it returns should take a tensor of shape `input_shape` as the input, and return one of three options: A tensor with `output_chan` channels, a tensor with `2 * output_chan` channels, or a bijector. Additional details can be found in the documentation for ExitBijector. grab_after_block: A tuple of floats, specifying what fraction of the remaining channels to remove following each glow block. Glow will take the integer floor of this number multiplied by the remaining number of channels. The default is half at each spatial hierarchy. Default value: None (this will take out half of the channels after each block. use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent initialization is used to initialize this layer. Default value: `False` seed: A seed to control randomness in the 1x1 convolution initialization. Default value: `None` (i.e., non-reproducible sampling). validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` name: Python `str`, name given to ops managed by this object. Default value: `'glow'`. """ # Make sure that the input shape is fully defined. if not tensorshape_util.is_fully_defined(output_shape): raise ValueError('Shape must be fully defined.') if tensorshape_util.rank(output_shape) != 3: raise ValueError('Shape ndims must be 3 for images. Your shape is' '{}'.format(tensorshape_util.rank(output_shape))) num_glow_blocks_ = tf.get_static_value(num_glow_blocks) if (num_glow_blocks_ is None or int(num_glow_blocks_) != num_glow_blocks_ or num_glow_blocks_ < 1): raise ValueError( 'Argument `num_glow_blocks` must be a statically known' 'positive `int` (saw: {}).'.format(num_glow_blocks)) num_glow_blocks = int(num_glow_blocks_) output_shape = tensorshape_util.as_list(output_shape) h, w, c = output_shape n = num_glow_blocks nsteps = num_steps_per_block # Default Glow: Half of the channels are split off after each block, # and after the final block, no channels are split off. if grab_after_block is None: grab_after_block = tuple([0.5] * (n - 1) + [0.]) # Thing we know must be true: h and w are evenly divisible by 2, n times. # Otherwise, the squeeze bijector will not work. if w % 2**n != 0: raise ValueError('Width must be divisible by 2 at least n times.' 'Saw: {} % {} != 0'.format(w, 2**n)) if h % 2**n != 0: raise ValueError( 'Height should be divisible by 2 at least n times.') if h // 2**n < 1: raise ValueError( 'num_glow_blocks ({0}) is too large. The image height ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, h, int(np.log(h) / np.log(2.)))) if w // 2**n < 1: raise ValueError( 'num_glow_blocks ({0}) is too large. The image width ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, w, int(np.log(h) / np.log(2.)))) # Other things we want to be true: # - The number of times we take must be equal to the number of glow blocks. if len(grab_after_block) != num_glow_blocks: raise ValueError( 'Length of grab_after_block ({0}) must match the number' 'of blocks ({1}).'.format(len(grab_after_block), num_glow_blocks)) self._blockwise_splits = self._get_blockwise_splits( output_shape, grab_after_block[::-1]) # Now check on the values of blockwise splits if any([bs[0] < 1 for bs in self._blockwise_splits]): first_offender = [bs[0] for bs in self._blockwise_splits].index(True) raise ValueError( 'At at least one exit, you are taking out all of your ' 'channels, and therefore have no inputs to later blocks.' ' Try setting grab_after_block to a lower value at index' '{}.'.format(first_offender)) if any(np.isclose(gab, 0) for gab in grab_after_block): # Special case: if specifically exiting no channels, then the exit is # just an identity bijector. pass elif any([bs[1] < 1 for bs in self._blockwise_splits]): first_offender = [bs[1] for bs in self._blockwise_splits].index(True) raise ValueError( 'At least one of your layers has < 1 output channels. ' 'This means you set grab_at_block too small. ' 'Try setting grab_after_block to a larger value at index' '{}.'.format(first_offender)) # Lets start to build our bijector. We assume that the distribution is 1 # dimensional. First, lets reshape it to an image. glow_chain = [ reshape.Reshape(event_shape_out=[h // 2**n, w // 2**n, c * 4**n], event_shape_in=[h * w * c]) ] seedstream = SeedStream(seed=seed, salt='random_beta') for i in range(n): # This is the shape of the current tensor current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1)) # This is the shape of the input to both the glow block and exit bijector. this_nchan = sum(self._blockwise_splits[i][0:2]) this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan) glow_chain.append( invert.Invert( ExitBijector(current_shape, self._blockwise_splits[i], exit_bijector_fn))) glow_block = GlowBlock(input_shape=this_input_shape, num_steps=nsteps, coupling_bijector_fn=coupling_bijector_fn, use_actnorm=use_actnorm, seedstream=seedstream) if self._blockwise_splits[i][2] == 0: # All channels are passed to the RealNVP glow_chain.append(glow_block) else: # Some channels are passed around the block. # This is done with the Blockwise bijector. glow_chain.append( blockwise.Blockwise( [glow_block, identity.Identity()], [ sum(self._blockwise_splits[i][0:2]), self._blockwise_splits[i][2] ])) # Finally, lets expand the channels into spatial features. glow_chain.append( Expand(input_shape=[ h // 2**n * 2**i, w // 2**n * 2**i, c * 4**n // 4**i, ])) glow_chain = glow_chain[::-1] # To finish off, we initialize the bijector with the chain we've built # This way, the rest of the model attributes are taken care of for us. super(Glow, self).__init__(bijectors=glow_chain, validate_args=validate_args, name=name)
def positive_definite(x): shp = tensorshape_util.as_list(x.shape) psd = (tf.matmul(x, x, transpose_b=True) + .1 * tf.linalg.eye(shp[-1], batch_shape=shp[:-2])) return symmetric(psd)
def _shape(self, x): if self.use_static_shape: return tensorshape_util.as_list(x.shape) else: return self.evaluate(tf.shape(x))
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and not dtype_util.is_integer( self.mixture_distribution.dtype): raise ValueError( '`mixture_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.mixture_distribution.dtype))) if tensorshape_util.rank( self.mixture_distribution.event_shape) is not None: if tensorshape_util.rank( self.mixture_distribution.event_shape) != 0: raise ValueError( '`mixture_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( tf.size(self.mixture_distribution.event_shape_tensor()), 0, message= '`mixture_distribution` must have scalar `event_dim`s'), ] # pylint: disable=protected-access mixture_dist_param = (self.mixture_distribution._probs if self.mixture_distribution._logits is None else self.mixture_distribution._logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[-1]) component_bst = None if km is not None and kc is not None: if km != kc: raise ValueError( '`mixture_distribution` components ({}) does not ' 'equal `components_distribution.batch_shape[-1]` ' '({})'.format(km, kc)) elif self.validate_args: if km is None: mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) km = tf.shape(mixture_dist_param)[-1] if kc is None: component_bst = self.components_distribution.batch_shape_tensor( ) kc = component_bst[-1] assertions += [ assert_util.assert_equal( km, kc, message=( '`mixture_distribution` components does not equal ' '`components_distribution.batch_shape[-1]`')), ] mdbs = self.mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[:-1] if (tensorshape_util.is_fully_defined(mdbs) and tensorshape_util.is_fully_defined(cdbs)): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( '`mixture_distribution.batch_shape` (`{}`) is not ' 'compatible with `components_distribution.batch_shape` ' '(`{}`)'.format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif self.validate_args: if not tensorshape_util.is_fully_defined(mdbs): mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) mdbs = tf.shape(mixture_dist_param)[:-1] if not tensorshape_util.is_fully_defined(cdbs): if component_bst is None: component_bst = self.components_distribution.batch_shape_tensor( ) cdbs = component_bst[:-1] assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs), cdbs, message=( '`mixture_distribution.batch_shape` is not ' 'compatible with `components_distribution.batch_shape`' )) ] return assertions
def __init__(self, mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. reparameterize: Python `bool`, default `False`. Whether to reparameterize samples of the distribution using implicit reparameterization gradients [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are equivalent to the ones described by [(Graves, 2016)][2]. The gradients for the components parameters are also computed using implicit reparameterization (as opposed to ancestral sampling), meaning that all components are updated every step. Only works when: (1) components_distribution is fully reparameterized; (2) components_distribution is either a scalar distribution or fully factorized (tfd.Independent applied to a scalar distribution); (3) batch shape has a known rank. Experimental, may be slow and produce infs/NaNs. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. #### References [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit reparameterization gradients. In _Neural Information Processing Systems_, 2018. https://arxiv.org/abs/1805.08498 [2]: Alex Graves. Stochastic Backpropagation through Mixture Density Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690 """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = tf.compat.dimension_value(s.shape[0]) if self._event_ndims is None: self._event_ndims = tf.size(s) self._event_size = tf.reduce_prod(s) if not dtype_util.is_integer(mixture_distribution.dtype): raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(dtype_util.name(mixture_distribution.dtype))) if (tensorshape_util.rank(mixture_distribution.event_shape) is not None and tensorshape_util.rank( mixture_distribution.event_shape) != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.size(mixture_distribution.event_shape_tensor()), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[:-1] if tensorshape_util.is_fully_defined( mdbs) and tensorshape_util.is_fully_defined(cdbs): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] mixture_dist_param = (mixture_distribution.probs if mixture_distribution.logits is None else mixture_distribution.logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[-1]) if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(mixture_dist_param)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ assert_util.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(mixture_dist_param)[-1] self._num_components = km self._reparameterize = reparameterize if reparameterize: # Note: tfd.Independent passes through the reparameterization type hence # we do not need separate logic for Independent. if (self._components_distribution.reparameterization_type != reparameterization.FULLY_REPARAMETERIZED): raise ValueError("Cannot reparameterize a mixture of " "non-reparameterized components.") reparameterization_type = reparameterization.FULLY_REPARAMETERIZED else: reparameterization_type = reparameterization.NOT_REPARAMETERIZED super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def __init__(self, num_or_size_splits, axis=-1, validate_args=False, name='split'): """Creates the bijector. Args: num_or_size_splits: Either a Python integer indicating the number of splits along `axis` or a 1-D integer `Tensor` or Python list containing the sizes of each output tensor along `axis`. If a list/`Tensor`, it may contain at most one value of `-1`, which indicates a split size that is unknown and determined from input. axis: A negative integer or scalar `int32` `Tensor`. The dimension along which to split. Must be negative to enable the bijector to support arbitrary batch dimensions. Defaults to -1 (note that this is different from the `tf.Split` default of `0`). Must be statically known. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. """ parameters = dict(locals()) with tf.name_scope(name) as name: if isinstance(num_or_size_splits, numbers.Integral): self._num_splits = num_or_size_splits self._split_sizes = None else: self._split_sizes = tensor_util.convert_nonref_to_tensor( num_or_size_splits, name='num_or_size_splits', dtype=tf.int32) if tensorshape_util.rank(self._split_sizes.shape) != 1: raise ValueError( '`num_or_size_splits` must be an integer or 1-D `Tensor`.' ) num_splits = tensorshape_util.as_list( self._split_sizes.shape)[0] if num_splits is None: raise ValueError( 'If `num_or_size_splits` is a vector of split sizes ' 'it must have a statically-known number of ' 'elements.') self._num_splits = num_splits static_axis = tf.get_static_value(axis) if static_axis is None: raise ValueError('`axis` must be statically known.') if static_axis >= 0: raise ValueError( '`axis` must be negative. Got {}'.format(axis)) self._axis = ps.convert_to_shape_tensor(axis, tf.int32) super(Split, self).__init__(forward_min_event_ndims=-axis, inverse_min_event_ndims=[-axis] * self.num_splits, is_constant_jacobian=True, validate_args=validate_args, parameters=parameters, name=name)
def broadcasting_params(draw, batch_shape, params_event_ndims, event_dim=None, enable_vars=False, constraint_fn_for=lambda param: identity_fn, mutex_params=(), dtype=np.float32): """Streategy for drawing parameters which jointly have the given batch shape. Specifically, the batch shapes of the returned parameters will broadcast to the requested batch shape. The dtypes of the returned parameters are determined by their respective constraint functions. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. batch_shape: A `TensorShape`. The returned parameters' batch shapes will broadcast to this. params_event_ndims: Python `dict` mapping the name of each parameter to a Python `int` giving the event ndims for that parameter. event_dim: Optional Python int giving the size of each parameter's event dimensions (except where overridden by any applicable constraint functions). This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`}. constraint_fn_for: Python callable mapping parameter name to constraint function. The latter is itself a Python callable which converts an unconstrained Tensor (currently with float32 values from -200 to +200) into one that meets the parameter's validity constraints. mutex_params: Python iterable of Python sets. Each set gives a clique of mutually exclusive parameters (e.g., the 'probs' and 'logits' of a Categorical). At most one parameter from each set will appear in the result. dtype: Dtype for generated parameters. Returns: params: A Hypothesis strategy for drawing Python `dict`s mapping parameter name to a `tf.Tensor`, `tf.Variable`, `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable`. The batch shapes of the returned parameters broadcast together to the supplied `batch_shape`. Only parameters whose names appear as keys in `params_event_ndims` will appear (but possibly not all of them, depending on `mutex_params`). """ if event_dim is None: event_dim = draw(hps.integers(min_value=2, max_value=6)) params_event_ndims = params_event_ndims or {} remaining_params = set(params_event_ndims.keys()) params_to_use = [] while remaining_params: param = draw(hps.sampled_from(sorted(remaining_params))) params_to_use.append(param) remaining_params.remove(param) for mutex_set in mutex_params: if param in mutex_set: remaining_params -= mutex_set param_batch_shapes = draw( broadcasting_named_shapes(batch_shape, params_to_use)) params_kwargs = dict() for param in params_to_use: param_batch_shape = param_batch_shapes[param] param_event_rank = params_event_ndims[param] param_shape = (tensorshape_util.as_list(param_batch_shape) + [event_dim] * param_event_rank) # Reduce our risk of exceeding TF kernel broadcast limits. hp.assume(len(param_shape) < 6) # TODO(axch): Can I replace `params_event_ndims` and `constraint_fn_for` # with a map from params to `Suppport`s, and use `tensors_in_support` here # instead of this explicit `constrained_tensors` function? param_strategy = constrained_tensors(constraint_fn_for(param), param_shape, dtype=dtype) params_kwargs[param] = draw( maybe_variable(param_strategy, enable_vars=enable_vars, dtype=dtype, name=param)) return params_kwargs
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in, event_shape_out, validate_args): """Replaces the rightmost dims in a `Tensor` representing a shape. Args: input_shape: a rank-1 `Tensor` of integers event_shape_in: the event shape expected to be present in rightmost dims of `shape_in`. event_shape_out: the event shape with which to replace `event_shape_in` in the rightmost dims of `input_shape`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: output_shape: A rank-1 integer `Tensor` with the same contents as `input_shape` except for the event dims, which are replaced with `event_shape_out`. """ output_tensorshape, is_validated = _replace_event_shape_in_tensorshape( tensorshape_util.constant_value_as_shape(input_shape), event_shape_in, event_shape_out) if (tensorshape_util.is_fully_defined(output_tensorshape) and (is_validated or not validate_args)): output_shape = ps.convert_to_shape_tensor( tensorshape_util.as_list(output_tensorshape), name='output_shape', dtype_hint=tf.int32) return output_shape, output_tensorshape event_shape_in_ndims = ( tf.size(event_shape_in) if tensorshape_util.num_elements(event_shape_in.shape) is None else tensorshape_util.num_elements(event_shape_in.shape)) input_non_event_shape, input_event_shape = tf.split( input_shape, num_or_size_splits=[-1, event_shape_in_ndims]) additional_assertions = [] if is_validated: pass elif validate_args: # Check that `input_event_shape` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. mask = event_shape_in >= 0 explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask) explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask) additional_assertions.append( assert_util.assert_equal( explicit_input_event_shape, explicit_event_shape_in, message='Input `event_shape` does not match `event_shape_in`.') ) # We don't explicitly additionally verify # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split` # already makes this assertion. with tf.control_dependencies(additional_assertions): output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0, name='output_shape') return output_shape, output_tensorshape