コード例 #1
0
    def testUnbatchSingleElementTupleDataset(self):
        data = tuple([(math_ops.range(10), ) for _ in range(3)])
        data = dataset_ops.Dataset.from_tensor_slices(data)
        expected_types = ((dtypes.int32, ), ) * 3
        data = data.batch(2)
        self.assertEqual(expected_types, data.output_types)
        data = data.apply(dataset_ops.unbatch())
        self.assertEqual(expected_types, data.output_types)

        iterator = data.make_one_shot_iterator()
        op = iterator.get_next()

        with self.test_session() as sess:
            for i in range(10):
                self.assertEqual(((i, ), ) * 3, sess.run(op))

            with self.assertRaises(errors.OutOfRangeError):
                sess.run(op)
コード例 #2
0
  def testUnbatchSingleElementTupleDataset(self):
    data = tuple([(math_ops.range(10),) for _ in range(3)])
    data = dataset_ops.Dataset.from_tensor_slices(data)
    expected_types = ((dtypes.int32,),) * 3
    data = data.batch(2)
    self.assertEqual(expected_types, data.output_types)
    data = data.apply(dataset_ops.unbatch())
    self.assertEqual(expected_types, data.output_types)

    iterator = data.make_one_shot_iterator()
    op = iterator.get_next()

    with self.test_session() as sess:
      for i in range(10):
        self.assertEqual(((i,),) * 3, sess.run(op))

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(op)
コード例 #3
0
    def testUnbatchMultiElementTupleDataset(self):
        data = tuple([(math_ops.range(10 * i,
                                      10 * i + 10), array_ops.fill([10], "hi"))
                      for i in range(3)])
        data = dataset_ops.Dataset.from_tensor_slices(data)
        expected_types = ((dtypes.int32, dtypes.string), ) * 3
        data = data.batch(2)
        self.assertAllEqual(expected_types, data.output_types)
        data = data.apply(dataset_ops.unbatch())
        self.assertAllEqual(expected_types, data.output_types)

        iterator = data.make_one_shot_iterator()
        op = iterator.get_next()

        with self.test_session() as sess:
            for i in range(10):
                self.assertEqual(
                    ((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
                    sess.run(op))

            with self.assertRaises(errors.OutOfRangeError):
                sess.run(op)
コード例 #4
0
  def testUnbatchMultiElementTupleDataset(self):
    data = tuple([(math_ops.range(10 * i, 10 * i + 10),
                   array_ops.fill([10], "hi"))
                  for i in range(3)])
    data = dataset_ops.Dataset.from_tensor_slices(data)
    expected_types = ((dtypes.int32, dtypes.string),) * 3
    data = data.batch(2)
    self.assertAllEqual(expected_types, data.output_types)
    data = data.apply(dataset_ops.unbatch())
    self.assertAllEqual(expected_types, data.output_types)

    iterator = data.make_one_shot_iterator()
    op = iterator.get_next()

    with self.test_session() as sess:
      for i in range(10):
        self.assertEqual(((i, b"hi"),
                          (10 + i, b"hi"),
                          (20 + i, b"hi")),
                         sess.run(op))

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(op)