def test_batch_tensors_does_not_modify_list(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict1 = field.as_tensor(padding_lengths) field = ProductionRuleField("NP -> test", is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict2 = field.as_tensor(padding_lengths) tensor_list = [tensor_dict1, tensor_dict2] assert field.batch_tensors(tensor_list) == tensor_list
def test_as_tensor_produces_correct_output(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 4 assert tensor_tuple[0] == "S -> [NP, VP]" assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) field = ProductionRuleField("S -> [NP, VP]", is_global_rule=False) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 4 assert tensor_tuple[0] == "S -> [NP, VP]" assert tensor_tuple[1] is False assert tensor_tuple[2] is None
def test_padding_lengths_are_computed_correctly(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field.index(self.vocab) assert field.get_padding_lengths() == {}