Пример #1
0
  def _sample_n(self, n, seed):
    batch_shape = self.batch_shape_tensor()
    event_shape = self.event_shape_tensor()
    batch_ndims = array_ops.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = array_ops.concat([[n], batch_shape, event_shape], 0)

    # Complexity: O(nbk**2)
    x = random_ops.random_normal(shape=shape,
                                 mean=0.,
                                 stddev=1.,
                                 dtype=self.dtype,
                                 seed=seed)

    # Complexity: O(nbk)
    # This parametrization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    expanded_df = self.df * array_ops.ones(
        self.scale_operator.batch_shape_tensor(),
        dtype=self.df.dtype.base_dtype)
    g = random_ops.random_gamma(shape=[n],
                                alpha=self._multi_gamma_sequence(
                                    0.5 * expanded_df, self.dimension),
                                beta=0.5,
                                dtype=self.dtype,
                                seed=distribution_util.gen_new_seed(
                                    seed, "wishart"))

    # Complexity: O(nbk**2)
    x = array_ops.matrix_band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = array_ops.matrix_set_diag(x, math_ops.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk**2)
    perm = array_ops.concat([math_ops.range(1, ndims), [0]], 0)
    x = array_ops.transpose(x, perm)
    shape = array_ops.concat([batch_shape, [event_shape[0]], [-1]], 0)
    x = array_ops.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so
    # this complexity is O(nbk**2). For LinearOperatorLowerTriangular,
    # each matmul is O(k^3) so this step has complexity O(nbk^3).
    x = self.scale_operator.matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk**2)
    shape = array_ops.concat([batch_shape, event_shape, [n]], 0)
    x = array_ops.reshape(x, shape)
    perm = array_ops.concat([[ndims - 1], math_ops.range(0, ndims - 1)], 0)
    x = array_ops.transpose(x, perm)

    if not self.cholesky_input_output_matrices:
      # Complexity: O(nbk^3)
      x = math_ops.matmul(x, x, adjoint_b=True)

    return x
Пример #2
0
 def _sample_n(self, n, seed=None):
   # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
   # ids as a [n]-shaped vector.
   batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32)
                 if self.batch_shape.is_fully_defined()
                 else math_ops.reduce_prod(self.batch_shape_tensor()))
   ids = self._mixture_distribution.sample(
       sample_shape=concat_vectors(
           [n],
           distribution_util.pick_vector(
               self.is_scalar_batch(),
               np.int32([]),
               [batch_size])),
       seed=distribution_util.gen_new_seed(
           seed, "poisson_lognormal_quadrature_compound"))
   # Stride `quadrature_size` for `batch_size` number of times.
   offset = math_ops.range(start=0,
                           limit=batch_size * self._quadrature_size,
                           delta=self._quadrature_size,
                           dtype=ids.dtype)
   ids += offset
   rate = array_ops.gather(
       array_ops.reshape(self.distribution.rate, shape=[-1]), ids)
   rate = array_ops.reshape(
       rate, shape=concat_vectors([n], self.batch_shape_tensor()))
   return random_ops.random_poisson(
       lam=rate, shape=[], dtype=self.dtype, seed=seed)
Пример #3
0
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
   if self.n.get_shape().ndims is not None:
     if self.n.get_shape().ndims != 0:
       raise NotImplementedError(
           "Sample only supported for scalar number of draws.")
   elif self.validate_args:
     is_scalar = check_ops.assert_rank(
         n_draws, 0,
         message="Sample only supported for scalar number of draws.")
     n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
   k = self.event_shape()[0]
   unnormalized_logits = array_ops.reshape(
       math_ops.log(random_ops.random_gamma(
           shape=[n],
           alpha=self.alpha,
           dtype=self.dtype,
           seed=seed)),
       shape=[-1, k])
   draws = random_ops.multinomial(
       logits=unnormalized_logits,
       num_samples=n_draws,
       seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
   x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                           reduction_indices=-2)
   final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
   return array_ops.reshape(x, final_shape)
Пример #4
0
  def _sample_n(self, n, seed):
    batch_shape = self.batch_shape()
    event_shape = self.event_shape()
    batch_ndims = array_ops.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = array_ops.concat(((n,), batch_shape, event_shape), 0)

    # Complexity: O(nbk^2)
    x = random_ops.random_normal(shape=shape,
                                 mean=0.,
                                 stddev=1.,
                                 dtype=self.dtype,
                                 seed=seed)

    # Complexity: O(nbk)
    # This parametrization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    g = random_ops.random_gamma(shape=(n,),
                                alpha=self._multi_gamma_sequence(
                                    0.5 * self.df, self.dimension),
                                beta=0.5,
                                dtype=self.dtype,
                                seed=distribution_util.gen_new_seed(
                                    seed, "wishart"))

    # Complexity: O(nbk^2)
    x = array_ops.matrix_band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = array_ops.matrix_set_diag(x, math_ops.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk^2)
    perm = array_ops.concat((math_ops.range(1, ndims), (0,)), 0)
    x = array_ops.transpose(x, perm)
    shape = array_ops.concat((batch_shape, (event_shape[0], -1)), 0)
    x = array_ops.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
    # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
    # O(k^3) so this step has complexity O(nbk^3).
    x = self.scale_operator_pd.sqrt_matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk^2)
    shape = array_ops.concat((batch_shape, event_shape, (n,)), 0)
    x = array_ops.reshape(x, shape)
    perm = array_ops.concat(((ndims - 1,), math_ops.range(0, ndims - 1)), 0)
    x = array_ops.transpose(x, perm)

    if not self.cholesky_input_output_matrices:
      # Complexity: O(nbk^3)
      x = math_ops.matmul(x, x, adjoint_b=True)

    return x
Пример #5
0
 def _sample_n(self, n, seed=None):
   a = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.a
   b = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.b
   gamma1_sample = random_ops.random_gamma(
       [n,], a, dtype=self.dtype, seed=seed)
   gamma2_sample = random_ops.random_gamma(
       [n,], b, dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, "beta"))
   beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
   return beta_sample
Пример #6
0
 def _sample_n(self, n, seed=None):
   # The sampling method comes from the well known fact that if X ~ Normal(0,
   # 1), and Z ~ Chi2(df), then X / sqrt(Z / df) ~ StudentT(df).
   shape = array_ops.concat(0, ([n], self.batch_shape()))
   normal_sample = random_ops.random_normal(
       shape, dtype=self.dtype, seed=seed)
   half = constant_op.constant(0.5, self.dtype)
   df = self.df * array_ops.ones(self.batch_shape(), dtype=self.dtype)
   gamma_sample = random_ops.random_gamma(
       [n,], half * df, beta=half, dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, salt="student_t"))
   samples = normal_sample / math_ops.sqrt(gamma_sample / df)
   return samples * self.sigma + self.mu
Пример #7
0
  def _sample_n(self, n, seed=None):
    # Here we use the fact that if:
    # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
    # then X ~ Poisson(lam) is Negative Binomially distributed.
    rate = random_ops.random_gamma(
        shape=[n],
        alpha=self.total_count,
        beta=math_ops.exp(-self.logits),
        dtype=self.dtype,
        seed=seed)

    return random_ops.random_poisson(
        rate,
        shape=[],
        dtype=self.dtype,
        seed=distribution_util.gen_new_seed(seed, "negative_binom"))
Пример #8
0
 def _sample_n(self, n, seed=None):
   # The sampling method comes from the fact that if:
   #   X ~ Normal(0, 1)
   #   Z ~ Chi2(df)
   #   Y = X / sqrt(Z / df)
   # then:
   #   Y ~ StudentT(df).
   shape = array_ops.concat_v2([[n], self.batch_shape()], 0)
   normal_sample = random_ops.random_normal(
       shape, dtype=self.dtype, seed=seed)
   df = self.df * array_ops.ones(self.batch_shape(), dtype=self.dtype)
   gamma_sample = random_ops.random_gamma(
       [n], 0.5 * df, beta=0.5, dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, salt="student_t"))
   samples = normal_sample / math_ops.sqrt(gamma_sample / df)
   return samples * self.sigma + self.mu
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
   k = self.event_shape_tensor()[0]
   unnormalized_logits = array_ops.reshape(
       math_ops.log(random_ops.random_gamma(
           shape=[n],
           alpha=self.concentration,
           dtype=self.dtype,
           seed=seed)),
       shape=[-1, k])
   draws = random_ops.multinomial(
       logits=unnormalized_logits,
       num_samples=n_draws,
       seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
   x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
   final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
   return array_ops.reshape(x, final_shape)
Пример #10
0
 def _sample_n(self, n, seed=None):
   expanded_concentration1 = array_ops.ones_like(
       self.total_concentration, dtype=self.dtype) * self.concentration1
   expanded_concentration0 = array_ops.ones_like(
       self.total_concentration, dtype=self.dtype) * self.concentration0
   gamma1_sample = random_ops.random_gamma(
       shape=[n],
       alpha=expanded_concentration1,
       dtype=self.dtype,
       seed=seed)
   gamma2_sample = random_ops.random_gamma(
       shape=[n],
       alpha=expanded_concentration0,
       dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, "beta"))
   beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
   return beta_sample
Пример #11
0
  def _sample_n(self, n, seed=None):
    x = self.distribution.sample(
        sample_shape=concat_vectors(
            [n],
            self.batch_shape_tensor(),
            self.event_shape_tensor()),
        seed=seed)   # shape: [n, B, e]
    x = [aff.forward(x) for aff in self.endpoint_affine]

    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = reduce_prod(self.batch_shape_tensor())
    ids = self._mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.is_scalar_batch(),
                np.int32([]),
                [batch_size])),
        seed=distribution_util.gen_new_seed(
            seed, "vector_diffeomixture"))

    # Stride `quadrature_degree` for `batch_size` number of times.
    offset = math_ops.range(start=0,
                            limit=batch_size * len(self.quadrature_probs),
                            delta=len(self.quadrature_probs),
                            dtype=ids.dtype)

    weight = array_ops.gather(
        array_ops.reshape(self.interpolate_weight, shape=[-1]),
        ids + offset)
    weight = weight[..., array_ops.newaxis]

    if len(x) != 2:
      # We actually should have already triggered this exception. However as a
      # policy we're putting this exception wherever we exploit the bimixture
      # assumption.
      raise NotImplementedError("Currently only bimixtures are supported; "
                                "len(scale)={} is not 2.".format(len(x)))

    # Alternatively:
    # x = weight * x[0] + (1. - weight) * x[1]
    x = weight * (x[0] - x[1]) + x[1]

    return x
Пример #12
0
  def _sample_n(self, n, seed=None):
    batch_size = reduce_prod(self.batch_shape_tensor())
    x = self.distribution.sample(
        sample_shape=concat_vectors(
            [n * batch_size],
            self.event_shape_tensor()),
        seed=seed)
    x = [array_ops.reshape(
        aff.forward(x),
        shape=concat_vectors(
            [-1],
            self.batch_shape_tensor(),
            self.event_shape_tensor()))
         for aff in self.endpoint_affine]

    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    ids = self._mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.is_scalar_batch(),
                np.int32([]),
                [batch_size])),
        seed=distribution_util.gen_new_seed(
            seed, "vector_diffeomixture"))

    # Stride `self._degree` for `batch_size` number of times.
    offset = math_ops.range(start=0,
                            limit=batch_size * self._degree,
                            delta=self._degree,
                            dtype=ids.dtype)

    weight = array_ops.gather(
        array_ops.reshape(self.interpolate_weight, shape=[-1]),
        ids + offset)
    weight = weight[..., array_ops.newaxis]

    # Alternatively:
    # x = weight * x[0] + (1. - weight) * x[1]
    x = weight * (x[0] - x[1]) + x[1]

    return x
Пример #13
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = math_ops.reduce_prod(self.batch_shape_tensor())
    # We need to "sample extra" from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    ids = self._mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.mixture_distribution.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=distribution_util.gen_new_seed(
            seed, "poisson_lognormal_quadrature_compound"))
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = array_ops.reshape(ids, shape=concat_vectors(
        [n],
        distribution_util.pick_vector(
            self.is_scalar_batch(),
            np.int32([]),
            np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = math_ops.range(start=0,
                            limit=batch_size * self._quadrature_size,
                            delta=self._quadrature_size,
                            dtype=ids.dtype)
    ids += offset
    rate = array_ops.gather(
        array_ops.reshape(self.distribution.rate, shape=[-1]), ids)
    rate = array_ops.reshape(
        rate, shape=concat_vectors([n], self.batch_shape_tensor()))
    return random_ops.random_poisson(
        lam=rate, shape=[], dtype=self.dtype, seed=seed)
Пример #14
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = math_ops.reduce_prod(self.batch_shape_tensor())
    # We need to "sample extra" from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    ids = self._mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.mixture_distribution.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=distribution_util.gen_new_seed(
            seed, "poisson_lognormal_quadrature_compound"))
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = array_ops.reshape(ids, shape=concat_vectors(
        [n],
        distribution_util.pick_vector(
            self.is_scalar_batch(),
            np.int32([]),
            np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = math_ops.range(start=0,
                            limit=batch_size * self._quadrature_size,
                            delta=self._quadrature_size,
                            dtype=ids.dtype)
    ids += offset
    rate = array_ops.gather(
        array_ops.reshape(self.distribution.rate, shape=[-1]), ids)
    rate = array_ops.reshape(
        rate, shape=concat_vectors([n], self.batch_shape_tensor()))
    return random_ops.random_poisson(
        lam=rate, shape=[], dtype=self.dtype, seed=seed)
Пример #15
0
    def _sample_n(self, n, seed=None):
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=seed)  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = reduce_prod(self.batch_shape_tensor())
        ids = self._mixture_distribution.sample(
            sample_shape=concat_vectors([n],
                                        distribution_util.pick_vector(
                                            self.is_scalar_batch(),
                                            np.int32([]), [batch_size])),
            seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture"))

        # Stride `quadrature_size` for `batch_size` number of times.
        offset = math_ops.range(start=0,
                                limit=batch_size * self._quadrature_size,
                                delta=self._quadrature_size,
                                dtype=ids.dtype)

        weight = array_ops.gather(
            array_ops.reshape(self.interpolate_weight, shape=[-1]),
            ids + offset)
        weight = weight[..., array_ops.newaxis]

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Пример #16
0
    def _sample_n(self, n, seed=None):
        batch_size = reduce_prod(self.batch_shape_tensor())
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n * batch_size], self.event_shape_tensor()),
                                     seed=seed)
        x = [
            array_ops.reshape(aff.forward(x),
                              shape=concat_vectors([-1],
                                                   self.batch_shape_tensor(),
                                                   self.event_shape_tensor()))
            for aff in self.endpoint_affine
        ]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        ids = self._mixture_distribution.sample(
            sample_shape=concat_vectors([n],
                                        distribution_util.pick_vector(
                                            self.is_scalar_batch(),
                                            np.int32([]), [batch_size])),
            seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture"))

        # Stride `self._degree` for `batch_size` number of times.
        offset = math_ops.range(start=0,
                                limit=batch_size * self._degree,
                                delta=self._degree,
                                dtype=ids.dtype)

        weight = array_ops.gather(
            array_ops.reshape(self.interpolate_weight, shape=[-1]),
            ids + offset)
        weight = weight[..., array_ops.newaxis]

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Пример #17
0
  def _sample_n(self, n, seed=None):
    x = self.distribution.sample(
        sample_shape=concat_vectors(
            [n],
            self.batch_shape_tensor(),
            self.event_shape_tensor()),
        seed=seed)   # shape: [n, B, e]
    x = [aff.forward(x) for aff in self.endpoint_affine]

    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = array_ops.reduce_prod(self.batch_shape_tensor())
    mix_batch_size = self.mixture_distribution.batch_shape.num_elements()
    if mix_batch_size is None:
      mix_batch_size = math_ops.reduce_prod(
          self.mixture_distribution.batch_shape_tensor())
    ids = self.mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.is_scalar_batch(),
                np.int32([]),
                [batch_size // mix_batch_size])),
        seed=distribution_util.gen_new_seed(
            seed, "vector_diffeomixture"))
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = array_ops.reshape(ids, shape=concat_vectors(
        [n],
        distribution_util.pick_vector(
            self.is_scalar_batch(),
            np.int32([]),
            np.int32([-1]))))

    # Stride `components * quadrature_size` for `batch_size` number of times.
    stride = self.grid.shape.with_rank_at_least(
        2)[-2:].num_elements()
    if stride is None:
      stride = array_ops.reduce_prod(
          array_ops.shape(self.grid)[-2:])
    offset = math_ops.range(start=0,
                            limit=batch_size * stride,
                            delta=stride,
                            dtype=ids.dtype)

    weight = array_ops.gather(
        array_ops.reshape(self.grid, shape=[-1]),
        ids + offset)
    # At this point, weight flattened all batch dims into one.
    # We also need to append a singleton to broadcast with event dims.
    if self.batch_shape.is_fully_defined():
      new_shape = [-1] + self.batch_shape.as_list() + [1]
    else:
      new_shape = array_ops.concat(
          ([-1], self.batch_shape_tensor(), [1]), axis=0)
    weight = array_ops.reshape(weight, shape=new_shape)

    if len(x) != 2:
      # We actually should have already triggered this exception. However as a
      # policy we're putting this exception wherever we exploit the bimixture
      # assumption.
      raise NotImplementedError("Currently only bimixtures are supported; "
                                "len(scale)={} is not 2.".format(len(x)))

    # Alternatively:
    # x = weight * x[0] + (1. - weight) * x[1]
    x = weight * (x[0] - x[1]) + x[1]

    return x
Пример #18
0
    def _sample_n(self, n, seed=None):
        with ops.control_dependencies(self._assertions):
            n = ops.convert_to_tensor(n, name="n")
            static_n = tensor_util.constant_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.get_shape()
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = array_ops.shape(cat_samples)
                samples_size = array_ops.size(cat_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = math_ops.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, samples_size), samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = array_ops.reshape(
                array_ops.tile(math_ops.range(0, batch_size), [n]),
                samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = data_flow_ops.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            for c in range(self.num_components):
                n_class = array_ops.size(partitioned_samples_indices[c])
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples_class_c = self.components[c].sample(n_class, seed=seed)

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * math_ops.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = array_ops.reshape(
                    samples_class_c,
                    array_ops.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = array_ops.gather(
                    samples_class_c,
                    lookup_partitioned_batch_indices,
                    name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = array_ops.reshape(
                lhs_flat_ret,
                array_ops.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tensor_shape.TensorShape(static_samples_shape).concatenate(
                    self.event_shape))
            return ret
Пример #19
0
    def _sample_n(self, n, seed=None):
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=seed)  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = self.batch_shape.num_elements()
        if batch_size is None:
            batch_size = array_ops.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = self.mixture_distribution.batch_shape.num_elements()
        if mix_batch_size is None:
            mix_batch_size = math_ops.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(
            sample_shape=concat_vectors(
                [n],
                distribution_util.pick_vector(self.is_scalar_batch(),
                                              np.int32([]),
                                              [batch_size // mix_batch_size])),
            seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture"))
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = array_ops.reshape(ids,
                                shape=concat_vectors(
                                    [n],
                                    distribution_util.pick_vector(
                                        self.is_scalar_batch(), np.int32([]),
                                        np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = self.grid.shape.with_rank_at_least(2)[-2:].num_elements()
        if stride is None:
            stride = array_ops.reduce_prod(array_ops.shape(self.grid)[-2:])
        offset = math_ops.range(start=0,
                                limit=batch_size * stride,
                                delta=stride,
                                dtype=ids.dtype)

        weight = array_ops.gather(array_ops.reshape(self.grid, shape=[-1]),
                                  ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if self.batch_shape.is_fully_defined():
            new_shape = [-1] + self.batch_shape.as_list() + [1]
        else:
            new_shape = array_ops.concat(
                ([-1], self.batch_shape_tensor(), [1]), axis=0)
        weight = array_ops.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
 def testOnlyNoneReturnsNone(self):
   self.assertFalse(distribution_util.gen_new_seed(0, "salt") is None)
   self.assertTrue(distribution_util.gen_new_seed(None, "salt") is None)
Пример #21
0
  def _sample_n(self, n, seed=None):
    with ops.control_dependencies(self._assertions):
      n = ops.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = array_ops.shape(cat_samples)
        samples_size = array_ops.size(cat_samples)
      static_batch_shape = self.get_batch_shape()
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape()
        batch_size = array_ops.reduce_prod(batch_shape)
      static_event_shape = self.get_event_shape()
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape()

      # Get indices into the raw cat sampling tensor.  We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = array_ops.reshape(
          math_ops.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = data_flow_ops.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = array_ops.reshape(
          array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = data_flow_ops.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = array_ops.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index.  The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * math_ops.range(n_class) +
            partitioned_batch_indices[c])
        samples_class_c = array_ops.reshape(
            samples_class_c,
            array_ops.concat(([n_class * batch_size], event_shape), 0))
        samples_class_c = array_ops.gather(
            samples_class_c, lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = data_flow_ops.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = array_ops.reshape(lhs_flat_ret,
                              array_ops.concat((samples_shape,
                                                self.event_shape()), 0))
      ret.set_shape(
          tensor_shape.TensorShape(static_samples_shape).concatenate(
              self.get_event_shape()))
      return ret
 def testOnlyNoneReturnsNone(self):
     self.assertFalse(distribution_util.gen_new_seed(0, "salt") is None)
     self.assertTrue(distribution_util.gen_new_seed(None, "salt") is None)