Ejemplo n.º 1
0
    def testBadLabelIdsRaisesValueError(self):
        self._input_config["label_feature"] = "label_str"

        # Label ids should be contiguous integers starting at 0.
        self._input_config["label_map"] = {"PC": 1, "AFP": 2, "NTP": 3}

        with self.assertRaises(ValueError):
            dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                      input_config=self._input_config,
                                      batch_size=4)
Ejemplo n.º 2
0
    def testUnknownLabel(self):
        self._input_config["label_feature"] = "label_str"

        # label_map does not include "NTP".
        self._input_config["label_map"] = {"PC": 1, "AFP": 0}

        dataset = dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                            input_config=self._input_config,
                                            batch_size=4)

        # We need an initializable iterator when using labels because of the
        # stateful label id hash table.
        iterator = dataset.make_initializable_iterator()
        inputs = iterator.get_next()
        init_op = tf.tables_initializer()

        # Expect features and labels.
        self.assertItemsEqual(
            ["time_series_features", "aux_features", "labels"], inputs.keys())
        labels = inputs["labels"]

        with self.test_session() as sess:
            sess.run([init_op, iterator.initializer])

            # Unknown label "NTP".
            with self.assertRaises(tf.errors.InvalidArgumentError):
                sess.run(labels)
Ejemplo n.º 3
0
    def testLabels2(self):
        self._input_config["label_feature"] = "label_str"
        self._input_config["label_map"] = {"PC": 1, "AFP": 0, "NTP": 0}

        dataset = dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                            input_config=self._input_config,
                                            batch_size=4)

        # We need an initializable iterator when using labels because of the
        # stateful label id hash table.
        iterator = dataset.make_initializable_iterator()
        inputs = iterator.get_next()
        init_op = tf.tables_initializer()

        # Expect features and labels.
        self.assertItemsEqual(
            ["time_series_features", "aux_features", "labels"], inputs.keys())
        labels = inputs["labels"]

        with self.test_session() as sess:
            sess.run([init_op, iterator.initializer])

            # Fetch 3 batches.
            np.testing.assert_array_equal([1, 0, 0, 1], sess.run(labels))
            np.testing.assert_array_equal([0, 0, 1, 0], sess.run(labels))
            np.testing.assert_array_equal([0, 1], sess.run(labels))

            # No more batches.
            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(labels)
Ejemplo n.º 4
0
    def input_fn(config, params):
        """Builds an input pipeline that reads a dataset from TFRecord files."""
        # Infer whether this input_fn was called by Estimator or TPUEstimator using
        # the config type.
        use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)

        dataset = dataset_ops.build_dataset(
            file_pattern=file_pattern,
            input_config=input_config,
            batch_size=params["batch_size"],
            include_labels=include_labels,
            reverse_time_series_prob=reverse_time_series_prob,
            shuffle_filenames=shuffle_filenames,
            shuffle_values_buffer=shuffle_values_buffer,
            repeat=repeat,
            use_tpu=use_tpu)

        return dataset
Ejemplo n.º 5
0
    def testTPU(self):
        dataset = dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                            input_config=self._input_config,
                                            batch_size=4,
                                            include_labels=False)

        # We can use a one-shot iterator without labels because we don't have the
        # stateful hash map for label ids.
        iterator = dataset.make_one_shot_iterator()
        features = iterator.get_next()

        # Expect features only.
        self.assertItemsEqual(["time_series_features", "aux_features"],
                              features.keys())

        with self.test_session() as sess:
            # Batch 1.
            f = sess.run(features)
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
            ], f["time_series_features"]["global_view"])
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
            ], f["time_series_features"]["local_view"])
            np.testing.assert_array_almost_equal(
                [[100], [101], [102], [103]], f["aux_features"]["aux_feature"])

            # Batch 2.
            f = sess.run(features)
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
            ], f["time_series_features"]["global_view"])
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
            ], f["time_series_features"]["local_view"])
            np.testing.assert_array_almost_equal(
                [[104], [105], [106], [107]], f["aux_features"]["aux_feature"])

            # Batch 3.
            f = sess.run(features)
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
            ], f["time_series_features"]["global_view"])
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3],
                [0, 1, 2, 3],
            ], f["time_series_features"]["local_view"])
            np.testing.assert_array_almost_equal(
                [[108], [109]], f["aux_features"]["aux_feature"])

            # No more batches.
            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(features)
Ejemplo n.º 6
0
 def testNonExistentFileRaisesValueError(self):
     with self.assertRaises(ValueError):
         dataset_ops.build_dataset(file_pattern="nonexistent",
                                   input_config=self._input_config,
                                   batch_size=4)