Example #1
0
def assert_shapes_unchanged(target_shaped_dict, possibly_bcast_dict):
    for param, target_param_val in six.iteritems(target_shaped_dict):
        np.testing.assert_array_equal(
            tensorshape_util.as_list(target_param_val.shape),
            tensorshape_util.as_list(possibly_bcast_dict[param].shape))
Example #2
0
    def testShape(self):
        x_shape = tf.TensorShape([5, 4, 6])
        y_shape = tf.TensorShape([5, 4, 4, 4])

        b = tfb.CorrelationCholesky(validate_args=True)

        x = tf.ones(shape=x_shape, dtype=tf.float32)
        y_ = b.forward(x)
        self.assertAllEqual(tensorshape_util.as_list(y_.shape),
                            tensorshape_util.as_list(y_shape))
        x_ = b.inverse(y_)
        self.assertAllEqual(tensorshape_util.as_list(x_.shape),
                            tensorshape_util.as_list(x_shape))

        y_shape_ = b.forward_event_shape(x_shape)
        self.assertAllEqual(tensorshape_util.as_list(y_shape_),
                            tensorshape_util.as_list(y_shape))
        x_shape_ = b.inverse_event_shape(y_shape)
        self.assertAllEqual(tensorshape_util.as_list(x_shape_),
                            tensorshape_util.as_list(x_shape))

        y_shape_tensor = self.evaluate(
            b.forward_event_shape_tensor(tensorshape_util.as_list(x_shape)))
        self.assertAllEqual(y_shape_tensor, tensorshape_util.as_list(y_shape))
        x_shape_tensor = self.evaluate(
            b.inverse_event_shape_tensor(tensorshape_util.as_list(y_shape)))
        self.assertAllEqual(x_shape_tensor, tensorshape_util.as_list(x_shape))
Example #3
0
def mixtures_same_family(draw,
                         batch_shape=None,
                         event_dim=None,
                         enable_vars=False,
                         depth=None):
    """Strategy for drawing `MixtureSameFamily` distributions.

  The component distribution is drawn from the `distributions` strategy.

  The Categorical mixture distributions are either shared across all batch
  members, or drawn independently for the full batch (as required by
  `MixtureSameFamily`).

  Args:
    draw: Hypothesis MacGuffin.  Supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      `MixtureSameFamily` distribution.  The component distribution will have a
      batch shape of 1 rank higher (for the components being mixed).  Hypothesis
      will pick a batch shape if omitted.
    event_dim: Optional Python int giving the size of each of the component
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all Tensors, never Variables or DeferredTensor.
    depth: Python `int` giving maximum nesting depth of compound Distributions.

  Returns:
    dists: A strategy for drawing `MixtureSameFamily` distributions with the
      specified `batch_shape` (or an arbitrary one if omitted).
  """
    if depth is None:
        depth = draw(depths())

    if batch_shape is None:
        # Ensure the components dist has at least one batch dim (a component dim).
        batch_shape = draw(tfp_hps.batch_shapes(min_ndims=1,
                                                min_lastdimsize=2))
    else:  # This mixture adds a batch dim to its underlying components dist.
        batch_shape = tensorshape_util.concatenate(
            batch_shape,
            draw(
                tfp_hps.batch_shapes(min_ndims=1,
                                     max_ndims=1,
                                     min_lastdimsize=2)))

    component_dist = draw(
        distributions(batch_shape=batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars,
                      depth=depth - 1))
    logging.info('component distribution: %s; parameters used: %s',
                 component_dist, [
                     k for k, v in six.iteritems(component_dist.parameters)
                     if v is not None
                 ])
    # scalar or same-shaped categorical?
    mixture_batch_shape = draw(
        hps.one_of(hps.just(batch_shape[:-1]), hps.just(tf.TensorShape([]))))
    mixture_dist = draw(
        base_distributions(dist_name='Categorical',
                           batch_shape=mixture_batch_shape,
                           event_dim=tensorshape_util.as_list(batch_shape)[-1],
                           enable_vars=enable_vars))
    logging.info(
        'mixture distribution: %s; parameters used: %s', mixture_dist, [
            k
            for k, v in six.iteritems(mixture_dist.parameters) if v is not None
        ])
    result_dist = tfd.MixtureSameFamily(components_distribution=component_dist,
                                        mixture_distribution=mixture_dist,
                                        validate_args=True)
    if batch_shape[:-1] != result_dist.batch_shape:
        msg = ('TransformedDistribution strategy generated a bad batch shape '
               'for {}, should have been {}.').format(result_dist,
                                                      batch_shape[:-1])
        raise AssertionError(msg)
    return result_dist
Example #4
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seed)
            stream = SeedStream(seed, salt='Mixture')

            for c in range(self.num_components):
                samples.append(self.components[c].sample(n, seed=stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=npdt(1),
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                tensorshape_util.rank(
                    self._static_event_shape))  # [n, B, k, [1]*e]
            return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples = self.cat.sample(n, seed=seed)

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
        else:
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
        else:
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        else:
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                         samples_shape)

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(
            data=samples_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),
                                       samples_shape)

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
            data=batch_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)
        samples_class = [None for _ in range(self.num_components)]

        stream = SeedStream(seed, salt='Mixture')

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            samples_class_c = self.components[c].sample(n_class, seed=stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            #
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                samples_class_c,
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
                                        lookup_partitioned_batch_indices,
                                        name='samples_class_c_gather')
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
                                         data=samples_class)
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        tensorshape_util.set_shape(
            ret,
            tensorshape_util.concatenate(static_samples_shape,
                                         self.event_shape))
        return ret
Example #5
0
    def __init__(self,
                 output_shape=(32, 32, 3),
                 num_glow_blocks=3,
                 num_steps_per_block=32,
                 coupling_bijector_fn=None,
                 exit_bijector_fn=None,
                 grab_after_block=None,
                 use_actnorm=True,
                 seed=None,
                 validate_args=False,
                 name='glow'):
        """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
        # Make sure that the input shape is fully defined.
        if not tensorshape_util.is_fully_defined(output_shape):
            raise ValueError('Shape must be fully defined.')
        if tensorshape_util.rank(output_shape) != 3:
            raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                             '{}'.format(tensorshape_util.rank(output_shape)))

        num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
        if (num_glow_blocks_ is None
                or int(num_glow_blocks_) != num_glow_blocks_
                or num_glow_blocks_ < 1):
            raise ValueError(
                'Argument `num_glow_blocks` must be a statically known'
                'positive `int` (saw: {}).'.format(num_glow_blocks))
        num_glow_blocks = int(num_glow_blocks_)

        output_shape = tensorshape_util.as_list(output_shape)
        h, w, c = output_shape
        n = num_glow_blocks
        nsteps = num_steps_per_block

        # Default Glow: Half of the channels are split off after each block,
        # and after the final block, no channels are split off.
        if grab_after_block is None:
            grab_after_block = tuple([0.5] * (n - 1) + [0.])

        # Thing we know must be true: h and w are evenly divisible by 2, n times.
        # Otherwise, the squeeze bijector will not work.
        if w % 2**n != 0:
            raise ValueError('Width must be divisible by 2 at least n times.'
                             'Saw: {} % {} != 0'.format(w, 2**n))
        if h % 2**n != 0:
            raise ValueError(
                'Height should be divisible by 2 at least n times.')
        if h // 2**n < 1:
            raise ValueError(
                'num_glow_blocks ({0}) is too large. The image height '
                '({1}) must be divisible by 2 no more than {2} '
                'times.'.format(num_glow_blocks, h,
                                int(np.log(h) / np.log(2.))))
        if w // 2**n < 1:
            raise ValueError(
                'num_glow_blocks ({0}) is too large. The image width '
                '({1}) must be divisible by 2 no more than {2} '
                'times.'.format(num_glow_blocks, w,
                                int(np.log(h) / np.log(2.))))

        # Other things we want to be true:
        # - The number of times we take must be equal to the number of glow blocks.
        if len(grab_after_block) != num_glow_blocks:
            raise ValueError(
                'Length of grab_after_block ({0}) must match the number'
                'of blocks ({1}).'.format(len(grab_after_block),
                                          num_glow_blocks))

        self._blockwise_splits = self._get_blockwise_splits(
            output_shape, grab_after_block[::-1])

        # Now check on the values of blockwise splits
        if any([bs[0] < 1 for bs in self._blockwise_splits]):
            first_offender = [bs[0]
                              for bs in self._blockwise_splits].index(True)
            raise ValueError(
                'At at least one exit, you are taking out all of your '
                'channels, and therefore have no inputs to later blocks.'
                ' Try setting grab_after_block to a lower value at index'
                '{}.'.format(first_offender))

        if any(np.isclose(gab, 0) for gab in grab_after_block):
            # Special case: if specifically exiting no channels, then the exit is
            # just an identity bijector.
            pass
        elif any([bs[1] < 1 for bs in self._blockwise_splits]):
            first_offender = [bs[1]
                              for bs in self._blockwise_splits].index(True)
            raise ValueError(
                'At least one of your layers has < 1 output channels. '
                'This means you set grab_at_block too small. '
                'Try setting grab_after_block to a larger value at index'
                '{}.'.format(first_offender))

        # Lets start to build our bijector. We assume that the distribution is 1
        # dimensional. First, lets reshape it to an image.
        glow_chain = [
            reshape.Reshape(event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
                            event_shape_in=[h * w * c])
        ]

        seedstream = SeedStream(seed=seed, salt='random_beta')

        for i in range(n):

            # This is the shape of the current tensor
            current_shape = (h // 2**n * 2**i, w // 2**n * 2**i,
                             c * 4**(i + 1))

            # This is the shape of the input to both the glow block and exit bijector.
            this_nchan = sum(self._blockwise_splits[i][0:2])
            this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

            glow_chain.append(
                invert.Invert(
                    ExitBijector(current_shape, self._blockwise_splits[i],
                                 exit_bijector_fn)))

            glow_block = GlowBlock(input_shape=this_input_shape,
                                   num_steps=nsteps,
                                   coupling_bijector_fn=coupling_bijector_fn,
                                   use_actnorm=use_actnorm,
                                   seedstream=seedstream)

            if self._blockwise_splits[i][2] == 0:
                # All channels are passed to the RealNVP
                glow_chain.append(glow_block)
            else:
                # Some channels are passed around the block.
                # This is done with the Blockwise bijector.
                glow_chain.append(
                    blockwise.Blockwise(
                        [glow_block, identity.Identity()], [
                            sum(self._blockwise_splits[i][0:2]),
                            self._blockwise_splits[i][2]
                        ]))

            # Finally, lets expand the channels into spatial features.
            glow_chain.append(
                Expand(input_shape=[
                    h // 2**n * 2**i,
                    w // 2**n * 2**i,
                    c * 4**n // 4**i,
                ]))

        glow_chain = glow_chain[::-1]
        # To finish off, we initialize the bijector with the chain we've built
        # This way, the rest of the model attributes are taken care of for us.
        super(Glow, self).__init__(bijectors=glow_chain,
                                   validate_args=validate_args,
                                   name=name)
def positive_definite(x):
    shp = tensorshape_util.as_list(x.shape)
    psd = (tf.matmul(x, x, transpose_b=True) +
           .1 * tf.linalg.eye(shp[-1], batch_shape=shp[:-2]))
    return symmetric(psd)
 def _shape(self, x):
   if self.use_static_shape:
     return tensorshape_util.as_list(x.shape)
   else:
     return self.evaluate(tf.shape(x))
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and not dtype_util.is_integer(
                self.mixture_distribution.dtype):
            raise ValueError(
                '`mixture_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.mixture_distribution.dtype)))

        if tensorshape_util.rank(
                self.mixture_distribution.event_shape) is not None:
            if tensorshape_util.rank(
                    self.mixture_distribution.event_shape) != 0:
                raise ValueError(
                    '`mixture_distribution` must have scalar `event_dim`s')
        elif self.validate_args:
            assertions += [
                assert_util.assert_equal(
                    tf.size(self.mixture_distribution.event_shape_tensor()),
                    0,
                    message=
                    '`mixture_distribution` must have scalar `event_dim`s'),
            ]

        # pylint: disable=protected-access
        mixture_dist_param = (self.mixture_distribution._probs
                              if self.mixture_distribution._logits is None else
                              self.mixture_distribution._logits)
        km = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                1)[-1])
        kc = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(
                self.components_distribution.batch_shape, 1)[-1])
        component_bst = None
        if km is not None and kc is not None:
            if km != kc:
                raise ValueError(
                    '`mixture_distribution` components ({}) does not '
                    'equal `components_distribution.batch_shape[-1]` '
                    '({})'.format(km, kc))
        elif self.validate_args:
            if km is None:
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                km = tf.shape(mixture_dist_param)[-1]
            if kc is None:
                component_bst = self.components_distribution.batch_shape_tensor(
                )
                kc = component_bst[-1]
            assertions += [
                assert_util.assert_equal(
                    km,
                    kc,
                    message=(
                        '`mixture_distribution` components does not equal '
                        '`components_distribution.batch_shape[-1]`')),
            ]

        mdbs = self.mixture_distribution.batch_shape
        cdbs = tensorshape_util.with_rank_at_least(
            self.components_distribution.batch_shape, 1)[:-1]
        if (tensorshape_util.is_fully_defined(mdbs)
                and tensorshape_util.is_fully_defined(cdbs)):
            if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                raise ValueError(
                    '`mixture_distribution.batch_shape` (`{}`) is not '
                    'compatible with `components_distribution.batch_shape` '
                    '(`{}`)'.format(tensorshape_util.as_list(mdbs),
                                    tensorshape_util.as_list(cdbs)))
        elif self.validate_args:
            if not tensorshape_util.is_fully_defined(mdbs):
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                mdbs = tf.shape(mixture_dist_param)[:-1]
            if not tensorshape_util.is_fully_defined(cdbs):
                if component_bst is None:
                    component_bst = self.components_distribution.batch_shape_tensor(
                    )
                cdbs = component_bst[:-1]
            assertions += [
                assert_util.assert_equal(
                    distribution_utils.pick_vector(
                        tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs),
                    cdbs,
                    message=(
                        '`mixture_distribution.batch_shape` is not '
                        'compatible with `components_distribution.batch_shape`'
                    ))
            ]

        return assertions
Example #9
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      reparameterize: Python `bool`, default `False`. Whether to reparameterize
        samples of the distribution using implicit reparameterization gradients
        [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are
        equivalent to the ones described by [(Graves, 2016)][2]. The gradients
        for the components parameters are also computed using implicit
        reparameterization (as opposed to ancestral sampling), meaning that
        all components are updated every step.
        Only works when:
          (1) components_distribution is fully reparameterized;
          (2) components_distribution is either a scalar distribution or
          fully factorized (tfd.Independent applied to a scalar distribution);
          (3) batch shape has a known rank.
        Experimental, may be slow and produce infs/NaNs.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.

    #### References

    [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit
         reparameterization gradients. In _Neural Information Processing
         Systems_, 2018. https://arxiv.org/abs/1805.08498

    [2]: Alex Graves. Stochastic Backpropagation through Mixture Density
         Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = tf.compat.dimension_value(s.shape[0])
            if self._event_ndims is None:
                self._event_ndims = tf.size(s)
            self._event_size = tf.reduce_prod(s)

            if not dtype_util.is_integer(mixture_distribution.dtype):
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(dtype_util.name(mixture_distribution.dtype)))

            if (tensorshape_util.rank(mixture_distribution.event_shape)
                    is not None and tensorshape_util.rank(
                        mixture_distribution.event_shape) != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.size(mixture_distribution.event_shape_tensor()),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = tensorshape_util.with_rank_at_least(
                components_distribution.batch_shape, 1)[:-1]
            if tensorshape_util.is_fully_defined(
                    mdbs) and tensorshape_util.is_fully_defined(cdbs):
                if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(tensorshape_util.as_list(mdbs),
                                        tensorshape_util.as_list(cdbs)))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            mixture_dist_param = (mixture_distribution.probs
                                  if mixture_distribution.logits is None else
                                  mixture_distribution.logits)
            km = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                    1)[-1])
            kc = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(
                    components_distribution.batch_shape, 1)[-1])
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(mixture_dist_param)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(mixture_dist_param)[-1]

            self._num_components = km

            self._reparameterize = reparameterize
            if reparameterize:
                # Note: tfd.Independent passes through the reparameterization type hence
                # we do not need separate logic for Independent.
                if (self._components_distribution.reparameterization_type !=
                        reparameterization.FULLY_REPARAMETERIZED):
                    raise ValueError("Cannot reparameterize a mixture of "
                                     "non-reparameterized components.")
                reparameterization_type = reparameterization.FULLY_REPARAMETERIZED
            else:
                reparameterization_type = reparameterization.NOT_REPARAMETERIZED

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    self._mixture_distribution._graph_parents  # pylint: disable=protected-access
                    + self._components_distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
Example #10
0
    def __init__(self,
                 num_or_size_splits,
                 axis=-1,
                 validate_args=False,
                 name='split'):
        """Creates the bijector.

    Args:
      num_or_size_splits: Either a Python integer indicating the number of
        splits along `axis` or a 1-D integer `Tensor` or Python list containing
        the sizes of each output tensor along `axis`. If a list/`Tensor`, it may
        contain at most one value of `-1`, which indicates a split size that is
        unknown and determined from input.
      axis: A negative integer or scalar `int32` `Tensor`. The dimension along
        which to split. Must be negative to enable the bijector to support
        arbitrary batch dimensions. Defaults to -1 (note that this is different
        from the `tf.Split` default of `0`). Must be statically known.
      validate_args: Python `bool` indicating whether arguments should
        be checked for correctness.
      name: Python `str`, name given to ops managed by this object.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:

            if isinstance(num_or_size_splits, numbers.Integral):
                self._num_splits = num_or_size_splits
                self._split_sizes = None
            else:
                self._split_sizes = tensor_util.convert_nonref_to_tensor(
                    num_or_size_splits,
                    name='num_or_size_splits',
                    dtype=tf.int32)

                if tensorshape_util.rank(self._split_sizes.shape) != 1:
                    raise ValueError(
                        '`num_or_size_splits` must be an integer or 1-D `Tensor`.'
                    )

                num_splits = tensorshape_util.as_list(
                    self._split_sizes.shape)[0]
                if num_splits is None:
                    raise ValueError(
                        'If `num_or_size_splits` is a vector of split sizes '
                        'it must have a statically-known number of '
                        'elements.')
                self._num_splits = num_splits

            static_axis = tf.get_static_value(axis)
            if static_axis is None:
                raise ValueError('`axis` must be statically known.')
            if static_axis >= 0:
                raise ValueError(
                    '`axis` must be negative. Got {}'.format(axis))

            self._axis = ps.convert_to_shape_tensor(axis, tf.int32)

            super(Split, self).__init__(forward_min_event_ndims=-axis,
                                        inverse_min_event_ndims=[-axis] *
                                        self.num_splits,
                                        is_constant_jacobian=True,
                                        validate_args=validate_args,
                                        parameters=parameters,
                                        name=name)
Example #11
0
def broadcasting_params(draw,
                        batch_shape,
                        params_event_ndims,
                        event_dim=None,
                        enable_vars=False,
                        constraint_fn_for=lambda param: identity_fn,
                        mutex_params=(),
                        dtype=np.float32):
    """Streategy for drawing parameters which jointly have the given batch shape.

  Specifically, the batch shapes of the returned parameters will broadcast to
  the requested batch shape.

  The dtypes of the returned parameters are determined by their respective
  constraint functions.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    batch_shape: A `TensorShape`.  The returned parameters' batch shapes will
      broadcast to this.
    params_event_ndims: Python `dict` mapping the name of each parameter to a
      Python `int` giving the event ndims for that parameter.
    event_dim: Optional Python int giving the size of each parameter's event
      dimensions (except where overridden by any applicable constraint
      functions).  This is shared across all parameters, permitting square event
      matrices, compatible location and scale Tensors, etc. If omitted,
      Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test. If `False`, the returned parameters are
      all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
      `tfp.util.TransformedVariable`}.
    constraint_fn_for: Python callable mapping parameter name to constraint
      function.  The latter is itself a Python callable which converts an
      unconstrained Tensor (currently with float32 values from -200 to +200)
      into one that meets the parameter's validity constraints.
    mutex_params: Python iterable of Python sets.  Each set gives a clique of
      mutually exclusive parameters (e.g., the 'probs' and 'logits' of a
      Categorical).  At most one parameter from each set will appear in the
      result.
    dtype: Dtype for generated parameters.

  Returns:
    params: A Hypothesis strategy for drawing Python `dict`s mapping parameter
      name to a `tf.Tensor`, `tf.Variable`, `tfp.util.DeferredTensor`, or
      `tfp.util.TransformedVariable`.  The batch shapes of the returned
      parameters broadcast together to the supplied `batch_shape`.  Only
      parameters whose names appear as keys in `params_event_ndims` will appear
      (but possibly not all of them, depending on `mutex_params`).
  """
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))

    params_event_ndims = params_event_ndims or {}
    remaining_params = set(params_event_ndims.keys())
    params_to_use = []
    while remaining_params:
        param = draw(hps.sampled_from(sorted(remaining_params)))
        params_to_use.append(param)
        remaining_params.remove(param)
        for mutex_set in mutex_params:
            if param in mutex_set:
                remaining_params -= mutex_set

    param_batch_shapes = draw(
        broadcasting_named_shapes(batch_shape, params_to_use))
    params_kwargs = dict()
    for param in params_to_use:
        param_batch_shape = param_batch_shapes[param]
        param_event_rank = params_event_ndims[param]
        param_shape = (tensorshape_util.as_list(param_batch_shape) +
                       [event_dim] * param_event_rank)

        # Reduce our risk of exceeding TF kernel broadcast limits.
        hp.assume(len(param_shape) < 6)

        # TODO(axch): Can I replace `params_event_ndims` and `constraint_fn_for`
        # with a map from params to `Suppport`s, and use `tensors_in_support` here
        # instead of this explicit `constrained_tensors` function?
        param_strategy = constrained_tensors(constraint_fn_for(param),
                                             param_shape,
                                             dtype=dtype)
        params_kwargs[param] = draw(
            maybe_variable(param_strategy,
                           enable_vars=enable_vars,
                           dtype=dtype,
                           name=param))
    return params_kwargs
Example #12
0
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in,
                                         event_shape_out, validate_args):
    """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
    output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
        tensorshape_util.constant_value_as_shape(input_shape), event_shape_in,
        event_shape_out)

    if (tensorshape_util.is_fully_defined(output_tensorshape)
            and (is_validated or not validate_args)):
        output_shape = ps.convert_to_shape_tensor(
            tensorshape_util.as_list(output_tensorshape),
            name='output_shape',
            dtype_hint=tf.int32)
        return output_shape, output_tensorshape

    event_shape_in_ndims = (
        tf.size(event_shape_in)
        if tensorshape_util.num_elements(event_shape_in.shape) is None else
        tensorshape_util.num_elements(event_shape_in.shape))
    input_non_event_shape, input_event_shape = tf.split(
        input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

    additional_assertions = []
    if is_validated:
        pass
    elif validate_args:
        # Check that `input_event_shape` and `event_shape_in` are compatible in the
        # sense that they have equal entries in any position that isn't a `-1` in
        # `event_shape_in`. Note that our validations at construction time ensure
        # there is at most one such entry in `event_shape_in`.
        mask = event_shape_in >= 0
        explicit_input_event_shape = tf.boolean_mask(input_event_shape,
                                                     mask=mask)
        explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
        additional_assertions.append(
            assert_util.assert_equal(
                explicit_input_event_shape,
                explicit_event_shape_in,
                message='Input `event_shape` does not match `event_shape_in`.')
        )
        # We don't explicitly additionally verify
        # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
        # already makes this assertion.

    with tf.control_dependencies(additional_assertions):
        output_shape = tf.concat([input_non_event_shape, event_shape_out],
                                 axis=0,
                                 name='output_shape')

    return output_shape, output_tensorshape