Ejemplo n.º 1
0
  def test_mixture_dev(self):
    mixture_weights = np.array([
        [1.0/3, 1.0/3, 1.0/3],
        [0.750, 0.250, 0.000]
    ])
    component_means = np.array([
        [1.0, 1.0, 1.0],
        [-5, 0, 1.25]
    ])
    component_devs = np.array([
        [1.0, 1.0, 1.0],
        [0.01, 2.0, 0.1]
    ])

    # The first case should trivially have a standard deviation of 1.0 because
    # all components are identical and have that standard deviation.
    # The second case was computed by hand.
    expected_devs = np.array([
        1.0,
        2.3848637277
    ])

    weights_tf = array_ops.constant(mixture_weights)
    means_tf = array_ops.constant(component_means)
    sigmas_tf = array_ops.constant(component_devs)
    mix_dev = distribution_util.mixture_stddev(weights_tf,
                                               means_tf,
                                               sigmas_tf)

    with self.test_session() as sess:
      actual_devs = sess.run(mix_dev)

    self.assertAllClose(actual_devs, expected_devs)
Ejemplo n.º 2
0
  def _stddev(self):
    with ops.control_dependencies(self._assertions):
      distribution_means = [d.mean() for d in self.components]
      distribution_devs = [d.stddev() for d in self.components]
      cat_probs = self._cat_probs(log_probs=False)

      stacked_means = array_ops.stack(distribution_means, axis=-1)
      stacked_devs = array_ops.stack(distribution_devs, axis=-1)
      cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs]
      broadcasted_cat_probs = (array_ops.stack(cat_probs, axis=-1) *
                               array_ops.ones_like(stacked_means))

      batched_dev = distribution_utils.mixture_stddev(
          array_ops.reshape(broadcasted_cat_probs, [-1, len(self.components)]),
          array_ops.reshape(stacked_means, [-1, len(self.components)]),
          array_ops.reshape(stacked_devs, [-1, len(self.components)]))

      # I.e. re-shape to list(batch_shape) + list(event_shape).
      return array_ops.reshape(batched_dev,
                               array_ops.shape(broadcasted_cat_probs)[:-1])
Ejemplo n.º 3
0
  def _stddev(self):
    with ops.control_dependencies(self._assertions):
      distribution_means = [d.mean() for d in self.components]
      distribution_devs = [d.stddev() for d in self.components]
      cat_probs = self._cat_probs(log_probs=False)

      stacked_means = array_ops.stack(distribution_means, axis=-1)
      stacked_devs = array_ops.stack(distribution_devs, axis=-1)
      cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs]
      broadcasted_cat_probs = (array_ops.stack(cat_probs, axis=-1) *
                               array_ops.ones_like(stacked_means))

      batched_dev = distribution_utils.mixture_stddev(
          array_ops.reshape(broadcasted_cat_probs, [-1, len(self.components)]),
          array_ops.reshape(stacked_means, [-1, len(self.components)]),
          array_ops.reshape(stacked_devs, [-1, len(self.components)]))

      # I.e. re-shape to list(batch_shape) + list(event_shape).
      return array_ops.reshape(batched_dev,
                               array_ops.shape(broadcasted_cat_probs)[:-1])
Ejemplo n.º 4
0
    def test_mixture_dev(self):
        mixture_weights = np.array([[1.0 / 3, 1.0 / 3, 1.0 / 3],
                                    [0.750, 0.250, 0.000]])
        component_means = np.array([[1.0, 1.0, 1.0], [-5, 0, 1.25]])
        component_devs = np.array([[1.0, 1.0, 1.0], [0.01, 2.0, 0.1]])

        # The first case should trivially have a standard deviation of 1.0 because
        # all components are identical and have that standard deviation.
        # The second case was computed by hand.
        expected_devs = np.array([1.0, 2.3848637277])

        weights_tf = array_ops.constant(mixture_weights)
        means_tf = array_ops.constant(component_means)
        sigmas_tf = array_ops.constant(component_devs)
        mix_dev = distribution_util.mixture_stddev(weights_tf, means_tf,
                                                   sigmas_tf)

        with self.test_session() as sess:
            actual_devs = sess.run(mix_dev)

        self.assertAllClose(actual_devs, expected_devs)