Exemple #1
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)

        # Complexity: O(nbk**2)
        x = tf.random_normal(shape=shape,
                             mean=0.,
                             stddev=1.,
                             dtype=self.dtype,
                             seed=seed)

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(
            self.scale_operator.batch_shape_tensor(),
            dtype=self.df.dtype.base_dtype)
        g = tf.random_gamma(
            shape=[n],
            alpha=self._multi_gamma_sequence(0.5 * expanded_df,
                                             self.dimension),
            beta=0.5,
            dtype=self.dtype,
            seed=distribution_util.gen_new_seed(seed, "wishart"))

        # Complexity: O(nbk**2)
        x = tf.matrix_band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.matrix_set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(x, perm)
        shape = tf.concat([batch_shape, [event_shape[0]], [-1]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so
        # this complexity is O(nbk**2). For LinearOperatorLowerTriangular,
        # each matmul is O(k^3) so this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(x, perm)

        if not self.cholesky_input_output_matrices:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
Exemple #2
0
  def _sample_n(self, n, seed):
    batch_shape = self.batch_shape_tensor()
    event_shape = self.event_shape_tensor()
    batch_ndims = tf.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = tf.concat([[n], batch_shape, event_shape], 0)

    # Complexity: O(nbk**2)
    x = tf.random_normal(
        shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed)

    # Complexity: O(nbk)
    # This parametrization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    expanded_df = self.df * tf.ones(
        self.scale_operator.batch_shape_tensor(),
        dtype=self.df.dtype.base_dtype)
    g = tf.random_gamma(
        shape=[n],
        alpha=self._multi_gamma_sequence(0.5 * expanded_df, self.dimension),
        beta=0.5,
        dtype=self.dtype,
        seed=distribution_util.gen_new_seed(seed, "wishart"))

    # Complexity: O(nbk**2)
    x = tf.matrix_band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = tf.matrix_set_diag(x, tf.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk**2)
    perm = tf.concat([tf.range(1, ndims), [0]], 0)
    x = tf.transpose(x, perm)
    shape = tf.concat([batch_shape, [event_shape[0]], [-1]], 0)
    x = tf.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
    # this step has complexity O(nbk^3).
    x = self.scale_operator.matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk**2)
    shape = tf.concat([batch_shape, event_shape, [n]], 0)
    x = tf.reshape(x, shape)
    perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
    x = tf.transpose(x, perm)

    if not self.input_output_cholesky:
      # Complexity: O(nbk**3)
      x = tf.matmul(x, x, adjoint_b=True)

    return x
Exemple #3
0
 def _sample_n(self, n, seed=None):
     expanded_concentration1 = tf.ones_like(
         self.total_concentration, dtype=self.dtype) * self.concentration1
     expanded_concentration0 = tf.ones_like(
         self.total_concentration, dtype=self.dtype) * self.concentration0
     gamma1_sample = tf.random_gamma(shape=[n],
                                     alpha=expanded_concentration1,
                                     dtype=self.dtype,
                                     seed=seed)
     gamma2_sample = tf.random_gamma(shape=[n],
                                     alpha=expanded_concentration0,
                                     dtype=self.dtype,
                                     seed=util.gen_new_seed(seed, "beta"))
     beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
     return beta_sample
 def _sample_n(self, n, seed=None):
     n_draws = tf.cast(self.total_count, dtype=tf.int32)
     k = self.event_shape_tensor()[0]
     unnormalized_logits = tf.reshape(tf.log(
         tf.random_gamma(shape=[n],
                         alpha=self.concentration,
                         dtype=self.dtype,
                         seed=seed)),
                                      shape=[-1, k])
     draws = tf.multinomial(logits=unnormalized_logits,
                            num_samples=n_draws,
                            seed=distribution_util.gen_new_seed(
                                seed, salt="dirichlet_multinomial"))
     x = tf.reduce_sum(tf.one_hot(draws, depth=k), -2)
     final_shape = tf.concat([[n], self.batch_shape_tensor(), [k]], 0)
     x = tf.reshape(x, final_shape)
     return tf.cast(x, self.dtype)
Exemple #5
0
 def _sample_n(self, n, seed=None):
   # The sampling method comes from the fact that if:
   #   X ~ Normal(0, 1)
   #   Z ~ Chi2(df)
   #   Y = X / sqrt(Z / df)
   # then:
   #   Y ~ StudentT(df).
   shape = tf.concat([[n], self.batch_shape_tensor()], 0)
   normal_sample = tf.random_normal(shape, dtype=self.dtype, seed=seed)
   df = self.df * tf.ones(self.batch_shape_tensor(), dtype=self.dtype)
   gamma_sample = tf.random_gamma(
       [n],
       0.5 * df,
       beta=0.5,
       dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, salt="student_t"))
   samples = normal_sample * tf.rsqrt(gamma_sample / df)
   return samples * self.scale + self.loc  # Abs(scale) not wanted.
    def _sample_n(self, n, seed=None):
        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = self.batch_shape.num_elements()
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        # We need to "sample extra" from the mixture distribution if it doesn't
        # already specify a probs vector for each batch coordinate.
        # We only support this kind of reduced broadcasting, i.e., there is exactly
        # one probs vector for all batch dims or one for each.
        ids = self._mixture_distribution.sample(
            sample_shape=concat_vectors(
                [n],
                distribution_util.pick_vector(
                    self.mixture_distribution.is_scalar_batch(), [batch_size],
                    np.int32([]))),
            seed=distribution_util.gen_new_seed(
                seed, "poisson_lognormal_quadrature_compound"))
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `quadrature_size` for `batch_size` number of times.
        offset = tf.range(start=0,
                          limit=batch_size * self._quadrature_size,
                          delta=self._quadrature_size,
                          dtype=ids.dtype)
        ids += offset
        rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids)
        rate = tf.reshape(rate,
                          shape=concat_vectors([n], self.batch_shape_tensor()))
        return tf.random_poisson(lam=rate,
                                 shape=[],
                                 dtype=self.dtype,
                                 seed=seed)
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = tf.reduce_prod(self.batch_shape_tensor())
    # We need to "sample extra" from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    ids = self._mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.mixture_distribution.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=distribution_util.gen_new_seed(
            seed, "poisson_lognormal_quadrature_compound"))
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids += offset
    rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self.batch_shape_tensor()))
    return tf.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
Exemple #8
0
 def testOnlyNoneReturnsNone(self):
     self.assertIsNotNone(distribution_util.gen_new_seed(0, 'salt'))
     self.assertIsNone(distribution_util.gen_new_seed(None, 'salt'))
    def _sample_n(self, n, seed=None):
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=seed)  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = self.batch_shape.num_elements()
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = self.mixture_distribution.batch_shape.num_elements()
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(
            sample_shape=concat_vectors(
                [n],
                distribution_util.pick_vector(self.is_scalar_batch(),
                                              np.int32([]),
                                              [batch_size // mix_batch_size])),
            seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture"))
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = self.grid.shape.with_rank_at_least(2)[-2:].num_elements()
        if stride is None:
            stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if self.batch_shape.is_fully_defined():
            new_shape = [-1] + self.batch_shape.as_list() + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
  def _sample_n(self, n, seed=None):
    with ops.control_dependencies(self._assertions):
      n = ops.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      pi_samples = self.pi.sample(n, seed=seed)

      static_samples_shape = pi_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = array_ops.shape(pi_samples)
        samples_size = array_ops.size(pi_samples)
      static_batch_shape = self.batch_shape
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape_tensor()
        batch_size = math_ops.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape_tensor()

      # Get indices into the raw pi sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = array_ops.reshape(
          math_ops.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = data_flow_ops.dynamic_partition(
          data=samples_raw_indices,
          partitions=pi_samples,
          num_partitions=self.num_dist)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = array_ops.reshape(
          array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = data_flow_ops.dynamic_partition(
          data=batch_raw_indices,
          partitions=pi_samples,
          num_partitions=self.num_dist)
      samples_class = [None for _ in range(self.num_dist)]

      for c in range(self.num_dist):
        n_class = array_ops.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "ZeroInflated")
        samples_class_c = self.dist[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along lopiions (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * math_ops.range(n_class) +
            partitioned_batch_indices[c])
        samples_class_c = array_ops.reshape(
            samples_class_c,
            array_ops.conpi([[n_class * batch_size], event_shape], 0))
        samples_class_c = array_ops.gather(
            samples_class_c, lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the dist.
      lhs_flat_ret = data_flow_ops.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = array_ops.reshape(lhs_flat_ret,
                              array_ops.conpi([samples_shape,
                                               self.event_shape_tensor()], 0))
      ret.set_shape(
          tensor_shape.TensorShape(static_samples_shape).conpienate(
              self.event_shape))
      return ret
  def _sample_n(self, n, seed=None):
    x = self.distribution.sample(
        sample_shape=concat_vectors(
            [n],
            self.batch_shape_tensor(),
            self.event_shape_tensor()),
        seed=seed)   # shape: [n, B, e]
    x = [aff.forward(x) for aff in self.endpoint_affine]

    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = tf.reduce_prod(self.batch_shape_tensor())
    mix_batch_size = self.mixture_distribution.batch_shape.num_elements()
    if mix_batch_size is None:
      mix_batch_size = tf.reduce_prod(
          self.mixture_distribution.batch_shape_tensor())
    ids = self.mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.is_scalar_batch(),
                np.int32([]),
                [batch_size // mix_batch_size])),
        seed=distribution_util.gen_new_seed(
            seed, "vector_diffeomixture"))
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `components * quadrature_size` for `batch_size` number of times.
    stride = self.grid.shape.with_rank_at_least(
        2)[-2:].num_elements()
    if stride is None:
      stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
    offset = tf.range(
        start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype)

    weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
    # At this point, weight flattened all batch dims into one.
    # We also need to append a singleton to broadcast with event dims.
    if self.batch_shape.is_fully_defined():
      new_shape = [-1] + self.batch_shape.as_list() + [1]
    else:
      new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0)
    weight = tf.reshape(weight, shape=new_shape)

    if len(x) != 2:
      # We actually should have already triggered this exception. However as a
      # policy we're putting this exception wherever we exploit the bimixture
      # assumption.
      raise NotImplementedError("Currently only bimixtures are supported; "
                                "len(scale)={} is not 2.".format(len(x)))

    # Alternatively:
    # x = weight * x[0] + (1. - weight) * x[1]
    x = weight * (x[0] - x[1]) + x[1]

    return x