示例#1
0
    def start_bundle(self):
        # Build the model.
        g = tf.Graph()
        with g.as_default():
            example_placeholder = tf.placeholder(tf.string, shape=[])
            parsed_features = kepler_light_curves.parse_example(
                example_placeholder)
            parsed_example_id = parsed_features.pop("example_id")
            parsed_time = parsed_features.pop("time")
            features = {
                # Add extra dimensions: [length] -> [1, length, 1].
                feature_name: tf.reshape(value, [1, -1, 1])
                for feature_name, value in parsed_features.items()
            }
            model = astrowavenet_model.AstroWaveNet(
                features=features,
                hparams=self.config.hparams,
                mode=tf.estimator.ModeKeys.PREDICT)
            model.build()
            saver = tf.train.Saver()

        sess = tf.Session(graph=g)
        saver.restore(sess, self.checkpoint_file)
        tf.logging.info("Successfully loaded checkpoint %s at global step %d.",
                        self.checkpoint_file, sess.run(model.global_step))

        self.example_placeholder = example_placeholder
        self.parsed_example_id = parsed_example_id
        self.parsed_time = parsed_time
        self.model = model
        self.session = sess
    def test_causality(self):
        time_series_length = 7
        input_num_features = 1
        context_num_features = 1

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 1,
            "skip_output_dim": 1,
            "preprocess_output_size": 1,
            "preprocess_kernel_width": 1,
            "num_residual_blocks": 1,
            "dilation_rates": [1],
            "output_distribution": {
                "type": "normal",
                "min_scale": 0.001,
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            feed_dict = {
                input_placeholder: [
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[1], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [1], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [1]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                ],
                context_placeholder: [
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    [[1], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [1], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [1]],
                ],
            }
            network_output = sess.run(model.network_output,
                                      feed_dict=feed_dict)
            np.testing.assert_array_equal(
                [
                    [[0], [0], [0], [0], [0], [0], [0]],
                    # Input elements are used to predict the next timestamp.
                    [[0], [1], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [0], [1], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [0]],
                    # Context elements are used to predict the current timestamp.
                    [[1], [0], [0], [0], [0], [0], [0]],
                    [[0], [0], [0], [1], [0], [0], [0]],
                    [[0], [0], [0], [0], [0], [0], [1]],
                ],
                np.greater(np.abs(network_output), 0))
    def test_output_weighted(self):
        time_series_length = 6
        input_num_features = 2
        context_num_features = 7

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        weights_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "weights": weights_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "normal",
                "min_scale": 0,
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            feed_dict = {
                input_placeholder: [
                    [[1, 9], [1, 9], [1, 9], [1, 9], [1, 9], [1, 9]],
                    [[2, 8], [2, 8], [2, 8], [2, 8], [2, 8], [2, 8]],
                    [[3, 7], [3, 7], [3, 7], [3, 7], [3, 7], [3, 7]],
                ],
                weights_placeholder: [
                    [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]],
                    [[1, 0], [1, 1], [1, 1], [0, 1], [0, 1], [0, 0]],
                    [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],
                ],
                # Context is not needed since we explicitly feed the dist params.
                model.dist_params["loc"]: [
                    [[1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8]],
                    [[2, 9], [2, 9], [2, 9], [2, 9], [2, 9], [2, 9]],
                    [[3, 6], [3, 6], [3, 6], [3, 6], [3, 6], [3, 6]],
                ],
                model.dist_params["scale"]: [
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                ],
            }
            batch_losses, per_example_loss, num_examples, total_loss = sess.run(
                [
                    model.batch_losses, model.per_example_loss,
                    model.num_nonzero_weight_examples, model.total_loss
                ],
                feed_dict=feed_dict)
            np.testing.assert_array_almost_equal(
                [[[-1.38364656, 48.61635344], [-0.69049938, 11.80950062],
                  [0.22579135, 2.22579135], [0.91893853, 1.41893853],
                  [1.61208571, 1.73708571], [2.52837645, 2.54837645]],
                 [[-1.38364656, 0], [-0.69049938, 11.80950062],
                  [0.22579135, 2.22579135], [0, 1.41893853], [0, 1.73708571],
                  [0, 0]], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]],
                batch_losses)
            np.testing.assert_array_almost_equal([5.96392435, 2.19185166, 0],
                                                 per_example_loss)
            np.testing.assert_almost_equal(2, num_examples)
            np.testing.assert_almost_equal(4.07788801, total_loss)
    def test_build_model(self):
        time_series_length = 9
        input_num_features = 8
        context_num_features = 7

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "normal",
                "min_scale": 0.001,
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        variables = {v.op.name: v for v in tf.trainable_variables()}

        # Verify variable shapes in two residual blocks.

        var = variables["preprocess/causal_conv/kernel"]
        self.assertShapeEquals((5, 8, 3), var)
        var = variables["preprocess/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)

        var = variables["block_0/dilation_1/filter/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_0/dilation_1/filter/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/filter/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_0/dilation_1/filter/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/gate/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_0/dilation_1/gate/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/gate/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_0/dilation_1/gate/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/residual/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 3), var)
        var = variables["block_0/dilation_1/residual/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_0/dilation_1/skip/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 6), var)
        var = variables["block_0/dilation_1/skip/conv1x1/bias"]
        self.assertShapeEquals((6, ), var)

        var = variables["block_1/dilation_4/filter/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_1/dilation_4/filter/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/filter/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_1/dilation_4/filter/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/gate/causal_conv/kernel"]
        self.assertShapeEquals((2, 3, 3), var)
        var = variables["block_1/dilation_4/gate/causal_conv/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/gate/conv1x1/kernel"]
        self.assertShapeEquals((1, 7, 3), var)
        var = variables["block_1/dilation_4/gate/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/residual/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 3), var)
        var = variables["block_1/dilation_4/residual/conv1x1/bias"]
        self.assertShapeEquals((3, ), var)
        var = variables["block_1/dilation_4/skip/conv1x1/kernel"]
        self.assertShapeEquals((1, 3, 6), var)
        var = variables["block_1/dilation_4/skip/conv1x1/bias"]
        self.assertShapeEquals((6, ), var)

        var = variables["postprocess/conv1x1/kernel"]
        self.assertShapeEquals((1, 6, 6), var)
        var = variables["postprocess/conv1x1/bias"]
        self.assertShapeEquals((6, ), var)
        var = variables["dist_params/conv1x1/kernel"]
        self.assertShapeEquals((1, 6, 16), var)
        var = variables["dist_params/conv1x1/bias"]
        self.assertShapeEquals((16, ), var)

        # Verify total number of trainable parameters.

        num_preprocess_params = (
            hparams.preprocess_kernel_width * input_num_features *
            hparams.preprocess_output_size + hparams.preprocess_output_size)

        num_gated_params = (
            hparams.dilation_kernel_width * hparams.preprocess_output_size *
            hparams.preprocess_output_size + hparams.preprocess_output_size +
            1 * context_num_features * hparams.preprocess_output_size +
            hparams.preprocess_output_size) * 2
        num_residual_params = (1 * hparams.preprocess_output_size *
                               hparams.preprocess_output_size +
                               hparams.preprocess_output_size)
        num_skip_params = (
            1 * hparams.preprocess_output_size * hparams.skip_output_dim +
            hparams.skip_output_dim)
        num_block_params = (
            num_gated_params + num_residual_params + num_skip_params) * len(
                hparams.dilation_rates) * hparams.num_residual_blocks

        num_postprocess_params = (
            1 * hparams.skip_output_dim * hparams.skip_output_dim +
            hparams.skip_output_dim)

        num_dist_params = (
            1 * hparams.skip_output_dim * 2 * input_num_features +
            2 * input_num_features)

        total_params = (num_preprocess_params + num_block_params +
                        num_postprocess_params + num_dist_params)

        total_retrieved_params = 0
        for v in tf.trainable_variables():
            total_retrieved_params += np.prod(v.shape)

        self.assertEqual(total_params, total_retrieved_params)

        # Verify model runs and outputs losses of correct shape.

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            batch_size = 11
            feed_dict = {
                input_placeholder:
                np.random.random(
                    (batch_size, time_series_length, input_num_features)),
                context_placeholder:
                np.random.random(
                    (batch_size, time_series_length, context_num_features))
            }
            batch_losses, per_example_loss, total_loss = sess.run(
                [model.batch_losses, model.per_example_loss, model.total_loss],
                feed_dict=feed_dict)
            self.assertShapeEquals(
                (batch_size, time_series_length, input_num_features),
                batch_losses)
            self.assertShapeEquals((batch_size, ), per_example_loss)
            self.assertShapeEquals((), total_loss)
    def test_output_categorical(self):
        time_series_length = 3
        input_num_features = 1
        context_num_features = 7
        num_classes = 4  # For quantized categorical output predictions.

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "categorical",
                "min_scale": 0,
                "num_classes": num_classes,
                "min_quantization_value": 0,
                "max_quantization_value": 1
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        self.assertItemsEqual(["logits"], model.dist_params.keys())
        self.assertShapeEquals(
            (None, time_series_length, input_num_features, num_classes),
            model.dist_params["logits"])

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            feed_dict = {
                input_placeholder: [
                    [[0], [0], [0]],  # min_quantization_value
                    [[0.2], [0.2], [0.2]],  # Within bucket.
                    [[0.25], [0.25], [0.25]],  # On bucket boundary.
                    [[0.5], [0.5], [0.5]],  # On bucket boundary.
                    [[0.8], [0.8], [0.8]],  # Within bucket.
                    [[1], [1], [1]],  # max_quantization_value
                    [[-0.1], [1.5], [200]],  # Outside range: will be clipped.
                ],
                # Context is not needed since we explicitly feed the dist params.
                model.dist_params["logits"]: [
                    [[[1, 0, 0, 0]], [[0, 1, 0, 0]], [[0, 0, 0, 1]]],
                    [[[1, 0, 0, 0]], [[0, 1, 0, 0]], [[0, 0, 0, 1]]],
                    [[[0, 1, 0, 0]], [[1, 0, 0, 0]], [[0, 0, 1, 0]]],
                    [[[0, 0, 1, 0]], [[0, 1, 0, 0]], [[0, 0, 0, 1]]],
                    [[[0, 0, 0, 1]], [[1, 0, 0, 0]], [[1, 0, 0, 0]]],
                    [[[0, 0, 0, 1]], [[0, 1, 0, 0]], [[0, 0, 1, 0]]],
                    [[[1, 0, 0, 0]], [[0, 0, 1, 0]], [[0, 1, 0, 0]]],
                ],
            }
            (target, batch_losses, per_example_loss, num_examples,
             total_loss) = sess.run([
                 model.autoregressive_target, model.batch_losses,
                 model.per_example_loss, model.num_nonzero_weight_examples,
                 model.total_loss
             ],
                                    feed_dict=feed_dict)
            np.testing.assert_array_almost_equal([
                [[0], [0], [0]],
                [[0], [0], [0]],
                [[1], [1], [1]],
                [[2], [2], [2]],
                [[3], [3], [3]],
                [[3], [3], [3]],
                [[0], [3], [3]],
            ], target)
            np.testing.assert_array_almost_equal([
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
                [[0.74366838], [1.74366838], [1.74366838]],
            ], batch_losses)
            np.testing.assert_array_almost_equal([
                1.41033504, 1.41033504, 1.41033504, 1.41033504, 1.41033504,
                1.41033504, 1.41033504
            ], per_example_loss)
            np.testing.assert_almost_equal(7, num_examples)
            np.testing.assert_almost_equal(1.41033504, total_loss)
    def test_build_model_categorical(self):
        time_series_length = 9
        input_num_features = 8
        context_num_features = 7

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "categorical",
                "num_classes": 256,
                "min_quantization_value": -1,
                "max_quantization_value": 1
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        variables = {v.op.name: v for v in tf.trainable_variables()}

        var = variables["dist_params/conv1x1/kernel"]
        self.assertShapeEquals(
            (1, hparams.skip_output_dim,
             hparams.output_distribution.num_classes * input_num_features),
            var)
        var = variables["dist_params/conv1x1/bias"]
        self.assertShapeEquals(
            (hparams.output_distribution.num_classes * input_num_features, ),
            var)

        # Verify model runs and outputs losses of correct shape.

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            batch_size = 11
            feed_dict = {
                input_placeholder:
                np.random.random(
                    (batch_size, time_series_length, input_num_features)),
                context_placeholder:
                np.random.random(
                    (batch_size, time_series_length, context_num_features))
            }
            batch_losses, per_example_loss, total_loss = sess.run(
                [model.batch_losses, model.per_example_loss, model.total_loss],
                feed_dict=feed_dict)
            self.assertShapeEquals(
                (batch_size, time_series_length, input_num_features),
                batch_losses)
            self.assertShapeEquals((batch_size, ), per_example_loss)
            self.assertShapeEquals((), total_loss)
示例#7
0
    def test_output_normal_mixture(self):
        time_series_length = 6
        input_num_features = 2
        context_num_features = 7

        input_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, input_num_features],
            name="input")
        context_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, time_series_length, context_num_features],
            name="context")
        features = {
            "autoregressive_input": input_placeholder,
            "conditioning_stack": context_placeholder
        }
        mode = tf.estimator.ModeKeys.TRAIN
        hparams = configdict.ConfigDict({
            "dilation_kernel_width": 2,
            "skip_output_dim": 6,
            "preprocess_output_size": 3,
            "preprocess_kernel_width": 5,
            "num_residual_blocks": 2,
            "dilation_rates": [1, 2, 4],
            "output_distribution": {
                "type": "normal",
                "min_scale": 0,
                "predict_outlier_distribution": True
            }
        })

        model = astrowavenet_model.AstroWaveNet(features, hparams, mode)
        model.build()

        # Model predicts the loc and scale of the outlier and non-outlier Gaussian
        # distributions, and the probability of being an outlier.
        self.assertItemsEqual(
            ["loc", "scale", "outlier_prob", "outlier_loc", "outlier_scale"],
            model.dist_params.keys())
        self.assertShapeEquals((None, time_series_length, input_num_features),
                               model.dist_params["loc"])
        self.assertShapeEquals((None, time_series_length, input_num_features),
                               model.dist_params["scale"])
        self.assertShapeEquals((2, ), model.dist_params["outlier_prob"])
        self.assertShapeEquals((2, ), model.dist_params["outlier_loc"])
        self.assertShapeEquals((2, ), model.dist_params["outlier_scale"])

        scaffold = tf.train.Scaffold()
        scaffold.finalize()
        with self.cached_session() as sess:
            sess.run([scaffold.init_op, scaffold.local_init_op])
            step = sess.run(model.global_step)
            self.assertEqual(0, step)

            feed_dict = {
                input_placeholder: [
                    [[1, 9], [1, 9], [1, 9], [1, 9], [1, 9], [1, 9]],
                    [[2, 8], [2, 8], [2, 8], [2, 8], [2, 8], [2, 8]],
                ],
                # Context is not needed since we explicitly feed the dist params.
                model.dist_params["loc"]: [
                    [[1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8]],
                    [[2, 9], [2, 9], [2, 9], [2, 9], [2, 9], [2, 9]],
                ],
                model.dist_params["scale"]: [
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                    [[0.1, 0.1], [0.2, 0.2], [0.5, 0.5], [1, 1], [2, 2],
                     [5, 5]],
                ],
                model.dist_params["outlier_prob"]: [0, 0],
                model.dist_params["outlier_loc"]: [1, 8],
                model.dist_params["outlier_scale"]: [1, 0.1],
            }
            batch_losses, per_example_loss, num_examples, total_loss = sess.run(
                [
                    model.batch_losses, model.per_example_loss,
                    model.num_nonzero_weight_examples, model.total_loss
                ],
                feed_dict=feed_dict)

            # Outlier probability is 0.0, so predictions are from the non-outlier
            # distribution.
            np.testing.assert_array_almost_equal(
                [[[-1.38364656, 48.61635344], [-0.69049938, 11.80950062],
                  [0.22579135, 2.22579135], [0.91893853, 1.41893853],
                  [1.61208571, 1.73708571], [2.52837645, 2.54837645]],
                 [[-1.38364656, 48.61635344], [-0.69049938, 11.80950062],
                  [0.22579135, 2.22579135], [0.91893853, 1.41893853],
                  [1.61208571, 1.73708571], [2.52837645, 2.54837645]]],
                batch_losses)
            np.testing.assert_array_almost_equal([5.96392435, 5.96392435],
                                                 per_example_loss)
            np.testing.assert_almost_equal(2, num_examples)
            np.testing.assert_almost_equal(5.96392435, total_loss)

            # Outlier probability is 1.0, so predictions are from the outlier
            # distribution.
            feed_dict[model.dist_params["outlier_prob"]] = [1, 1]
            batch_losses, per_example_loss, num_examples, total_loss = sess.run(
                [
                    model.batch_losses, model.per_example_loss,
                    model.num_nonzero_weight_examples, model.total_loss
                ],
                feed_dict=feed_dict)
            np.testing.assert_array_almost_equal(
                [[[0.918939, 48.616352]] * 6, [[1.418939, -1.383647]] * 6],
                batch_losses)
            np.testing.assert_array_almost_equal([24.7676468, 0.017645916],
                                                 per_example_loss, 5)
            np.testing.assert_almost_equal(2, num_examples)
            np.testing.assert_almost_equal(12.392646358, total_loss, decimal=6)

            # Predictions are weighted from the non-outlier and outlier distributions.
            feed_dict[model.dist_params["outlier_prob"]] = [0.3, 0.5]
            batch_losses, per_example_loss, num_examples, total_loss = sess.run(
                [
                    model.batch_losses, model.per_example_loss,
                    model.num_nonzero_weight_examples, model.total_loss
                ],
                feed_dict=feed_dict)
            np.testing.assert_array_almost_equal(
                [[[-1.06893575, 48.61635208], [-0.41606259, 12.5026474],
                  [0.38831028, 2.91893864], [0.91893858, 2.11208582],
                  [1.34972155, 2.430233], [1.73991919, 3.24152374]],
                 [[-1.05263364, -0.69049942], [-0.38450652, -0.69050133],
                  [0.46027452, -0.71720666], [1.04454803, -0.74938428],
                  [1.55012715, -0.73367846], [2.05226898, -0.70991373]]],
                batch_losses)
            np.testing.assert_array_almost_equal([6.227806, -0.051759],
                                                 per_example_loss)
            np.testing.assert_almost_equal(2, num_examples)
            np.testing.assert_almost_equal(3.0880234, total_loss, decimal=6)