コード例 #1
0
  def testSplitSources(self):
    p = self._CreateNmtInputParams()
    num_splits = 2
    expected_ids_split_1 = [
        [
            228, 58, 854, 11, 392, 45, 77, 67, 1346, 30, 25, 10, 2283, 933, 14,
            3, 872, 4677, 5, 2
        ],
        [
            328, 22, 463, 571, 134, 10, 3815, 6311, 8, 2203, 3, 654, 2724, 1064,
            5, 2, 0, 0, 0, 0
        ],
    ]

    expected_ids_split_2 = [
        [
            16, 599, 11, 8, 113, 3, 145, 558, 489, 4373, 36, 55, 8988, 5, 2, 0,
            0, 0, 0, 0
        ],
        [
            16, 343, 95, 296, 4550, 4786, 1798, 23019, 8, 10296, 3, 107, 6428,
            1812, 5, 2, 0, 0, 0, 0
        ],
    ]

    with self.session(use_gpu=False) as sess:
      inp = input_generator.NmtInput(p)
      splits = inp.SplitInputBatch(num_splits)
      split_ids = sess.run([splits[0].src.ids, splits[1].src.ids])
      self.assertAllEqual(expected_ids_split_1, split_ids[0])
      self.assertAllEqual(expected_ids_split_2, split_ids[1])
コード例 #2
0
 def testBasic(self):
   p = self._CreateNmtInputParams()
   with self.session(use_gpu=False) as sess:
     inp = input_generator.NmtInput(p)
     # Runs a few steps.
     for _ in range(10):
       sess.run(inp.GetPreprocessedInputBatch())
コード例 #3
0
  def testSplitTargets(self):
    p = self._CreateNmtInputParams()
    num_splits = 2

    with self.session(use_gpu=False) as sess:
      inp = input_generator.NmtInput(p)
      fetched = sess.run(inp.SplitInputBatch(num_splits))

    expected_ids_split_1 = [
        [
            1, 400, 5548, 12, 583, 43, 61, 179, 1265, 22, 27, 7193, 16, 5, 782,
            14077, 6734, 4, 0
        ],
        [
            1, 1639, 32, 1522, 93, 38, 6812, 2624, 9, 2440, 3, 39, 11, 2364,
            24238, 9, 317, 4, 0
        ],
    ]

    expected_ids_split_2 = [
        [
            1, 53, 17787, 12, 3, 5, 1554, 871, 9, 1398, 3, 2784, 18, 25579, 942,
            29828, 5998, 77, 4
        ],
        [
            1, 67, 4141, 11483, 2008, 6, 483, 46, 23, 14852, 3, 39, 5, 9732,
            495, 3176, 21523, 4, 0
        ],
    ]

    self.assertAllEqual(expected_ids_split_1, fetched[0].tgt.ids)
    self.assertAllEqual(expected_ids_split_2, fetched[1].tgt.ids)
コード例 #4
0
    def testSplitTargets(self):
        p = self._CreateNmtInputParams()
        num_splits = 2

        with self.session(use_gpu=False):
            inp = input_generator.NmtInput(p)
            fetched = self.evaluate(inp.SplitInputBatch(num_splits))

        expected_ids_split_1 = [
            [
                1, 272, 7514, 10944, 2220, 815, 3, 39, 6, 3021, 4893, 10, 6693,
                23788, 3410, 0, 0, 0, 0
            ],
            [
                1, 28, 18764, 6, 1413, 2338, 8068, 107, 431, 14, 6, 1083, 3,
                11, 782, 19664, 9, 3622, 4
            ],
        ]

        expected_ids_split_2 = [
            [
                1, 15149, 12, 583, 43, 61, 179, 1265, 22, 27, 7193, 16, 5, 782,
                14077, 6734, 4, 0, 0
            ],
            [
                1, 81, 90, 1397, 9207, 61, 241, 2102, 15, 3003, 424, 6, 483, 4,
                0, 0, 0, 0, 0
            ],
        ]

        tf.logging.info('fetched[0].tgt.ids = %r', fetched[0].tgt.ids)
        tf.logging.info('fetched[1].tgt.ids = %r', fetched[1].tgt.ids)
        self.assertAllEqual(expected_ids_split_1, fetched[0].tgt.ids)
        self.assertAllEqual(expected_ids_split_2, fetched[1].tgt.ids)
コード例 #5
0
    def testSplitSources(self):
        p = self._CreateNmtInputParams()
        num_splits = 2
        expected_ids_split_1 = [
            [
                93, 15027, 643, 8, 2985, 3, 27025, 6, 4569, 2, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0
            ],
            [
                15027, 1668, 4125, 54, 139, 24, 3, 101, 8, 2031, 5545, 2962, 5,
                2, 0, 0, 0, 0, 0, 0
            ],
        ]

        expected_ids_split_2 = [
            [
                626, 854, 11, 392, 45, 77, 67, 1346, 30, 25, 10, 2283, 933, 14,
                22255, 425, 872, 4677, 5, 2
            ],
            [
                52, 21, 1034, 4, 3, 274, 30, 7203, 6275, 3, 967, 795, 142, 5,
                2, 0, 0, 0, 0, 0
            ],
        ]

        with self.session(use_gpu=False):
            inp = input_generator.NmtInput(p)
            splits = inp.SplitInputBatch(num_splits)
            split_ids = self.evaluate([splits[0].src.ids, splits[1].src.ids])
            tf.logging.info('split_ids[0] = %r', split_ids[0])
            tf.logging.info('split_ids[1] = %r', split_ids[1])
            self.assertAllEqual(expected_ids_split_1, split_ids[0])
            self.assertAllEqual(expected_ids_split_2, split_ids[1])
コード例 #6
0
  def testPadToMax(self):
    p = self._CreateNmtInputParams()
    p.bucket_upper_bound = [20]
    p.bucket_batch_limit = [4]
    p.source_max_length = 30
    p.target_max_length = 30
    p.pad_to_max_seq_length = True
    with self.session(use_gpu=False) as sess:
      inp = input_generator.NmtInput(p)
      fetched = py_utils.NestedMap(sess.run(inp.GetPreprocessedInputBatch()))

    def Check(x, pad):
      # Check the shape: (batch, maxlen)
      self.assertEqual(x.shape, (4, 30))
      # Check the padding.
      self.assertAllEqual(x[:, 20:], np.full((4, 10), pad))

    Check(fetched.src.ids, 0)
    Check(fetched.src.paddings, 1)
    Check(fetched.tgt.ids, 0)
    Check(fetched.tgt.labels, 0)
    Check(fetched.tgt.weights, 0)
    Check(fetched.tgt.paddings, 1)