def test_sequence_with_none_permissions_multi_input(self):
    use_log_space = True
    use_start_and_end_states = False
    scores = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[-1.0,  1.0, -2.0,  2.0],
                                   [ 3.0, -3.0,  4.0, -4.0],
                                   [ 5.0,  1.0, 10.0,  1.0],
                                   [-7.0,  7.0, -8.0,  8.0]], dtype=np.float32)

    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence = self._decode_greedy_sequence(
        scores,
        transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a multi-item batch.
    multiple_input = np.array([scores, scores, scores], dtype=np.float32)

    multiple_sequence_op = greedy_op.greedy_constrained_sequence(
        multiple_input, [2, 2, 2],
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    multiple_sequence_result = self.evaluate(multiple_sequence_op)
    self.assertAllEqual(multiple_sequence_result,
                        [sequence, sequence, sequence])
    def test_ragged_inputs(self):
        use_log_space = True
        use_start_and_end_states = False
        input_1 = np.array([[10.0, 13.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0],
                            [13.0, 12.0, 11.0, 10.0]])
        input_2 = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
        # TODO(momernick): Extend RT support to lists-of-ndarrays.
        scores = ragged_factory_ops.constant(
            [input_1.tolist(), input_2.tolist()])
        # pyformat: disable
        # pylint: disable=bad-whitespace
        # pylint: disable=bad-continuation
        transition_weights = np.array(
            [[-1.0, 1.0, -2.0, 2.0], [3.0, -3.0, 4.0, -4.0],
             [5.0, 1.0, 10.0, 1.0], [-7.0, 7.0, -8.0, 8.0]],
            dtype=np.float32)

        allowed_transitions = np.array([[True, True, True, True],
                                        [True, True, True, True],
                                        [True, False, True, False],
                                        [True, True, True, True]])
        # pyformat: enable
        # pylint: enable=bad-whitespace
        # pylint: enable=bad-continuation
        sequence_1 = self._decode_greedy_sequence(
            input_1,
            transition_weights,
            allowed_transitions,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)
        sequence_2 = self._decode_greedy_sequence(
            input_2,
            transition_weights,
            allowed_transitions,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)
        expected_sequence = ragged_factory_ops.constant(
            [sequence_1, sequence_2])

        # Test a ragged batch.
        ragged_op = greedy_op.greedy_constrained_sequence(
            scores,
            allowed_transitions=allowed_transitions,
            transition_weights=transition_weights,
            use_log_space=use_log_space,
            use_start_and_end_states=use_start_and_end_states)
        ragged_result = self.evaluate(ragged_op)
        self.assertRaggedEqual(ragged_result, expected_sequence)
  def test_sequence_in_exp_space_with_start_end_states_multi_batch_item(self):
    use_log_space = False
    use_start_and_end_states = True
    scores = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[ .1,  .2,  .3,  .4, .1],
                                   [ .5,  .6,  .7,  .8, .1],
                                   [ .9,   1, .15,   1, .1],
                                   [.25, .35, .45, .55, .5],
                                   [ .1,  .5,  .1,  .1,  1]], dtype=np.float32)

    allowed_transitions = np.array([[ True,  True,  True,  True,  True],
                                    [ True,  True,  True,  True,  True],
                                    [ True, False,  True, False,  True],
                                    [ True,  True,  True,  True,  True],
                                    [ True,  True,  True,  True, False]])
    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence = self._decode_greedy_sequence(
        scores,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a multi-item batch.
    multiple_input = np.array([scores, scores, scores], dtype=np.float32)

    multiple_sequence_op = greedy_op.greedy_constrained_sequence(
        multiple_input, [2, 2, 2],
        allowed_transitions=allowed_transitions,
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    multiple_sequence_result = self.evaluate(multiple_sequence_op)
    self.assertAllEqual(multiple_sequence_result,
                        [sequence, sequence, sequence])
  def test_sequence_in_log_space_with_start_end_states_single_batch_item(self):
    use_log_space = True
    use_start_and_end_states = True
    scores = np.array([[10.0, 12.0, 7.0, 4.0], [13.0, 12.0, 11.0, 10.0]])
    # pyformat: disable
    # pylint: disable=bad-whitespace
    # pylint: disable=bad-continuation
    transition_weights = np.array([[-1.0,  1.0, -2.0,  2.0, 0.0],
                                   [ 3.0, -3.0,  4.0, -4.0, 0.0],
                                   [ 5.0,  1.0, 10.0,  1.0, 1.0],
                                   [-7.0,  7.0, -8.0,  8.0, 0.0],
                                   [ 0.0,  1.0,  2.0,  3.0, 0.0]],
                                  dtype=np.float32)

    allowed_transitions = np.array([[ True,  True,  True,  True,  True],
                                    [ True,  True,  True,  True,  True],
                                    [ True, False,  True, False, False],
                                    [ True,  True,  True,  True,  True],
                                    [ True, False,  True,  True,  True]])
    # pyformat: enable
    # pylint: enable=bad-whitespace
    # pylint: enable=bad-continuation
    sequence = self._decode_greedy_sequence(
        scores,
        transition_weights,
        allowed_transitions,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)

    # Test a single-item batch.
    single_input = np.array([scores], dtype=np.float32)
    single_sequence_op = greedy_op.greedy_constrained_sequence(
        single_input, [2],
        allowed_transitions=allowed_transitions,
        transition_weights=transition_weights,
        use_log_space=use_log_space,
        use_start_and_end_states=use_start_and_end_states)
    single_result = self.evaluate(single_sequence_op)
    self.assertAllEqual(single_result, [sequence])