示例#1
0
  def test_sequence_with_none_permissions_multi_input(self):
    use_log_space = True
    use_start_and_end_states = False
    scores = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[-1.0,  1.0, -2.0,  2.0],
                                   [ 3.0, -3.0,  4.0, -4.0],
                                   [ 5.0,  1.0, 10.0,  1.0],
                                   [-7.0,  7.0, -8.0,  8.0]], dtype=np.float32)

    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence = self._decode_greedy_sequence(
        scores,
        transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a multi-item batch.
    multiple_input = np.array([scores, scores, scores], dtype=np.float32)

    multiple_sequence_op = tftext.greedy_constrained_sequence(
        multiple_input, [2, 2, 2],
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    multiple_sequence_result = self.evaluate(multiple_sequence_op)
    self.assertRaggedEqual(multiple_sequence_result,
                           [sequence, sequence, sequence])
示例#2
0
  def test_ragged_inputs(self):
    use_log_space = True
    use_start_and_end_states = False
    input_1 = np.array([[10.0, 13.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0],
                        [13.0, 12.0, 11.0, 10.0]])
    input_2 = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # TODO(momernick): Extend RT support to lists-of-ndarrays.
    scores = tf.ragged.constant([input_1.tolist(), input_2.tolist()])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[-1.0, 1.0, -2.0, 2.0],
                                   [3.0, -3.0, 4.0, -4.0],
                                   [5.0, 1.0, 10.0, 1.0],
                                   [-7.0, 7.0, -8.0, 8.0]],
                                   dtype=np.float32)

    allowed_transitions = np.array([[ True,  True,  True,  True],
                                    [ True,  True,  True,  True],
                                    [ True, False,  True, False],
                                    [ True,  True,  True,  True]])
    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence_1 = self._decode_greedy_sequence(
        input_1,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    sequence_2 = self._decode_greedy_sequence(
        input_2,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    expected_sequence = tf.ragged.constant([sequence_1, sequence_2])

    # Test a ragged batch.
    ragged_op = tftext.greedy_constrained_sequence(
        scores,
        allowed_transitions=allowed_transitions,
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    ragged_result = self.evaluate(ragged_op)
    self.assertRaggedEqual(ragged_result, expected_sequence)
示例#3
0
  def test_sequence_in_exp_space_with_start_end_states_multi_batch_item(self):
    use_log_space = False
    use_start_and_end_states = True
    scores = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[ .1,  .2,  .3,  .4, .1],
                                   [ .5,  .6,  .7,  .8, .1],
                                   [ .9,   1, .15,   1, .1],
                                   [.25, .35, .45, .55, .5],
                                   [ .1,  .5,  .1,  .1,  1]], dtype=np.float32)

    allowed_transitions = np.array([[ True,  True,  True,  True,  True],
                                    [ True,  True,  True,  True,  True],
                                    [ True, False,  True, False,  True],
                                    [ True,  True,  True,  True,  True],
                                    [ True,  True,  True,  True, False]])
    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence = self._decode_greedy_sequence(
        scores,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a multi-item batch.
    multiple_input = np.array([scores, scores, scores], dtype=np.float32)

    multiple_sequence_op = tftext.greedy_constrained_sequence(
        multiple_input, [2, 2, 2],
        allowed_transitions=allowed_transitions,
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    multiple_sequence_result = self.evaluate(multiple_sequence_op)
    self.assertRaggedEqual(multiple_sequence_result,
                           [sequence, sequence, sequence])
示例#4
0
  def test_sequence_in_log_space_with_start_end_states_single_batch_item(self):
    use_log_space = True
    use_start_and_end_states = True
    scores = np.array([[10.0, 12.0, 7.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[-1.0,  1.0, -2.0,  2.0, 0.0],
                                   [ 3.0, -3.0,  4.0, -4.0, 0.0],
                                   [ 5.0,  1.0, 10.0,  1.0, 1.0],
                                   [-7.0,  7.0, -8.0,  8.0, 0.0],
                                   [ 0.0,  1.0,  2.0,  3.0, 0.0]],
                                  dtype=np.float32)

    allowed_transitions = np.array([[ True,  True,  True,  True,  True],
                                    [ True,  True,  True,  True,  True],
                                    [ True, False,  True, False, False],
                                    [ True,  True,  True,  True,  True],
                                    [ True, False,  True,  True,  True]])
    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence = self._decode_greedy_sequence(
        scores,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a single-item batch.
    single_input = np.array([scores], dtype=np.float32)
    single_sequence_op = tftext.greedy_constrained_sequence(
        single_input, [2],
        allowed_transitions=allowed_transitions,
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    single_result = self.evaluate(single_sequence_op)
    self.assertRaggedEqual(single_result, [sequence])