def _mean(self): with ops.control_dependencies(self._runtime_assertions): probs = distribution_utils.pad_mixture_dimensions( self.mixture_distribution.probs, self, self.mixture_distribution, self._event_shape().ndims) # [B, k, [1]*e] return math_ops.reduce_sum( probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E]
def _covariance(self): static_event_ndims = self.event_shape.ndims if static_event_ndims != 1: # Covariance is defined only for vector distributions. raise NotImplementedError("covariance is not implemented") with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) probs = distribution_utils.pad_mixture_dimensions( distribution_utils.pad_mixture_dimensions( self.mixture_distribution.probs, self, self.mixture_distribution, self._event_shape().ndims), self, self.mixture_distribution, self._event_shape().ndims) # [B, k, 1, 1] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.covariance(), axis=-3) # [B, e, e] var_cond_mean = math_ops.reduce_sum( probs * _outer_squared_difference( self.components_distribution.mean(), self._pad_sample_dims(self._mean())), axis=-3) # [B, e, e] return mean_cond_var + var_cond_mean # [B, e, e]
def test_pad_mixture_dimensions_mixture_same_family(self): with self.test_session() as sess: gm = mixture_same_family.MixtureSameFamily( mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) x_pad = distribution_util.pad_mixture_dimensions( x, gm, gm.mixture_distribution, gm.event_shape.ndims) x_out, x_pad_out = sess.run([x, x_pad]) self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
def _variance(self): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) probs = distribution_utils.pad_mixture_dimensions( self.mixture_distribution.probs, self, self.mixture_distribution, self._event_shape().ndims) # [B, k, [1]*e] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.variance(), axis=-1 - self._event_ndims) # [B, E] var_cond_mean = math_ops.reduce_sum( probs * math_ops.squared_difference( self.components_distribution.mean(), self._pad_sample_dims(self._mean())), axis=-1 - self._event_ndims) # [B, E] return mean_cond_var + var_cond_mean # [B, E]
def _sample_n(self, n, seed): with ops.control_dependencies(self._runtime_assertions): x = self.components_distribution.sample(n) # [n, B, k, E] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = x.dtype.as_numpy_dtype mask = array_ops.one_hot( indices=self.mixture_distribution.sample(n), # [n, B] depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] mask = distribution_utils.pad_mixture_dimensions( mask, self, self.mixture_distribution, self._event_shape().ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum( x * mask, axis=-1 - self._event_ndims) # [n, B, E]
def _sample_n(self, n, seed): with ops.control_dependencies(self._runtime_assertions): x = self.components_distribution.sample(n) # [n, B, k, E] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = x.dtype.as_numpy_dtype mask = array_ops.one_hot( indices=self.mixture_distribution.sample(n), # [n, B] depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] mask = distribution_utils.pad_mixture_dimensions( mask, self, self.mixture_distribution, self._event_shape().ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum(x * mask, axis=-1 - self._event_ndims) # [n, B, E]
def test_pad_mixture_dimensions_mixture(self): with self.test_session() as sess: gm = mixture.Mixture( cat=categorical.Categorical(probs=[[0.3, 0.7]]), components=[ normal.Normal(loc=[-1.0], scale=[1.0]), normal.Normal(loc=[1.0], scale=[0.5]) ]) x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) x_pad = distribution_util.pad_mixture_dimensions( x, gm, gm.cat, gm.event_shape.ndims) x_out, x_pad_out = sess.run([x, x_pad]) self.assertAllEqual(x_pad_out.shape, [2, 2]) self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
def _sample_n(self, n, seed=None): if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code path. samples = [] cat_samples = self.cat.sample(n, seed=seed) for c in range(self.num_components): seed = distribution_util.gen_new_seed(seed, "mixture") samples.append(self.components[c].sample(n, seed=seed)) x = array_ops.stack(samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] npdt = x.dtype.as_numpy_dtype mask = array_ops.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] mask = distribution_utils.pad_mixture_dimensions( mask, self, self._cat, self._static_event_shape.ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum( x * mask, axis=-1 - self._static_event_shape.ndims) # [n, B, E] with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.get_shape() if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = array_ops.shape(cat_samples) samples_size = array_ops.size(cat_samples) static_batch_shape = self.batch_shape if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape_tensor() batch_size = math_ops.reduce_prod(batch_shape) static_event_shape = self.event_shape if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = array_ops.reshape( math_ops.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = data_flow_ops.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = array_ops.reshape( array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = data_flow_ops.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] for c in range(self.num_components): n_class = array_ops.size(partitioned_samples_indices[c]) seed = distribution_util.gen_new_seed(seed, "mixture") samples_class_c = self.components[c].sample(n_class, seed=seed) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * math_ops.range(n_class) + partitioned_batch_indices[c]) samples_class_c = array_ops.reshape( samples_class_c, array_ops.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = array_ops.gather( samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = data_flow_ops.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = array_ops.reshape( lhs_flat_ret, array_ops.concat( [samples_shape, self.event_shape_tensor()], 0)) ret.set_shape( tensor_shape.TensorShape(static_samples_shape).concatenate( self.event_shape)) return ret
def _sample_n(self, n, seed=None): if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code path. samples = [] cat_samples = self.cat.sample(n, seed=seed) for c in range(self.num_components): seed = distribution_util.gen_new_seed(seed, "mixture") samples.append(self.components[c].sample(n, seed=seed)) x = array_ops.stack( samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] npdt = x.dtype.as_numpy_dtype mask = array_ops.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] mask = distribution_utils.pad_mixture_dimensions( mask, self, self._cat, self._static_event_shape.ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum( x * mask, axis=-1 - self._static_event_shape.ndims) # [n, B, E] with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.get_shape() if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = array_ops.shape(cat_samples) samples_size = array_ops.size(cat_samples) static_batch_shape = self.batch_shape if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape_tensor() batch_size = math_ops.reduce_prod(batch_shape) static_event_shape = self.event_shape if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = array_ops.reshape( math_ops.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = data_flow_ops.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = array_ops.reshape( array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = data_flow_ops.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] for c in range(self.num_components): n_class = array_ops.size(partitioned_samples_indices[c]) seed = distribution_util.gen_new_seed(seed, "mixture") samples_class_c = self.components[c].sample(n_class, seed=seed) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * math_ops.range(n_class) + partitioned_batch_indices[c]) samples_class_c = array_ops.reshape( samples_class_c, array_ops.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = array_ops.gather( samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = data_flow_ops.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = array_ops.reshape(lhs_flat_ret, array_ops.concat([samples_shape, self.event_shape_tensor()], 0)) ret.set_shape( tensor_shape.TensorShape(static_samples_shape).concatenate( self.event_shape)) return ret