예제 #1
0
    def testBadLabelIdsRaisesValueError(self):
        self._input_config["label_feature"] = "label_str"

        # Label ids should be contiguous integers starting at 0.
        self._input_config["label_map"] = {"PC": 1, "AFP": 2, "NTP": 3}

        with self.assertRaises(ValueError):
            dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                      input_config=self._input_config,
                                      batch_size=4)
예제 #2
0
  def testBadLabelIdsRaisesValueError(self):
    self._input_config["label_feature"] = "label_str"

    # Label ids should be contiguous integers starting at 0.
    self._input_config["label_map"] = {"PC": 1, "AFP": 2, "NTP": 3}

    with self.assertRaises(ValueError):
      dataset_ops.build_dataset(
          file_pattern=self._file_pattern,
          input_config=self._input_config,
          batch_size=4)
예제 #3
0
  def input_fn(config, params):
    """Builds an input pipeline that reads a dataset from TFRecord files."""
    # Infer whether this input_fn was called by Estimator or TPUEstimator using
    # the config type.
    use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)

    dataset = dataset_ops.build_dataset(
        file_pattern=file_pattern,
        input_config=input_config,
        batch_size=params["batch_size"],
        include_labels=include_labels,
        reverse_time_series_prob=reverse_time_series_prob,
        shuffle_filenames=shuffle_filenames,
        shuffle_values_buffer=shuffle_values_buffer,
        repeat=repeat,
        use_tpu=use_tpu)

    # We must use an initializable iterator, rather than a one-shot iterator,
    # because the input pipeline contains a stateful table that requires
    # initialization. We add the initializer to the TABLE_INITIALIZERS
    # collection to ensure it is run during initialization.
    iterator = dataset.make_initializable_iterator()
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)

    inputs = iterator.get_next()
    return inputs, inputs.pop("labels", None)
예제 #4
0
    def testUnknownLabel(self):
        self._input_config["label_feature"] = "label_str"

        # label_map does not include "NTP".
        self._input_config["label_map"] = {"PC": 1, "AFP": 0}

        dataset = dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                            input_config=self._input_config,
                                            batch_size=4)

        # We need an initializable iterator when using labels because of the
        # stateful label id hash table.
        iterator = dataset.make_initializable_iterator()
        inputs = iterator.get_next()
        init_op = tf.tables_initializer()

        # Expect features and labels.
        self.assertItemsEqual(
            ["time_series_features", "aux_features", "labels"], inputs.keys())
        labels = inputs["labels"]

        with self.test_session() as sess:
            sess.run([init_op, iterator.initializer])

            # Unknown label "NTP".
            with self.assertRaises(tf.errors.InvalidArgumentError):
                sess.run(labels)
예제 #5
0
    def testLabels2(self):
        self._input_config["label_feature"] = "label_str"
        self._input_config["label_map"] = {"PC": 1, "AFP": 0, "NTP": 0}

        dataset = dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                            input_config=self._input_config,
                                            batch_size=4)

        # We need an initializable iterator when using labels because of the
        # stateful label id hash table.
        iterator = dataset.make_initializable_iterator()
        inputs = iterator.get_next()
        init_op = tf.tables_initializer()

        # Expect features and labels.
        self.assertItemsEqual(
            ["time_series_features", "aux_features", "labels"], inputs.keys())
        labels = inputs["labels"]

        with self.test_session() as sess:
            sess.run([init_op, iterator.initializer])

            # Fetch 3 batches.
            np.testing.assert_array_equal([1, 0, 0, 1], sess.run(labels))
            np.testing.assert_array_equal([0, 0, 1, 0], sess.run(labels))
            np.testing.assert_array_equal([0, 1], sess.run(labels))

            # No more batches.
            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(labels)
예제 #6
0
  def testUnknownLabel(self):
    self._input_config["label_feature"] = "label_str"

    # label_map does not include "NTP".
    self._input_config["label_map"] = {"PC": 1, "AFP": 0}

    dataset = dataset_ops.build_dataset(
        file_pattern=self._file_pattern,
        input_config=self._input_config,
        batch_size=4)

    # We need an initializable iterator when using labels because of the
    # stateful label id hash table.
    iterator = dataset.make_initializable_iterator()
    inputs = iterator.get_next()
    init_op = tf.tables_initializer()

    # Expect features and labels.
    self.assertItemsEqual(["time_series_features", "aux_features", "labels"],
                          inputs.keys())
    labels = inputs["labels"]

    with self.test_session() as sess:
      sess.run([init_op, iterator.initializer])

      # Unknown label "NTP".
      with self.assertRaises(tf.errors.InvalidArgumentError):
        sess.run(labels)
예제 #7
0
  def testLabels2(self):
    self._input_config["label_feature"] = "label_str"
    self._input_config["label_map"] = {"PC": 1, "AFP": 0, "NTP": 0}

    dataset = dataset_ops.build_dataset(
        file_pattern=self._file_pattern,
        input_config=self._input_config,
        batch_size=4)

    # We need an initializable iterator when using labels because of the
    # stateful label id hash table.
    iterator = dataset.make_initializable_iterator()
    inputs = iterator.get_next()
    init_op = tf.tables_initializer()

    # Expect features and labels.
    self.assertItemsEqual(["time_series_features", "aux_features", "labels"],
                          inputs.keys())
    labels = inputs["labels"]

    with self.test_session() as sess:
      sess.run([init_op, iterator.initializer])

      # Fetch 3 batches.
      np.testing.assert_array_equal([1, 0, 0, 1], sess.run(labels))
      np.testing.assert_array_equal([0, 0, 1, 0], sess.run(labels))
      np.testing.assert_array_equal([0, 1], sess.run(labels))

      # No more batches.
      with self.assertRaises(tf.errors.OutOfRangeError):
        sess.run(labels)
  def input_fn(config, params):
    """Builds an input pipeline that reads a dataset from TFRecord files."""
    # Infer whether this input_fn was called by Estimator or TPUEstimator using
    # the config type.
    use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)

    dataset = dataset_ops.build_dataset(
        file_pattern=file_pattern,
        input_config=input_config,
        batch_size=params["batch_size"],
        include_labels=include_labels,
        reverse_time_series_prob=reverse_time_series_prob,
        shuffle_filenames=shuffle_filenames,
        shuffle_values_buffer=shuffle_values_buffer,
        repeat=repeat,
        use_tpu=use_tpu)

    return dataset
예제 #9
0
  def __call__(self, config, params):
    """Builds the input pipeline."""
    # Infer whether this input_fn was called by Estimator or TPUEstimator using
    # the config type.
    use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)

    mode = self._mode
    include_labels = (
        mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL])
    reverse_time_series_prob = 0.5 if mode == tf.estimator.ModeKeys.TRAIN else 0
    shuffle_filenames = (mode == tf.estimator.ModeKeys.TRAIN)
    dataset = dataset_ops.build_dataset(
        file_pattern=self._file_pattern,
        input_config=self._input_config,
        batch_size=params["batch_size"],
        include_labels=include_labels,
        reverse_time_series_prob=reverse_time_series_prob,
        shuffle_filenames=shuffle_filenames,
        shuffle_values_buffer=self._shuffle_values_buffer,
        repeat=self._repeat,
        use_tpu=use_tpu)

    return dataset
예제 #10
0
  def __call__(self, config, params):
    """Builds the input pipeline."""
    # Infer whether this input_fn was called by Estimator or TPUEstimator using
    # the config type.
    use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)

    mode = self._mode
    include_labels = (
        mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL])
    reverse_time_series_prob = 0.5 if mode == tf.estimator.ModeKeys.TRAIN else 0
    shuffle_filenames = (mode == tf.estimator.ModeKeys.TRAIN)
    dataset = dataset_ops.build_dataset(
        file_pattern=self._file_pattern,
        input_config=self._input_config,
        batch_size=params["batch_size"],
        include_labels=include_labels,
        reverse_time_series_prob=reverse_time_series_prob,
        shuffle_filenames=shuffle_filenames,
        shuffle_values_buffer=self._shuffle_values_buffer,
        repeat=self._repeat,
        use_tpu=use_tpu)

    return dataset
예제 #11
0
    def testTPU(self):
        dataset = dataset_ops.build_dataset(file_pattern=self._file_pattern,
                                            input_config=self._input_config,
                                            batch_size=4,
                                            include_labels=False)

        # We can use a one-shot iterator without labels because we don't have the
        # stateful hash map for label ids.
        iterator = dataset.make_one_shot_iterator()
        features = iterator.get_next()

        # Expect features only.
        self.assertItemsEqual(["time_series_features", "aux_features"],
                              features.keys())

        with self.test_session() as sess:
            # Batch 1.
            f = sess.run(features)
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
            ], f["time_series_features"]["global_view"])
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
            ], f["time_series_features"]["local_view"])
            np.testing.assert_array_almost_equal(
                [[100], [101], [102], [103]], f["aux_features"]["aux_feature"])

            # Batch 2.
            f = sess.run(features)
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
            ], f["time_series_features"]["global_view"])
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
            ], f["time_series_features"]["local_view"])
            np.testing.assert_array_almost_equal(
                [[104], [105], [106], [107]], f["aux_features"]["aux_feature"])

            # Batch 3.
            f = sess.run(features)
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3, 4, 5, 6, 7],
                [0, 1, 2, 3, 4, 5, 6, 7],
            ], f["time_series_features"]["global_view"])
            np.testing.assert_array_almost_equal([
                [0, 1, 2, 3],
                [0, 1, 2, 3],
            ], f["time_series_features"]["local_view"])
            np.testing.assert_array_almost_equal(
                [[108], [109]], f["aux_features"]["aux_feature"])

            # No more batches.
            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(features)
예제 #12
0
 def testNonExistentFileRaisesValueError(self):
     with self.assertRaises(ValueError):
         dataset_ops.build_dataset(file_pattern="nonexistent",
                                   input_config=self._input_config,
                                   batch_size=4)
예제 #13
0
  def testTPU(self):
    dataset = dataset_ops.build_dataset(
        file_pattern=self._file_pattern,
        input_config=self._input_config,
        batch_size=4,
        include_labels=False)

    # We can use a one-shot iterator without labels because we don't have the
    # stateful hash map for label ids.
    iterator = dataset.make_one_shot_iterator()
    features = iterator.get_next()

    # Expect features only.
    self.assertItemsEqual(["time_series_features", "aux_features"],
                          features.keys())

    with self.test_session() as sess:
      # Batch 1.
      f = sess.run(features)
      np.testing.assert_array_almost_equal([
          [0, 1, 2, 3, 4, 5, 6, 7],
          [0, 1, 2, 3, 4, 5, 6, 7],
          [0, 1, 2, 3, 4, 5, 6, 7],
          [0, 1, 2, 3, 4, 5, 6, 7],
      ], f["time_series_features"]["global_view"])
      np.testing.assert_array_almost_equal([
          [0, 1, 2, 3],
          [0, 1, 2, 3],
          [0, 1, 2, 3],
          [0, 1, 2, 3],
      ], f["time_series_features"]["local_view"])
      np.testing.assert_array_almost_equal([[100], [101], [102], [103]],
                                           f["aux_features"]["aux_feature"])

      # Batch 2.
      f = sess.run(features)
      np.testing.assert_array_almost_equal([
          [0, 1, 2, 3, 4, 5, 6, 7],
          [0, 1, 2, 3, 4, 5, 6, 7],
          [0, 1, 2, 3, 4, 5, 6, 7],
          [0, 1, 2, 3, 4, 5, 6, 7],
      ], f["time_series_features"]["global_view"])
      np.testing.assert_array_almost_equal([
          [0, 1, 2, 3],
          [0, 1, 2, 3],
          [0, 1, 2, 3],
          [0, 1, 2, 3],
      ], f["time_series_features"]["local_view"])
      np.testing.assert_array_almost_equal([[104], [105], [106], [107]],
                                           f["aux_features"]["aux_feature"])

      # Batch 3.
      f = sess.run(features)
      np.testing.assert_array_almost_equal([
          [0, 1, 2, 3, 4, 5, 6, 7],
          [0, 1, 2, 3, 4, 5, 6, 7],
      ], f["time_series_features"]["global_view"])
      np.testing.assert_array_almost_equal([
          [0, 1, 2, 3],
          [0, 1, 2, 3],
      ], f["time_series_features"]["local_view"])
      np.testing.assert_array_almost_equal([[108], [109]],
                                           f["aux_features"]["aux_feature"])

      # No more batches.
      with self.assertRaises(tf.errors.OutOfRangeError):
        sess.run(features)
예제 #14
0
 def testNonExistentFileRaisesValueError(self):
   with self.assertRaises(ValueError):
     dataset_ops.build_dataset(
         file_pattern="nonexistent",
         input_config=self._input_config,
         batch_size=4)