def test_mixture_dev(self): mixture_weights = np.array([ [1.0/3, 1.0/3, 1.0/3], [0.750, 0.250, 0.000] ]) component_means = np.array([ [1.0, 1.0, 1.0], [-5, 0, 1.25] ]) component_devs = np.array([ [1.0, 1.0, 1.0], [0.01, 2.0, 0.1] ]) # The first case should trivially have a standard deviation of 1.0 because # all components are identical and have that standard deviation. # The second case was computed by hand. expected_devs = np.array([ 1.0, 2.3848637277 ]) weights_tf = array_ops.constant(mixture_weights) means_tf = array_ops.constant(component_means) sigmas_tf = array_ops.constant(component_devs) mix_dev = distribution_util.mixture_stddev(weights_tf, means_tf, sigmas_tf) with self.test_session() as sess: actual_devs = sess.run(mix_dev) self.assertAllClose(actual_devs, expected_devs)
def _stddev(self): with ops.control_dependencies(self._assertions): distribution_means = [d.mean() for d in self.components] distribution_devs = [d.stddev() for d in self.components] cat_probs = self._cat_probs(log_probs=False) stacked_means = array_ops.stack(distribution_means, axis=-1) stacked_devs = array_ops.stack(distribution_devs, axis=-1) cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs] broadcasted_cat_probs = (array_ops.stack(cat_probs, axis=-1) * array_ops.ones_like(stacked_means)) batched_dev = distribution_utils.mixture_stddev( array_ops.reshape(broadcasted_cat_probs, [-1, len(self.components)]), array_ops.reshape(stacked_means, [-1, len(self.components)]), array_ops.reshape(stacked_devs, [-1, len(self.components)])) # I.e. re-shape to list(batch_shape) + list(event_shape). return array_ops.reshape(batched_dev, array_ops.shape(broadcasted_cat_probs)[:-1])
def _stddev(self): with ops.control_dependencies(self._assertions): distribution_means = [d.mean() for d in self.components] distribution_devs = [d.stddev() for d in self.components] cat_probs = self._cat_probs(log_probs=False) stacked_means = array_ops.stack(distribution_means, axis=-1) stacked_devs = array_ops.stack(distribution_devs, axis=-1) cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs] broadcasted_cat_probs = (array_ops.stack(cat_probs, axis=-1) * array_ops.ones_like(stacked_means)) batched_dev = distribution_utils.mixture_stddev( array_ops.reshape(broadcasted_cat_probs, [-1, len(self.components)]), array_ops.reshape(stacked_means, [-1, len(self.components)]), array_ops.reshape(stacked_devs, [-1, len(self.components)])) # I.e. re-shape to list(batch_shape) + list(event_shape). return array_ops.reshape(batched_dev, array_ops.shape(broadcasted_cat_probs)[:-1])
def test_mixture_dev(self): mixture_weights = np.array([[1.0 / 3, 1.0 / 3, 1.0 / 3], [0.750, 0.250, 0.000]]) component_means = np.array([[1.0, 1.0, 1.0], [-5, 0, 1.25]]) component_devs = np.array([[1.0, 1.0, 1.0], [0.01, 2.0, 0.1]]) # The first case should trivially have a standard deviation of 1.0 because # all components are identical and have that standard deviation. # The second case was computed by hand. expected_devs = np.array([1.0, 2.3848637277]) weights_tf = array_ops.constant(mixture_weights) means_tf = array_ops.constant(component_means) sigmas_tf = array_ops.constant(component_devs) mix_dev = distribution_util.mixture_stddev(weights_tf, means_tf, sigmas_tf) with self.test_session() as sess: actual_devs = sess.run(mix_dev) self.assertAllClose(actual_devs, expected_devs)