示例#1
0
    def test_gather_tree(self):
        # (max_time = 3, batch_size = 2, beam_width = 3)

        # create (batch_size, max_time, beam_width) matrix and transpose it
        predicted_ids = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                                  [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
                                 dtype=np.int32).transpose([1, 0, 2])
        parent_ids = np.array([[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
                               [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
                              dtype=np.int32).transpose([1, 0, 2])

        # sequence_lengths is shaped (batch_size = 3)
        max_sequence_lengths = [3, 3]

        expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
                                    [[2, 4, 4], [7, 6, 6],
                                     [8, 9, 10]]]).transpose([1, 0, 2])

        res = beam_search_ops.gather_tree(
            predicted_ids,
            parent_ids,
            max_sequence_lengths=max_sequence_lengths,
            end_token=11)

        with self.cached_session() as sess:
            res_ = sess.run(res)

        self.assertAllEqual(expected_result, res_)
示例#2
0
 def testGatherTreeOne(self):
     # (max_time = 4, batch_size = 1, beams = 3)
     end_token = 10
     step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                        [-1, -1, -1]]])
     parent_ids = _transpose_batch_time([[[0, 0, 0], [0, 1, 1], [2, 1, 2],
                                          [-1, -1, -1]]])
     max_sequence_lengths = [3]
     expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6],
                                               [7, 8, 9], [10, 10, 10]]])
     beams = beam_search_ops.gather_tree(
         step_ids=step_ids,
         parent_ids=parent_ids,
         max_sequence_lengths=max_sequence_lengths,
         end_token=end_token)
     with self.cached_session(use_gpu=True):
         self.assertAllEqual(expected_result, self.evaluate(beams))
示例#3
0
    def testGatherTreeBatch(self):
        batch_size = 10
        beam_width = 15
        max_time = 8
        max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
        end_token = 5

        with self.cached_session(use_gpu=True):
            step_ids = np.random.randint(0,
                                         high=end_token + 1,
                                         size=(max_time, batch_size,
                                               beam_width))
            parent_ids = np.random.randint(0,
                                           high=beam_width - 1,
                                           size=(max_time, batch_size,
                                                 beam_width))

            beams = beam_search_ops.gather_tree(
                step_ids=step_ids.astype(np.int32),
                parent_ids=parent_ids.astype(np.int32),
                max_sequence_lengths=max_sequence_lengths,
                end_token=end_token)

            self.assertEqual((max_time, batch_size, beam_width), beams.shape)
            beams_value = self.evaluate(beams)
            for b in range(batch_size):
                # Past max_sequence_lengths[b], we emit all end tokens.
                b_value = beams_value[max_sequence_lengths[b]:, b, :]
                self.assertAllClose(b_value, end_token * np.ones_like(b_value))
            for batch, beam in itertools.product(range(batch_size),
                                                 range(beam_width)):
                v = np.squeeze(beams_value[:, batch, beam])
                if end_token in v:
                    found_bad = np.where(v == -1)[0]
                    self.assertEqual(0, len(found_bad))
                    found = np.where(v == end_token)[0]
                    found = found[0]  # First occurrence of end_token.
                    # If an end_token is found, everything before it should be a
                    # valid id and everything after it should be -1.
                    if found > 0:
                        self.assertAllEqual(
                            v[:found - 1] >= 0,
                            np.ones_like(v[:found - 1], dtype=bool))
                    self.assertAllClose(
                        v[found + 1:], end_token * np.ones_like(v[found + 1:]))
示例#4
0
 def testBadParentValuesOnCPU(self):
     # (batch_size = 1, max_time = 4, beams = 3)
     # bad parent in beam 1 time 1
     end_token = 10
     step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                        [-1, -1, -1]]])
     parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2],
                                          [-1, -1, -1]]])
     max_sequence_lengths = [3]
     with ops.device("/cpu:0"):
         with self.assertRaisesOpError(
                 r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
             beams = beam_search_ops.gather_tree(
                 step_ids=step_ids,
                 parent_ids=parent_ids,
                 max_sequence_lengths=max_sequence_lengths,
                 end_token=end_token)
             self.evaluate(beams)
示例#5
0
 def testBadParentValuesOnGPU(self):
     # Only want to run this test on CUDA devices, as gather_tree is not
     # registered for SYCL devices.
     if not test.is_gpu_available(cuda_only=True):
         return
     # (max_time = 4, batch_size = 1, beams = 3)
     # bad parent in beam 1 time 1; appears as a negative index at time 0
     end_token = 10
     step_ids = _transpose_batch_time([[[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                        [-1, -1, -1]]])
     parent_ids = _transpose_batch_time([[[0, 0, 0], [0, -1, 1], [2, 1, 2],
                                          [-1, -1, -1]]])
     max_sequence_lengths = [3]
     expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6],
                                               [7, 8, 9], [10, 10, 10]]])
     with ops.device("/device:GPU:0"):
         beams = beam_search_ops.gather_tree(
             step_ids=step_ids,
             parent_ids=parent_ids,
             max_sequence_lengths=max_sequence_lengths,
             end_token=end_token)
         self.assertAllEqual(expected_result, self.evaluate(beams))