Exemple #1
0
    def testTrainEvalBehavior(self):
        train_dataset = data_reader.read_examples(self.problem,
                                                  self.filepatterns[0], 16)
        train_examples = train_dataset.make_one_shot_iterator().get_next()
        eval_dataset = data_reader.read_examples(
            self.problem,
            self.filepatterns[0],
            16,
            mode=tf.estimator.ModeKeys.EVAL)
        eval_examples = eval_dataset.make_one_shot_iterator().get_next()

        eval_idxs = []
        with tf.train.MonitoredSession() as sess:
            # Train should be shuffled and run through infinitely
            for i in xrange(30):
                self.assertNotEqual(i, sess.run(train_examples)["inputs"][0])

            # Eval should not be shuffled and only run through once
            for i in xrange(30):
                self.assertEqual(i, sess.run(eval_examples)["inputs"][0])
                eval_idxs.append(i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(eval_examples)
                # Should never run because above line should error
                eval_idxs.append(30)

            # Ensuring that the above exception handler actually ran and we didn't
            # exit the MonitoredSession context.
            eval_idxs.append(-1)

        self.assertAllEqual(list(range(30)) + [-1], eval_idxs)
  def testLengthFilter(self):
    max_len = 15
    dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32)
    dataset = dataset.filter(
        lambda ex: data_reader._example_too_big(ex, max_len))
    examples = dataset.make_one_shot_iterator().get_next()
    with tf.train.MonitoredSession() as sess:
      ex_lens = []
      for _ in xrange(max_len):
        ex_lens.append(len(sess.run(examples)["inputs"]))

    self.assertAllEqual(list(range(1, max_len + 1)), sorted(ex_lens))
  def testPreprocess(self):
    dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32)
    examples = dataset.make_one_shot_iterator().get_next()
    examples = data_reader._preprocess(examples, self.problem, None, None, None)
    with tf.train.MonitoredSession() as sess:
      ex_val = sess.run(examples)
      # problem.preprocess_examples has been run
      self.assertAllClose([42.42], ex_val["new_field"])

      # int64 has been cast to int32
      self.assertEqual(np.int32, ex_val["inputs"].dtype)
      self.assertEqual(np.int32, ex_val["targets"].dtype)
      self.assertEqual(np.float32, ex_val["floats"].dtype)
 def testBasicExampleReading(self):
   dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32)
   examples = dataset.make_one_shot_iterator().get_next()
   with tf.train.MonitoredSession() as sess:
     # Check that there are multiple examples that have the right fields of the
     # right type (lists of int/float).
     for _ in xrange(10):
       ex_val = sess.run(examples)
       inputs, targets, floats = (ex_val["inputs"], ex_val["targets"],
                                  ex_val["floats"])
       self.assertEqual(np.int64, inputs.dtype)
       self.assertEqual(np.int64, targets.dtype)
       self.assertEqual(np.float32, floats.dtype)
       for field in [inputs, targets, floats]:
         self.assertGreater(len(field), 0)
    def testBucketBySeqLength(self):

        def example_len(ex):
            return tf.shape(ex["inputs"])[0]

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        window_size = 40

        dataset = data_reader.read_examples(
            self.problem,
            self.filepatterns[0],
            32,
            mode=tf.contrib.learn.ModeKeys.EVAL)
        dataset = data_reader.bucket_by_sequence_length(
            dataset, example_len,
            boundaries, batch_sizes, window_size)
        batch = dataset.make_one_shot_iterator().get_next()

        input_vals = []
        obs_batch_sizes = []
        with tf.train.MonitoredSession() as sess:
            # Until OutOfRangeError
            while True:
                batch_val = sess.run(batch)
                batch_inputs = batch_val["inputs"]
                batch_size, max_len = batch_inputs.shape
                obs_batch_sizes.append(batch_size)
                for inputs in batch_inputs:
                    input_val = inputs[0]
                    input_vals.append(input_val)
                    # The inputs were constructed such that they were repeated value+1
                    # times (i.e. if the inputs value is 7, the example has 7 repeated 8
                    # times).
                    repeat = input_val + 1
                    # Check padding
                    self.assertAllEqual(
                        [input_val] * repeat + [0] * (max_len - repeat),
                        inputs)

        # Check that all inputs came through
        self.assertEqual(list(range(30)), sorted(input_vals))
        # Check that we saw variable batch size
        self.assertTrue(len(set(obs_batch_sizes)) > 1)