コード例 #1
0
    def test_batch_tensors_does_not_modify_list(self):
        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict1 = field.as_tensor(padding_lengths)

        field = ProductionRuleField("NP -> test", is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict2 = field.as_tensor(padding_lengths)
        tensor_list = [tensor_dict1, tensor_dict2]
        assert field.batch_tensors(tensor_list) == tensor_list
コード例 #2
0
    def test_as_tensor_produces_correct_output(self):
        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=False)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None