def test_single_input_sequence_with_implicit_lengths(self):
        use_log_space = True
        use_start_and_end_states = False
        scores = np.array([[10.0, 13.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0],
                           [13.0, 12.0, 11.0, 10.0]])
        # pyformat: disable
        # pylint: disable=bad-whitespace
        # pylint: disable=bad-continuation
        transition_weights = np.array(
            [[-1.0, 1.0, -2.0, 2.0], [3.0, -3.0, 4.0, -4.0],
             [5.0, 1.0, 10.0, 1.0], [-7.0, 7.0, -8.0, 8.0]],
            dtype=np.float32)

        # pyformat: enable
        # pylint: enable=bad-whitespace
        # pylint: enable=bad-continuation
        sequence, _ = viterbi_decode.decode(
            scores,
            transition_weights,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)

        # Test a multi-item batch.
        multiple_input = np.array([scores], dtype=np.float32)

        single_sequence_op = sequence_op.viterbi_constrained_sequence(
            multiple_input,
            transition_weights=transition_weights,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)
        single_sequence_result = self.evaluate(single_sequence_op)
        self.assertRaggedEqual(single_sequence_result, [sequence])
    def test_ragged_input_sequence(self):
        use_log_space = True
        use_start_and_end_states = False
        input_1 = np.array([[10.0, 13.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0],
                            [13.0, 12.0, 11.0, 10.0]])
        input_2 = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
        # TODO(b/122968457): Extend RT support to lists-of-ndarrays.
        scores = ragged_factory_ops.constant(
            [input_1.tolist(), input_2.tolist()])
        # pyformat: disable
        # pylint: disable=bad-whitespace
        # pylint: disable=bad-continuation
        transition_weights = np.array(
            [[-1.0, 1.0, -2.0, 2.0], [3.0, -3.0, 4.0, -4.0],
             [5.0, 1.0, 10.0, 1.0], [-7.0, 7.0, -8.0, 8.0]],
            dtype=np.float32)

        # pyformat: enable
        # pylint: enable=bad-whitespace
        # pylint: enable=bad-continuation
        sequence_1, _ = viterbi_decode.decode(
            input_1,
            transition_weights,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)
        sequence_2, _ = viterbi_decode.decode(
            input_2,
            transition_weights,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)
        expected_sequence = ragged_factory_ops.constant(
            [sequence_1, sequence_2])

        # Test a ragged batch
        single_sequence_op = sequence_op.viterbi_constrained_sequence(
            scores,
            transition_weights=transition_weights,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)
        single_sequence_result = self.evaluate(single_sequence_op)
        self.assertRaggedEqual(single_sequence_result, expected_sequence)
コード例 #3
0
  def test_sequence_in_log_space_with_start_end_states_multi_input(self):
    use_log_space = True
    use_start_and_end_states = True
    scores = np.array([[10.0, 12.0, 7.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[-1.0,  1.0, -2.0,  2.0, 0.0],
                                   [ 3.0, -3.0,  4.0, -4.0, 0.0],
                                   [ 5.0,  1.0, 10.0,  1.0, 1.0],
                                   [-7.0,  7.0, -8.0,  8.0, 0.0],
                                   [ 0.0,  1.0,  2.0,  3.0, 0.0]],
                                  dtype=np.float32)

    allowed_transitions = np.array([[True,  True,  True,  True,  True],
                                    [True,  True,  True,  True,  True],
                                    [True, False,  True, False, False],
                                    [True,  True,  True,  True,  True],
                                    [True, False,  True,  True,  True]])
    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence, _ = viterbi_decode.decode(
        scores,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a multi-item batch.
    multiple_input = np.array([scores, scores, scores], dtype=np.float32)

    multiple_sequence_op = sequence_op.viterbi_constrained_sequence(
        multiple_input, [2, 2, 2],
        allowed_transitions=allowed_transitions,
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    multiple_sequence_result = self.evaluate(multiple_sequence_op)
    self.assertRaggedEqual(multiple_sequence_result,
                           [sequence, sequence, sequence])
コード例 #4
0
  def test_sequence_in_exp_space_with_start_end_states_multi_input(self):
    use_log_space = False
    use_start_and_end_states = True
    scores = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[ .1,  .2,  .3,  .4, .1],
                                   [ .5,  .6,  .7,  .8, .1],
                                   [ .9,   1, .15,   1, .1],
                                   [.25, .35, .45, .55, .5],
                                   [ .1,  .5,  .1,  .1,  1]], dtype=np.float32)

    allowed_transitions = np.array([[True,  True,  True,  True,  True],
                                    [True,  True,  True,  True,  True],
                                    [True, False,  True, False,  True],
                                    [True,  True,  True,  True,  True],
                                    [True,  True,  True,  True, False]])
    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence, _ = viterbi_decode.decode(
        scores,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a multi-item batch.
    multiple_input = np.array([scores, scores, scores], dtype=np.float32)

    multiple_sequence_op = sequence_op.viterbi_constrained_sequence(
        multiple_input, [2, 2, 2],
        allowed_transitions=allowed_transitions,
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    multiple_sequence_result = self.evaluate(multiple_sequence_op)
    self.assertRaggedEqual(multiple_sequence_result,
                           [sequence, sequence, sequence])