コード例 #1
0
ファイル: test_constraints.py プロジェクト: kahne/fairseq
    def test_next_tokens(self):
        """
        Tests that the set of next tokens is correct.
        """
        for example in self.examples:
            constraints, expected, gold_counts = example
            root = ConstraintNode.create(constraints)

            root_tokens = set(root.children.keys())
            for sequence in constraints:
                state = UnorderedConstraintState(root)
                for token in sequence:
                    all_tokens = root_tokens.union(state.node.children.keys())
                    assert (all_tokens == state.next_tokens()
                            ), f"ALL {all_tokens} NEXT {state.next_tokens()}"
                    state = state.advance(token)
コード例 #2
0
    def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):
        self.constraint_states = []
        for constraint_tensor in batch_constraints:
            if self.representation == "ordered":
                constraint_state = OrderedConstraintState.create(constraint_tensor)
            elif self.representation == "unordered":
                constraint_state = UnorderedConstraintState.create(constraint_tensor)

            self.constraint_states.append([constraint_state for i in range(beam_size)])
コード例 #3
0
ファイル: test_constraints.py プロジェクト: kahne/fairseq
    def test_sequences(self):
        for constraints, tokens, expected in self.sequences:
            state = UnorderedConstraintState.create(
                pack_constraints([constraints])[0])
            for token in tokens:
                state = state.advance(token)
            result = {}
            for attr in expected.keys():
                result[attr] = getattr(state, attr)

            assert (result == expected
                    ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"