def test_batch_tensors_does_not_modify_list(self):
        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict1 = field.as_tensor(padding_lengths)

        field = ProductionRuleField("NP -> test", is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict2 = field.as_tensor(padding_lengths)
        tensor_list = [tensor_dict1, tensor_dict2]
        assert field.batch_tensors(tensor_list) == tensor_list
    def test_as_tensor_produces_correct_output(self):
        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=False)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
 def test_padding_lengths_are_computed_correctly(self):
     field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
     field.index(self.vocab)
     assert field.get_padding_lengths() == {}
 def test_index_converts_field_correctly(self):
     field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
     field.index(self.vocab)
     assert field._rule_id == self.s_rule_index