def test_regression_normal_log_prob_means_and_stddevs_2d(
            self, model_output_type):
        tensor_model_outputs = tf.constant([
            [[0.3, 0.4, np.log(0.01), np.log(0.02)],
             [1.6, 0.6, np.log(2.0), np.log(0.01)]],
            [[0.2, 0.2, np.log(0.1), np.log(0.2)],
             [0.8, 0.5, np.log(0.5), np.log(0.2)]],
            [[0.4, 0.6, np.log(1.0), np.log(1.5)],
             [2.4, 0.4, np.log(0.05), np.log(0.1)]],
        ])
        labels = tf.constant([[0.2, 0.4], [1.4, 1.0]])

        model_outputs = NewRegressionModelOutputs(
            tensor_model_outputs,
            model_output_type,
            outputs_with_log_stddevs=True)

        ens_reg_outputs = stats.RegressionOutputs(
            outputs_with_log_stddevs=True)
        ens_reg_outputs.update(model_outputs[0])
        ens_reg_outputs.update(model_outputs[1])
        ens_reg_outputs.update(model_outputs[2])
        means, variances = ens_reg_outputs.result()
        expected_nll = -tfd.Normal(means, variances**0.5).log_prob(labels)

        rnlls = stats.RegressionNormalLogProb(outputs_with_log_stddevs=True)
        rnlls.update(model_outputs[0], labels)
        rnlls.update(model_outputs[1], labels)
        rnlls.update(model_outputs[2], labels)
        nlls = rnlls.result()
        self.assertAllClose(expected_nll, nlls, atol=TOL)
    def test_regression_outputs_only_means_1d(self, model_output_type):
        tensor_model_outputs = tf.constant([
            [[0.3], [0.6]],  # Member 0, Example 0 and 1
            [[0.2], [0.5]],  # Member 1, Example 0 and 1
            [[0.4], [0.4]],  # Member 2, Example 0 and 1
        ])
        model_outputs = NewRegressionModelOutputs(tensor_model_outputs,
                                                  model_output_type)

        ens_reg_outputs = stats.RegressionOutputs()
        ens_reg_outputs.update(model_outputs[0])
        ens_reg_outputs.update(model_outputs[1])
        ens_reg_outputs.update(model_outputs[2])
        means, variances = ens_reg_outputs.result()

        self.assertAlmostEqual(0.3, float(means[0][0]), delta=TOL)
        self.assertAlmostEqual(0.5, float(means[1][0]), delta=TOL)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.3, 0.2, 0.4],
            stddevs=[1.0, 1.0, 1.0])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[0][0]),
                               delta=1e-5)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.6, 0.5, 0.4],
            stddevs=[1.0, 1.0, 1.0])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[1][0]),
                               delta=1e-5)
    def test_regression_outputs_only_means_2d_diff_stddev(
            self, model_output_type):

        tensor_model_outputs = tf.constant([
            [[0.3, 0.4], [1.6, 0.6]],  # Member 0, Example 0 and 1
            [[0.2, 0.2], [0.8, 0.5]],  # Member 1, Example 0 and 1
            [[0.4, 0.6], [2.4, 0.4]],  # Member 2, Example 0 and 1
        ])

        model_outputs = NewRegressionModelOutputs(tensor_model_outputs,
                                                  model_output_type,
                                                  stddev=0.1)

        ens_reg_outputs = stats.RegressionOutputs(stddev=0.1)
        ens_reg_outputs.update(model_outputs[0])
        ens_reg_outputs.update(model_outputs[1])
        ens_reg_outputs.update(model_outputs[2])
        means, variances = ens_reg_outputs.result()

        self.assertAlmostEqual(0.3, float(means[0][0]), delta=TOL)
        self.assertAlmostEqual(0.4, float(means[0][1]), delta=TOL)
        self.assertAlmostEqual(1.6, float(means[1][0]), delta=TOL)
        self.assertAlmostEqual(0.5, float(means[1][1]), delta=TOL)

        # Expected mixture, does not have to use normal distributions
        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.3, 0.2, 0.4],
            stddevs=[0.1, 0.1, 0.1])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[0][0]),
                               delta=1e-5)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.4, 0.2, 0.6],
            stddevs=[0.1, 0.1, 0.1])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[0][1]),
                               delta=1e-5)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[1.6, 0.8, 2.4],
            stddevs=[0.1, 0.1, 0.1])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[1][0]),
                               delta=1e-5)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.6, 0.5, 0.4],
            stddevs=[0.1, 0.1, 0.1])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[1][1]),
                               delta=1e-5)
    def test_regression_outputs_means_and_variances_2d(self,
                                                       model_output_type):
        tensor_model_outputs = tf.constant([
            [  # member 0 tensor_model_outputs
                [0.3, 0.4, np.log(0.01), np.log(0.02)],  # Example 0
                [1.6, 0.6, np.log(2.0), np.log(0.01)],  # Example 1
            ],
            [  # member 1 tensor_model_outputs
                [0.2, 0.2, np.log(0.1), np.log(0.2)],  # Example 0
                [0.8, 0.5, np.log(0.5), np.log(0.2)],  # Example 1
            ],
            [  # member 2 tensor_model_outputs
                [0.4, 0.6, np.log(1.0), np.log(1.5)],  # Example 0
                [2.4, 0.4, np.log(0.05), np.log(0.1)],  # Example 1
            ]
        ])
        model_outputs = NewRegressionModelOutputs(
            tensor_model_outputs,
            model_output_type,
            outputs_with_log_stddevs=True)
        ens_reg_outputs = stats.RegressionOutputs(
            outputs_with_log_stddevs=True)
        ens_reg_outputs.update(model_outputs[0])  # Member 0 outputs
        ens_reg_outputs.update(model_outputs[1])  # Member 1 outputs
        ens_reg_outputs.update(model_outputs[2])  # Member 2 outputs
        means, variances = ens_reg_outputs.result()

        self.assertAlmostEqual(0.3, float(means[0][0]), delta=TOL)
        self.assertAlmostEqual(0.4, float(means[0][1]), delta=TOL)
        self.assertAlmostEqual(1.6, float(means[1][0]), delta=TOL)
        self.assertAlmostEqual(0.5, float(means[1][1]), delta=TOL)

        # Expected mixture, does not have to use normal distributions
        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.3, 0.2, 0.4],
            stddevs=[0.01, 0.1, 1.0])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[0][0]),
                               delta=1e-5)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.4, 0.2, 0.6],
            stddevs=[0.02, 0.2, 1.5])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[0][1]),
                               delta=1e-5)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[1.6, 0.8, 2.4],
            stddevs=[2.0, 0.5, 0.05])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[1][0]),
                               delta=1e-5)

        expected_variance = self._get_mixture_variance(
            probs=[1 / 3, 1 / 3, 1 / 3],
            means=[0.6, 0.5, 0.4],
            stddevs=[0.01, 0.2, 0.1])
        self.assertAlmostEqual(float(expected_variance),
                               float(variances[1][1]),
                               delta=1e-5)