def test_next_tokens(self): """ Tests that the set of next tokens is correct. """ for example in self.examples: constraints, expected, gold_counts = example root = ConstraintNode.create(constraints) root_tokens = set(root.children.keys()) for sequence in constraints: state = UnorderedConstraintState(root) for token in sequence: all_tokens = root_tokens.union(state.node.children.keys()) assert (all_tokens == state.next_tokens() ), f"ALL {all_tokens} NEXT {state.next_tokens()}" state = state.advance(token)
def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): self.constraint_states = [] for constraint_tensor in batch_constraints: if self.representation == "ordered": constraint_state = OrderedConstraintState.create(constraint_tensor) elif self.representation == "unordered": constraint_state = UnorderedConstraintState.create(constraint_tensor) self.constraint_states.append([constraint_state for i in range(beam_size)])
def test_sequences(self): for constraints, tokens, expected in self.sequences: state = UnorderedConstraintState.create( pack_constraints([constraints])[0]) for token in tokens: state = state.advance(token) result = {} for attr in expected.keys(): result[attr] = getattr(state, attr) assert (result == expected ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"