コード例 #1
0
ファイル: datasets_test.py プロジェクト: Crazyonxh/tensorflow
 def testNestedOutputs(self):
   ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
                                                    Dataset.range(4)))))
   total = 0
   # The Iterator will return a nested structure of Tensor objects.
   # Some funkiness to compare against simple integers.
   for (i, x) in enumerate(datasets.Iterator(ds)):
     want = (i, (i, i))
     got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy()))
     self.assertEqual(got, want)
     total += 1
   self.assertEqual(4, total)
コード例 #2
0
 def testNestedOutputs(self):
   ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
                                                    Dataset.range(4)))))
   total = 0
   # The Iterator will return a nested structure of Tensor objects.
   # Some funkiness to compare against simple integers.
   for (i, x) in enumerate(datasets.Iterator(ds)):
     want = (i, (i, i))
     got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy()))
     self.assertEqual(got, want)
     total += 1
   self.assertEqual(4, total)
コード例 #3
0
ファイル: datasets_test.py プロジェクト: Crazyonxh/tensorflow
  def testMapAndFilter(self):
    def even(x):
      return math_ops.equal(math_ops.mod(x, 2), 0)

    it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
    got = [x.numpy() for x in it]
    self.assertAllEqual([0, 4, 16, 36], got)
コード例 #4
0
    def testMapAndFilter(self):
        def even(x):
            return math_ops.equal(math_ops.mod(x, 2), 0)

        it = datasets.Iterator(
            Dataset.range(8).map(math_ops.square).filter(even))
        got = [x.numpy() for x in it]
        self.assertAllEqual([0, 4, 16, 36], got)
コード例 #5
0
    def testPyFunc(self):
        def my_map(inp):
            return [[x + 1 for x in inp]]

        ds = Dataset.range(4).map(
            lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
        got = [x.numpy() for x in datasets.Iterator(ds)]
        self.assertAllEqual([[1], [2], [3], [4]], got)
コード例 #6
0
ファイル: datasets_test.py プロジェクト: Crazyonxh/tensorflow
  def testMultipleIteratorsOnTheSameDataset(self):
    ds = Dataset.range(4)
    it1 = datasets.Iterator(ds)
    it2 = datasets.Iterator(ds)
    got = [x.numpy() for x in it1]
    self.assertAllEqual([0, 1, 2, 3], got)

    got = [x.numpy() for x in it2]
    self.assertAllEqual([0, 1, 2, 3], got)
コード例 #7
0
    def testMultipleIteratorsOnTheSameDataset(self):
        ds = Dataset.range(4)
        it1 = datasets.Iterator(ds)
        it2 = datasets.Iterator(ds)
        got = [x.numpy() for x in it1]
        self.assertAllEqual([0, 1, 2, 3], got)

        got = [x.numpy() for x in it2]
        self.assertAllEqual([0, 1, 2, 3], got)
コード例 #8
0
  def testPyFunc(self):

    def my_map(inp):
      return [[x + 1 for x in inp]]

    ds = Dataset.range(4).map(
        lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
    got = [x.numpy() for x in datasets.Iterator(ds)]
    self.assertAllEqual([[1], [2], [3], [4]], got)
コード例 #9
0
  def testMapAndFilter(self):
    # TODO(ashankar): Address this
    self.skipTest('Not working yet, requires function attribute support')

    def even(x):
      return math_ops.equal(math_ops.mod(x, 2), 0)

    it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
    got = [x.numpy() for x in it]
    self.assertAllEqual([0, 4, 16, 36], got)
コード例 #10
0
ファイル: datasets_test.py プロジェクト: Crazyonxh/tensorflow
 def testBasic(self):
   got = []
   for t in datasets.Iterator(Dataset.range(4)):
     got.append(t.numpy())
   self.assertAllEqual([0, 1, 2, 3], got)
コード例 #11
0
 def testBasic(self):
     got = []
     for t in datasets.Iterator(Dataset.range(4)):
         got.append(t.numpy())
     self.assertAllEqual([0, 1, 2, 3], got)