def _get_token_positions_to_span(doc: Document, start_token: int, end_token: int, wrt_span): start, end, _ = wrt_span token_span = TokenSpan(start, end) return [ token_span.token_distance_to(TokenSpan(idx, idx + 1)) for idx in range(start_token, end_token) ]
def test_leq(self): self.assertTrue(TokenSpan(1, 2) < TokenSpan(2, 3)) self.assertTrue(TokenSpan(1, 2) < TokenSpan(2, 3)) self.assertFalse(TokenSpan(2, 3) < TokenSpan(0, 1)) self.assertRaises(Exception, lambda: TokenSpan(0, 1) < Sentence(2, 3)) self.assertRaises(Exception, lambda: Sentence(0, 1) < TokenSpan(2, 3))
def get_marked_tokens_on_root_path_for_span(doc: Document, span, *, add_distance=False): start, end, span_sent_idx = span distances_to_root = doc.token_features["dt_head_distances"] main_span_token = find_span_head_token(doc, TokenSpan(start, end)) res = [False] * len(doc.tokens) for sent_idx in range(len(doc.sentences)): if sent_idx != span_sent_idx: continue distance_from_span = 0 res[main_span_token] = distance_from_span if add_distance else True distance_to_parent = distances_to_root[main_span_token] current_idx = main_span_token while distance_to_parent != 0: current_idx += distance_to_parent distance_from_span += 1 res[current_idx] = distance_from_span if add_distance else True distance_to_parent = distances_to_root[current_idx] return res
def test_leq(self): self.assertTrue(Entity("_", 2, 3, "A") < Entity("_", 2, 4, "A")) self.assertTrue(Entity("_", 2, 3, "A") < Entity("_", 2, 3, "B")) self.assertTrue(Entity("1", 2, 3, "B") < Entity("11", 2, 3, "B")) self.assertTrue(Entity("1", 0, 1, "B") < Entity("1", 2, 3, "B")) self.assertRaises(Exception, lambda: Entity("_", 0, 1, "A") < TokenSpan(2, 3))
def test_token_distance_to(self): self.assertEqual(TokenSpan(4, 7).token_distance_to(TokenSpan(5, 6)), 0) self.assertEqual(TokenSpan(4, 7).token_distance_to(TokenSpan(6, 7)), 0) self.assertEqual( TokenSpan(4, 7).token_distance_to(TokenSpan(0, 10)), 0) self.assertEqual( TokenSpan(4, 7).token_distance_to(TokenSpan(5, 11)), 0) self.assertEqual(TokenSpan(4, 7).token_distance_to(TokenSpan(4, 8)), 0) self.assertEqual(TokenSpan(4, 7).token_distance_to(TokenSpan(7, 9)), 1) self.assertEqual( TokenSpan(4, 7).token_distance_to(TokenSpan(10, 12)), 4) self.assertEqual(TokenSpan(4, 7).token_distance_to(TokenSpan(3, 4)), 1) self.assertEqual(TokenSpan(4, 7).token_distance_to(TokenSpan(0, 1)), 4)
def test_coincides(self): self.assertTrue(TokenSpan(4, 7).coincides(TokenSpan(4, 7))) self.assertFalse(TokenSpan(4, 7).coincides(TokenSpan(4, 5))) self.assertFalse(TokenSpan(4, 7).coincides(TokenSpan(7, 8)))
def test_intersects(self): self.assertTrue(TokenSpan(4, 7).intersects(TokenSpan(5, 6))) self.assertTrue(TokenSpan(4, 7).intersects(TokenSpan(6, 7))) self.assertTrue(TokenSpan(4, 7).intersects(TokenSpan(0, 10))) self.assertTrue(TokenSpan(4, 7).intersects(TokenSpan(5, 11))) self.assertTrue(TokenSpan(4, 7).intersects(TokenSpan(4, 8))) self.assertFalse(TokenSpan(4, 7).intersects(TokenSpan(7, 9))) self.assertFalse(TokenSpan(4, 7).intersects(TokenSpan(3, 4)))
def test_contains(self): self.assertTrue(TokenSpan(4, 7).contains(TokenSpan(5, 6))) self.assertTrue(TokenSpan(4, 7).contains(TokenSpan(6, 7))) self.assertFalse(TokenSpan(4, 7).contains(TokenSpan(7, 9))) self.assertFalse(TokenSpan(4, 7).contains(TokenSpan(3, 4))) self.assertFalse(TokenSpan(4, 7).contains(TokenSpan(0, 10))) self.assertFalse(TokenSpan(4, 7).contains(TokenSpan(5, 11))) self.assertFalse(TokenSpan(4, 7).contains(TokenSpan(4, 8)))
def test_hash(self): self.assertEqual(hash(TokenSpan(10, 15)), hash(TokenSpan(10, 15))) self.assertNotEqual(hash(TokenSpan(10, 15)), hash(TokenSpan(100, 150)))
def test_eq(self): self.assertEqual(TokenSpan(10, 15), TokenSpan(10, 15)) self.assertNotEqual(TokenSpan(10, 15), TokenSpan(10, 11)) self.assertNotEqual(TokenSpan(10, 15), Sentence(10, 15))