Ejemplo n.º 1
0
    def test_inference_no_head(self):
        model = CanineModel.from_pretrained("google/canine-s")
        # this one corresponds to the first example of the TydiQA dev set (in Swahili)
        # fmt: off
        input_ids = [57344, 57349, 85, 107, 117, 98, 119, 97, 32, 119, 97, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 111, 114, 105, 32, 110, 105, 32, 107, 105, 97, 115, 105, 32, 103, 97, 110, 105, 63, 57345, 57350, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 111, 114, 105, 32, 44, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 97, 117, 32, 105, 110, 103, 46, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 40, 112, 105, 97, 58, 32, 84, 111, 108, 105, 109, 97, 110, 32, 97, 117, 32, 82, 105, 103, 105, 108, 32, 75, 101, 110, 116, 97, 117, 114, 117, 115, 41, 32, 110, 105, 32, 110, 121, 111, 116, 97, 32, 105, 110, 97, 121, 111, 110, 103, 39, 97, 97, 32, 115, 97, 110, 97, 32, 107, 97, 116, 105, 107, 97, 32, 97, 110, 103, 97, 32, 121, 97, 32, 107, 117, 115, 105, 110, 105, 32, 107, 119, 101, 110, 121, 101, 32, 107, 117, 110, 100, 105, 110, 121, 111, 116, 97, 32, 121, 97, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 40, 112, 105, 97, 58, 32, 105, 110, 103, 46, 32, 67, 101, 110, 116, 97, 117, 114, 117, 115, 41, 46, 32, 78, 105, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 107, 117, 110, 103, 97, 97, 32, 115, 97, 110, 97, 32, 121, 97, 32, 110, 110, 101, 32, 97, 110, 103, 97, 110, 105, 32, 108, 97, 107, 105, 110, 105, 32, 104, 97, 105, 111, 110, 101, 107, 97, 110, 105, 32, 107, 119, 101, 110, 121, 101, 32, 110, 117, 115, 117, 100, 117, 110, 105, 97, 32, 121, 97, 32, 107, 97, 115, 107, 97, 122, 105, 110, 105, 46, 32, 57351, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 110, 105, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 112, 101, 107, 101, 101, 32, 107, 119, 97, 32, 115, 97, 98, 97, 98, 117, 32, 110, 105, 32, 110, 121, 111, 116, 97, 32, 121, 101, 116, 117, 32, 106, 105, 114, 97, 110, 105, 32, 107, 97, 116, 105, 107, 97, 32, 97, 110, 103, 97, 32, 105, 110, 97, 32, 117, 109, 98, 97, 108, 105, 32, 119, 97, 32, 109, 105, 97, 107, 97, 32, 121, 97, 32, 110, 117, 114, 117, 32, 52, 46, 50, 46, 32, 73, 110, 97, 111, 110, 101, 107, 97, 110, 97, 32, 97, 110, 103, 97, 110, 105, 32, 107, 97, 114, 105, 98, 117, 32, 110, 97, 32, 107, 117, 110, 100, 105, 110, 121, 111, 116, 97, 32, 121, 97, 32, 83, 97, 108, 105, 98, 117, 32, 40, 67, 114, 117, 120, 41, 46, 32, 57352, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 40, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 41, 32, 105, 110, 97, 111, 110, 101, 107, 97, 110, 97, 32, 107, 97, 109, 97, 32, 110, 121, 111, 116, 97, 32, 109, 111, 106, 97, 32, 108, 97, 107, 105, 110, 105, 32, 107, 119, 97, 32, 100, 97, 114, 117, 98, 105, 110, 105, 32, 107, 117, 98, 119, 97, 32, 105, 110, 97, 111, 110, 101, 107, 97, 110, 97, 32, 107, 117, 119, 97, 32, 109, 102, 117, 109, 111, 32, 119, 97, 32, 110, 121, 111, 116, 97, 32, 116, 97, 116, 117, 32, 122, 105, 110, 97, 122, 111, 107, 97, 97, 32, 107, 97, 114, 105, 98, 117, 32, 110, 97, 32, 107, 117, 115, 104, 105, 107, 97, 109, 97, 110, 97, 32, 107, 97, 116, 105, 32, 121, 97, 111, 46, 32, 78, 121, 111, 116, 97, 32, 109, 97, 112, 97, 99, 104, 97, 32, 122, 97, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 65, 32, 110, 97, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 66, 32, 122, 105, 107, 111, 32, 109, 105, 97, 107, 97, 32, 121, 97, 32, 110, 117, 114, 117, 32, 52, 46, 51, 54, 32, 107, 117, 116, 111, 107, 97, 32, 107, 119, 101, 116, 117, 32, 110, 97, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 116, 97, 116, 117, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 67, 32, 97, 117, 32, 80, 114, 111, 120, 105, 109, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 105, 110, 97, 32, 117, 109, 98, 97, 108, 105, 32, 119, 97, 32, 109, 105, 97, 107, 97, 32, 121, 97, 32, 110, 117, 114, 117, 32, 52, 46, 50, 50, 46, 32, 57353, 32, 80, 114, 111, 120, 105, 109, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 40, 121, 97, 97, 110, 105, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 105, 108, 105, 121, 111, 32, 107, 97, 114, 105, 98, 117, 32, 122, 97, 105, 100, 105, 32, 110, 97, 115, 105, 41, 32, 105, 109, 101, 103, 117, 110, 100, 117, 108, 105, 119, 97, 32, 107, 117, 119, 97, 32, 110, 97, 32, 115, 97, 121, 97, 114, 105, 32, 109, 111, 106, 97, 46, 32, 86, 105, 112, 105, 109, 111, 32, 118, 105, 110, 97, 118, 121, 111, 112, 97, 116, 105, 107, 97, 110, 97, 32, 104, 97, 100, 105, 32, 115, 97, 115, 97, 32, 122, 105, 110, 97, 111, 110, 121, 101, 115, 104, 97, 32, 117, 119, 101, 122, 101, 107, 97, 110, 111, 32, 109, 107, 117, 98, 119, 97, 32, 121, 97, 32, 107, 119, 97, 109, 98, 97, 32, 115, 97, 121, 97, 114, 105, 32, 104, 105, 105, 32, 110, 105, 32, 121, 97, 32, 109, 119, 97, 109, 98, 97, 32, 40, 107, 97, 109, 97, 32, 100, 117, 110, 105, 97, 32, 121, 101, 116, 117, 44, 32, 77, 105, 114, 105, 104, 105, 32, 97, 117, 32, 90, 117, 104, 117, 114, 97, 41, 32, 110, 97, 32, 105, 110, 97, 119, 101, 122, 97, 32, 107, 117, 119, 97, 32, 110, 97, 32, 97, 110, 103, 97, 104, 101, 119, 97, 44, 32, 116, 101, 110, 97, 32, 107, 97, 116, 105, 107, 97, 32, 117, 112, 101, 111, 32, 119, 97, 32, 106, 111, 116, 111, 32, 117, 110, 97, 111, 114, 117, 104, 117, 115, 117, 32, 107, 117, 119, 101, 112, 111, 32, 107, 119, 97, 32, 117, 104, 97, 105, 46, 32, 91, 49, 93, 57345, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        attention_mask = [1 if x != 0 else 0 for x in input_ids]
        token_type_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        # fmt: on
        input_ids = torch.tensor([input_ids])
        attention_mask = torch.tensor([attention_mask])
        token_type_ids = torch.tensor([token_type_ids])
        outputs = model(input_ids, attention_mask, token_type_ids)

        # verify sequence output
        expected_shape = torch.Size((1, 2048, 768))
        self.assertEqual(outputs.last_hidden_state.shape, expected_shape)

        expected_slice = torch.tensor(
            [
                [-0.161433131, 0.395568609, 0.0407391489],
                [-0.108025983, 0.362060368, -0.544592619],
                [-0.141537309, 0.180541009, 0.076907],
            ]
        )

        self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-2))

        # verify pooled output
        expected_shape = torch.Size((1, 768))
        self.assertEqual(outputs.pooler_output.shape, expected_shape)

        expected_slice = torch.tensor([-0.884311497, -0.529064834, 0.723164916])

        self.assertTrue(torch.allclose(outputs.pooler_output[0, :3], expected_slice, atol=1e-2))
Ejemplo n.º 2
0
 def create_and_check_model(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = CanineModel(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
     result = model(input_ids, token_type_ids=token_type_ids)
     result = model(input_ids)
     self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
Ejemplo n.º 3
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):

    # Initialize PyTorch model
    config = CanineConfig()
    model = CanineModel(config)
    model.eval()

    print(f"Building PyTorch model from configuration: {config}")

    # Load weights from tf checkpoint
    load_tf_weights_in_canine(model, config, tf_checkpoint_path)

    # Save pytorch-model (weights and configuration)
    print(f"Save PyTorch model to {pytorch_dump_path}")
    model.save_pretrained(pytorch_dump_path)

    # Save tokenizer files
    tokenizer = CanineTokenizer()
    print(f"Save tokenizer files to {pytorch_dump_path}")
    tokenizer.save_pretrained(pytorch_dump_path)
 def test_model_from_pretrained(self):
     for model_name in CANINE_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = CanineModel.from_pretrained(model_name)
         self.assertIsNotNone(model)