class ZeroShotExtractor: def __init__(self, labels, domain_utils): self._tokenizer = BertPreTokenizer() # should align with reader's tokenizer self._token_indexers = PretrainedBertIndexer( pretrained_model='bert-base-uncased') # should align with reader's tokenizer self._is_cude = torch.cuda.device_count() > 0 self._domain_utils = domain_utils self._set_labels_wordpieces(labels) def _get_wordpieces(self, text): # should match the reader method tokens = self._tokenizer.tokenize(text) do_lowercase = True tokens_out = ( token.text.lower() if do_lowercase and token.text not in self._tokenizer.never_split else token.text for token in tokens ) wps = [ [wordpiece for wordpiece in self._token_indexers.wordpiece_tokenizer(token)] for token in tokens_out ] wps_flat = [wordpiece for token in wps for wordpiece in token] return tuple(wps_flat) def _set_labels_wordpieces(self, labels): self._num_labels = len(labels) self._labels_wordpieces = defaultdict(list) for index, label in labels.items(): if label == 'NO-LABEL' or label == 'span': continue lexicon_phrases = self._domain_utils.get_lexicon_phrase(label) for lexicon_phrase in lexicon_phrases: self._labels_wordpieces[self._get_wordpieces(lexicon_phrase)].append(index) def get_similarity_features(self, batch_tokens, batch_spans): device = 'cuda' if self._is_cude else 'cpu' similarities = torch.zeros([batch_spans.shape[0], batch_spans.shape[1], self._num_labels], dtype=torch.float32, requires_grad=False, device=device) for k, (sentence, spans) in enumerate(zip(batch_tokens, batch_spans)): sent_len = len(sentence) span_to_ind = {} for i, span in enumerate(spans): span_to_ind[tuple(span.tolist())] = i for i in range(sent_len): for j in range(i+1, i+6): if j > sent_len: break labels = self._labels_wordpieces.get(sentence[i:j]) if labels: start = i + 1 end = j for label in labels: similarities[k, span_to_ind[(start, end)], label] = 1.0 return similarities
class SpanMapper(object): def __init__(self): self._tokenizer = BertPreTokenizer( ) # should align with reader's tokenizer self._token_indexers = PretrainedBertIndexer( pretrained_model='bert-base-uncased' ) # should align with reader's tokenizer self._synonyms = {'arguments': {}, 'predicates': {}} # self._load_argumets() # self._enrich_synonyms() # self._enrich_synonyms_by_hand() self._parser = pyparsing.nestedExpr( '(', ')', ignoreExpr=pyparsing.dblQuotedString) # self._parser.setParseAction(self._parse_action) def _parse_action(self, string, location, tokens) -> Tree: raise NotImplementedError def _get_wordpieces(self, text): # should match the reader method tokens = self._tokenizer.tokenize(text) do_lowercase = True tokens_out = (token.text.lower() if do_lowercase and token.text not in self._tokenizer.never_split else token.text for token in tokens) wps = [[ wordpiece for wordpiece in self._token_indexers.wordpiece_tokenizer(token) ] for token in tokens_out] wps_flat = [wordpiece for token in wps for wordpiece in token] return tuple(wps_flat) def _align_to_text(self, constant, type): spans = [] realizations = self._synonyms[type][constant] num_tokens_text = len(self._tokens) for realization in realizations: num_tokens_real = len(realization) for begin in range(num_tokens_text - num_tokens_real + 1): end = begin + num_tokens_real if self._tokens[begin:end] == realization: span = Span(height=1, span=(begin, end), content=self._tokens[begin:end], constant=constant) spans.append(span) # assert len(spans) == 1 return spans def _align_filter_to_text(self, constants, type, hint_span=None): spans_per_constant = [] for i, constant in enumerate(constants): constant_spans = self._align_to_text( constant, type) # find spans in text for a particular constant # if len(constants) == 1 and hint_span != None: # # if i == 0 and hint_span != None: # constant_spans_ = self._choose_span_by_hint(constant_spans, hint_span) # if constant_spans_ != constant_spans: # print('here') # constant_spans = constant_spans_ # if len(constants) == 2 and hint_span != None and i == 1 and len(constant_spans) > 1: # constant_spans__ = [] # for j, span in enumerate(constant_spans): # # temporary bad way to disambiguate spans # if span.span[0] < 4 and hint_span.span[0] <= 6: # constant_spans__.append(span) # if len(constant_spans__) > 0: # print('here') # risky # constant_spans = constant_spans__ # assert len(constant_spans) >= 1 spans_per_constant.append(constant_spans) contiguous_chains = self._find_contiguous_chain(spans_per_constant) # if len(contiguous_chains) > 1 and hint_span: # contiguous_chains_ = self._filter_chains_with_hint(contiguous_chains, hint_span) # if contiguous_chains_ != contiguous_chains: # print('here') # risky # contiguous_chains = contiguous_chains_ if len(contiguous_chains) == 2: if not self._possible_chains: next_option = tuple( [span.constant for span in contiguous_chains[1]]) self._possible_chains = contiguous_chains[1] contiguous_chains = [contiguous_chains[0]] else: contiguous_chains = [self._possible_chains] self._possible_chains = None if len(contiguous_chains) != 1: print('here') assert len(contiguous_chains) == 1 contiguous_chain = contiguous_chains[0] return self._contiguous_chain_to_subtree(contiguous_chain) def _filter_chains_with_hint(self, contiguous_chains, hint_span): # min delta from span edges. Sets a high values if span is within the hint_span delta_from_hint = [ min(abs(chain[0].span[0] - hint_span.span[1]), abs(chain[-1].span[1] - hint_span.span[0])) if not self._is_span_contained(chain[0], hint_span) else 10000 for chain in contiguous_chains ] min_value = min(delta_from_hint) min_chains = [] for i, chain in enumerate(contiguous_chains): if delta_from_hint[i] == min_value: min_chains.append(chain) # if min_value == 0 and len(min_spans) > 1: # take the last span if all min_values are 0 # min_spans = [sorted(spans, key=lambda span: span.span[0], reverse=True)[0]] return min_chains def _filter_sub_chains(self, contiguous_chains): """Filter chains that are actually a part of a larger chain""" full_contiguous_chains = [] tokens = self._tokens for chain in contiguous_chains: start = chain[0].span[0] end = chain[-1].span[1] if start > 0: # there is at least one token before 'start' if tuple(tokens[start - 1:start]) in self._filter_args: continue if start > 1: # there are at least two tokens before 'start' if tuple(tokens[start - 2:start]) in self._filter_args: continue if tuple(self._tokens[end:end + 1]) in self._filter_args: continue if tuple(self._tokens[end:end + 2]) in self._filter_args: continue full_contiguous_chains.append(chain) return full_contiguous_chains def _find_contiguous_chain(self, spans_per_constant: List[List[Span]]): contiguous_chains = [] combinations = list(product(*spans_per_constant)) for comb in combinations: if all([ s.span[1] == comb[i + 1].span[0] for i, s in enumerate(comb[:-1]) ]): # check if contiguous if not any([ self._is_span_contained(sub_span, span) for sub_span in comb for span in self._decided_spans ]): # check if span wasn't decided before contiguous_chains.append(comb) else: print('here') if len(contiguous_chains) > 1: contiguous_chains_ = self._filter_sub_chains(contiguous_chains) if contiguous_chains != contiguous_chains_: print('here') contiguous_chains = contiguous_chains_ return contiguous_chains def _contiguous_chain_to_subtree(self, contiguous_chain: List[Span]): self._decided_spans += contiguous_chain tree = Tree() stack = [] for i in range(len(contiguous_chain) - 1): # make parent start = contiguous_chain[i].span[0] end = contiguous_chain[-1].span[1] span = Span(height=len(contiguous_chain) - 1, span=(start, end), content=self._tokens[start:end], constant=None) parent = stack[-1] if len(stack) > 0 else None identifier = '{}-{}'.format(start, end) tree.create_node(identifier=identifier, data=span, parent=parent) stack.append(identifier) # make left child span_lc = contiguous_chain[i] identifier_lc = '{}-{}'.format(span_lc.span[0], span_lc.span[1]) tree.create_node(identifier=identifier_lc, data=span_lc, parent=identifier) # make last right child span_rc = contiguous_chain[-1] identifier_rc = '{}-{}'.format(span_rc.span[0], span_rc.span[1]) tree.create_node(identifier=identifier_rc, data=span_rc, parent=stack[-1] if stack else None) return tree # return top most identifier def _join_trees(self, subtree_1: Tree, subtree_2: Tree): top_tree = Tree() arg_1_span = subtree_1.root.split('-') arg_2_span = subtree_2.root.split('-') start = int(arg_1_span[0]) end = int(arg_2_span[1]) identifier = '{}-{}'.format(start, end) span = Span(height=100, span=(start, end), content=self._tokens[start:end], constant=None) top_tree.create_node(identifier=identifier, data=span) top_tree.paste(nid=identifier, new_tree=subtree_1) top_tree.paste(nid=identifier, new_tree=subtree_2) return top_tree def _combine_trees(self, subtree_1: Tree, subtree_2: Tree): subtree_2.paste(nid=subtree_2.root, new_tree=subtree_1) return subtree_2 def _join_binary_predicate_tree(self, predicate: Tree, arg_1: Tree, arg_2: Tree, allow_arg_switch: bool = True): predicate_start = int(predicate.root.split('-')[0]) arg_1_start = int(arg_1.root.split('-')[0]) arg_2_start = int(arg_2.root.split('-')[0]) # if arg_2 is not in the middle if not (arg_2_start > predicate_start and arg_2_start < arg_1_start ) and not (arg_2_start < predicate_start and arg_2_start > arg_1_start): if predicate_start < arg_1_start: # predicate is left to arg_1 join_1 = self._join_trees(predicate, arg_1) else: join_1 = self._join_trees(arg_1, predicate) if predicate_start < arg_2_start: # predicate is left to arg_2 join_2 = self._join_trees(join_1, arg_2) else: join_2 = self._join_trees(arg_2, join_1) return join_2 else: if not allow_arg_switch: # in this case we do not allow arg_2 to be in a span with the predicate raise Exception( 'Argument switch is not allowed={}'.format(predicate)) if predicate_start < arg_1_start: # predicate is left to arg_1, and arg_2 is in the middle join_1 = self._join_trees(predicate, arg_2) join_2 = self._join_trees(join_1, arg_1) else: join_1 = self._join_trees(arg_2, predicate) join_2 = self._join_trees(arg_1, join_1) return join_2 def _join_unary_predicate_tree(self, predicate: Tree, arg: Tree): predicate_start = int(predicate.root.split('-')[0]) arg_start = int(arg.root.split('-')[0]) if predicate_start < arg_start: # predicate is left to arg_1 join_tree = self._join_trees(predicate, arg) else: join_tree = self._join_trees(arg, predicate) return join_tree def _filter_contained_spans(self, spans): spans_ = sorted(spans, key=lambda span: span.span[1] - span.span[0], reverse=True) # sort according to span length if spans_ != spans: print('here') spans = spans_ filtered_spans = [] for span in spans: for broad_span in filtered_spans: if self._is_span_contained( span, broad_span): # span contained in broad_span break else: filtered_spans.append(span) return filtered_spans def _is_span_contained(self, span_1: Span, span_2: Span): return set(range(span_1.span[0], span_1.span[1])).issubset( set(range(span_2.span[0], span_2.span[1]))) def _is_spans_intersect(self, span_1: Span, span_2: Span): return len( set(range(span_1.span[0], span_1.span[1])).intersection( set(range(span_2.span[0], span_2.span[1])))) > 0 def _choose_span_by_hint(self, spans, hint_span): """Chooses that span that is closest to the hint span. The hint span is the one the selected span should be close to.""" # min delta from span edges. Sets a high values if span is within the hint_span delta_from_hint = [ min(abs(span.span[0] - hint_span.span[1]), abs(span.span[1] - hint_span.span[0])) if not self._is_span_contained(span, hint_span) else 10000 for span in spans ] min_value = min(delta_from_hint) min_spans = [] for i, span in enumerate(spans): if delta_from_hint[i] == min_value: min_spans.append(span) if min_value == 0 and len( min_spans) > 1: # take the last span if all min_values are 0 min_spans = [ sorted(spans, key=lambda span: span.span[0], reverse=True)[0] ] # risky return min_spans def _get_tree_from_constant(self, constant, type, hint_span=None, constant_prefix=None): spans = self._align_to_text(constant, type) spans_ = self._filter_contained_spans(spans) if spans != spans_: print('here') spans = spans_ if len(spans) != 1: print('here') # if len(spans) > 1 and hint_span: # spans__ = self._choose_span_by_hint(spans, hint_span) # if spans != spans__: # print('here') # spans = spans__ spans___ = [] for span in spans: if not any([ self._is_span_contained(span, decided_span) for decided_span in self._decided_spans ]): spans___.append(span) if spans___ != spans: print('here') spans = spans___ if len(spans) == 2: if not constant in self._possible_constants: self._possible_constants[constant] = spans[1] spans = [spans[0]] else: spans = [self._possible_constants[constant]] del self._possible_constants[constant] if len(spans) != 1: print('spans for {} are {}'.format(constant, spans)) assert len(spans) == 1 span = spans[0] if constant_prefix: span.constant = '{}#{}'.format(constant_prefix, span.constant) self._decided_spans.append(span) identifier = '{}-{}'.format(span.span[0], span.span[1]) constant_tree = Tree() constant_tree.create_node(identifier=identifier, data=span) return constant_tree def is_valid_tree(self, parse_tree: Tree): is_violateing = [ self._is_violating_node(node, parse_tree) for node in parse_tree.expand_tree() ] if any(is_violateing): print('here') return not any(is_violateing) def is_projective_tree(self, parse_tree: Tree): is_violateing = [ len(parse_tree.children(node)) > 2 for node in parse_tree.expand_tree() ] if any(is_violateing): print('here') return not any(is_violateing) def _is_violating_node(self, node, parse_tree): """Checks id a node is violated - if its child's span is not contained in its parent span, or intersect another child.""" node_span = parse_tree.get_node(node).data for child in parse_tree.children(node): child_span = child.data if not self._is_span_contained( child_span, node_span): # not contained in parent's span print('node {} is not contained in parent {}'.format( child_span.to_string, node_span.to_string)) return True for child_other in parse_tree.children(node): child_other_span = child_other.data if child != child_other and self._is_spans_intersect( child_span, child_other_span): # intersects another span print('node {} intersectes node {}'.format( child_span.to_string, child_other_span.to_string)) return True return False def map_prog_to_tree(self, question, program): program = re.sub(r'(\w+) \(', r'( \1', program) self._program = program.replace(',', '') self._tokens = self._get_wordpieces(question) self._tree = Tree() self._decided_spans = [] parse_result = self._parser.parseString(self._program)[0] return parse_result # # print the program tree # executor.parser.setParseAction(_parse_action_tree) # tree_parse = executor.parser.parseString(program)[0] # print('parse_tree=') # pprint(tree_parse) # def pprint(node, tab=""): # if isinstance(node, str): # print(tab + u"┗━ " + str(node)) # return # print(tab + u"┗━ " + str(node.value)) # for child in node.children: # pprint(child, tab + " ") def _parse_action_tree(string, location, tokens): from collections import namedtuple Node = namedtuple("Node", ["value", "children"]) node = Node(value=tokens[0][0], children=tokens[0][1:]) return node def _get_aligned_span(self, subtree: Tree): """Gets hint span for aligning ambiguous constants (e.g., 'left' appears twice)""" return subtree.get_node(subtree.root).data def _get_first_argument_to_join(self, predicate_tree, arg1_tree, arg2_tree): predicate_span_start = int(predicate_tree.root.split('-')[0]) arg1_span_start = int(arg1_tree.root.split('-')[0]) arg2_span_start = int(arg2_tree.root.split('-')[0]) if predicate_span_start < arg2_span_start < arg1_span_start or predicate_span_start > arg2_span_start > arg1_span_start: return arg2_tree, arg1_tree else: return arg1_tree, arg2_tree def _get_details(self, child, span_labels): data = child.data start = data.span[0] end = data.span[1] - 1 span = (data.span[0], data.span[1] - 1) type = data.constant if data.constant else 'span' is_span = type == 'span' if not is_span: span_labels.append({'span': span, 'type': type}) return start, end, is_span def _adjust_end(self, start, end, adjusted_end, is_span, span_labels): if end < adjusted_end - 1: span_labels.append({ 'span': (start, adjusted_end - 1), 'type': 'span' }) else: if is_span: span_labels.append({'span': (start, end), 'type': 'span'}) def _inner_write(self, span_labels, children, end, parse_tree): children.sort(key=lambda c: c.data.span[0]) start_1, end_1, is_span_1 = self._get_details(children[0], span_labels) start_2, end_2, is_span_2 = self._get_details(children[1], span_labels) if len(children) > 2: start_3, end_3, is_span_3 = self._get_details( children[2], span_labels) self._adjust_end(start_1, end_1, start_2, is_span_1, span_labels) if len(children) > 2: self._adjust_end(start_2, end_2, start_3, is_span_2, span_labels) self._adjust_end(start_3, end_3, end + 1, is_span_3, span_labels) # if end_2 < start_3 - 1: # span_labels.append({'span': (start_2, start_3 - 1), 'type': 'span'}) # if end_3 < end: # span_labels.append({'span': (start_3, end), 'type': 'span'}) else: self._adjust_end(start_2, end_2, end + 1, is_span_2, span_labels) # if end_2 < end: # span_labels.append({'span': (start_2, end), 'type': 'span'}) children_1 = parse_tree.children(children[0].identifier) if len(children_1) > 0: self._inner_write(span_labels, children_1, start_2 - 1, parse_tree) children_2 = parse_tree.children(children[1].identifier) if len(children) > 2: children_3 = parse_tree.children(children[1].identifier) if len(children_2) > 0: self._inner_write(span_labels, children_2, start_3 - 1, parse_tree) if len(children_3) > 0: self._inner_write(span_labels, children_3, end, parse_tree) else: if len(children_2) > 0: self._inner_write(span_labels, children_2, end, parse_tree) def write_to_output(self, line, parse_tree, output_file): tokens = self._get_wordpieces(line['question']) if line['question'] == "what state borders michigan ?": print() len_sent = len(tokens) span_labels = [] # type = 'span' # span_labels.append({'span': (0, len_sent-1), 'type': type}) root = parse_tree.root root_node = parse_tree.get_node(root).data # root_start, root_end = root.split('-') # root_start = int(root_start) # root_end = int(root_end) # # if root_start > 0: # start = 0 end = len_sent - 1 type = 'span' span_labels.append({'span': (start, end), 'type': type}) children = parse_tree.children(root) if len(children) == 0: s, t = (int(root_node.span[0]), int(root_node.span[1])) type = root_node.constant span_labels.append({'span': (s, t), 'type': type}) if s > 0: span_labels.append({'span': (s, end), 'type': 'span'}) else: child_1_start = children[0].data.span[0] if child_1_start > 0: span_labels.append({ 'span': (child_1_start, end), 'type': 'span' }) self._inner_write(span_labels, parse_tree.children(root), end, parse_tree) # while (len(parse_tree.children(root)) > 0): # children = parse_tree.children(root) # data_1 = children[0].data # span_1 = (data_1.span[0], data_1.span[1] - 1) # data_2 = children[1].data # span_2 = (data_2.span[0], data_2.span[1] - 1) # print() # for i, node in enumerate(parse_tree.expand_tree()): # data = parse_tree.get_node(node).data # span = (data.span[0], data.span[1]-1) # move to inclusive spans # if i==0: # left_extra = None # if span[0] > 0: # left_extra = (0, span[0]-1) # right_extra = None # if span[1] < len_sent-1: # left_right = (span[1]+1, len_sent-1) # # type = data.constant if data.constant else 'span' # span_labels.append({'span': span, 'type': type}) line['gold_spans'] = span_labels json_str = json.dumps(line) output_file.write(json_str + '\n')
class ConllCorefBertReader(DatasetReader): """ Reads a single CoNLL-formatted file. This is the same file format as used in the :class:`~allennlp.data.dataset_readers.semantic_role_labelling.SrlReader`, but is preprocessed to dump all documents into a single file per train, dev and test split. See scripts/compile_coref_data.sh for more details of how to pre-process the Ontonotes 5.0 data into the correct format. Returns a ``Dataset`` where the ``Instances`` have four fields: ``text``, a ``TextField`` containing the full document text, ``spans``, a ``ListField[SpanField]`` of inclusive start and end indices for span candidates, and ``metadata``, a ``MetadataField`` that stores the instance's original text. For data with gold cluster labels, we also include the original ``clusters`` (a list of list of index pairs) and a ``SequenceLabelField`` of cluster ids for every span candidate. Parameters ---------- max_span_width: ``int``, required. The maximum width of candidate spans to consider. token_indexers : ``Dict[str, TokenIndexer]``, optional This is used to index the words in the document. See :class:`TokenIndexer`. Default is ``{"tokens": SingleIdTokenIndexer()}``. """ def __init__(self, max_span_width: int, token_indexers: Dict[str, TokenIndexer] = None, lazy: bool = False) -> None: super().__init__(lazy) self._max_span_width = max_span_width self._token_indexers = { "tokens": PretrainedBertIndexer("bert-base-cased", do_lowercase=False) } self.token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False) @overrides def _read(self, file_path: str): # if `file_path` is a URL, redirect to the cache file_path = cached_path(file_path) ontonotes_reader = Ontonotes() for sentences in ontonotes_reader.dataset_document_iterator(file_path): clusters: DefaultDict[int, List[Tuple[ int, int]]] = collections.defaultdict(list) total_tokens = 0 for sentence in sentences: for typed_span in sentence.coref_spans: # Coref annotations are on a _per sentence_ # basis, so we need to adjust them to be relative # to the length of the document. span_id, (start, end) = typed_span clusters[span_id].append( (start + total_tokens, end + total_tokens)) total_tokens += len(sentence.words) canonical_clusters = canonicalize_clusters(clusters) new_sentences = [s.words for s in sentences] flattened_sentences = [ self._normalize_word(word) for sentence in new_sentences for word in sentence ] def tokenizer(s: str): return self.token_indexer.wordpiece_tokenizer(s) flattened_sentences = tokenizer(" ".join(flattened_sentences)) if len(flattened_sentences) > 510: continue yield self.text_to_instance([s.words for s in sentences], canonical_clusters) def align_token(self, text, span): """ Retokenize one span for one individual span. """ current = self.token_indexer.wordpiece_tokenizer(" ".join( text[:span[0]])) start_span = len(current) span_embedding = self.token_indexer.wordpiece_tokenizer(" ".join( text[span[0]:span[1]])) end_span = start_span + len(span_embedding) return start_span, end_span def align_clusters_to_tokens(self, text, clusters): new_clusters = [] for cluster in clusters: new_cluster = [] for span in cluster: new_cluster.append(self.align_token(text, span)) new_clusters.append(new_cluster) return new_clusters @overrides def text_to_instance( self, # type: ignore sentences: List[List[str]], gold_clusters: Optional[List[List[Tuple[int, int]]]] = None) -> Instance: # pylint: disable=arguments-differ """ Parameters ---------- sentences : ``List[List[str]]``, required. A list of lists representing the tokenised words and sentences in the document. gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None) A list of all clusters in the document, represented as word spans. Each cluster contains some number of spans, which can be nested and overlap, but will never exactly match between clusters. Returns ------- An ``Instance`` containing the following ``Fields``: text : ``TextField`` The text of the full document. spans : ``ListField[SpanField]`` A ListField containing the spans represented as ``SpanFields`` with respect to the document text. span_labels : ``SequenceLabelField``, optional The id of the cluster which each possible span belongs to, or -1 if it does not belong to a cluster. As these labels have variable length (it depends on how many spans we are considering), we represent this a as a ``SequenceLabelField`` with respect to the ``spans ``ListField``. """ flattened_sentences = [ self._normalize_word(word) for sentence in sentences for word in sentence ] # align clusters gold_clusters = self.align_clusters_to_tokens(flattened_sentences, gold_clusters) def tokenizer(s: str): return self.token_indexer.wordpiece_tokenizer(s) # we nee dto try this with the other one. flattened_sentences = tokenizer(" ".join(flattened_sentences)) metadata: Dict[str, Any] = {"original_text": flattened_sentences} if gold_clusters is not None: metadata["clusters"] = gold_clusters text_field = TextField([Token(word) for word in flattened_sentences], self._token_indexers) cluster_dict = {} if gold_clusters is not None: for cluster_id, cluster in enumerate(gold_clusters): for mention in cluster: cluster_dict[tuple(mention)] = cluster_id spans: List[Field] = [] span_labels: Optional[ List[int]] = [] if gold_clusters is not None else None sentence_offset = 0 normal = [] for sentence in sentences: # enumerate the spans. for start, end in enumerate_spans( sentence, offset=sentence_offset, max_span_width=self._max_span_width): if span_labels is not None: if (start, end) in cluster_dict: span_labels.append(cluster_dict[(start, end)]) else: span_labels.append(-1) # align the spans to the BERT tokeniation normal.append((start, end)) # span field for Span, which needs to be a flattened esnetnece. spans.append(SpanField(start, end, text_field)) sentence_offset += len(sentence) span_field = ListField(spans) metadata_field = MetadataField(metadata) fields: Dict[str, Field] = { "text": text_field, "spans": span_field, "metadata": metadata_field } if span_labels is not None: fields["span_labels"] = SequenceLabelField(span_labels, span_field) return Instance(fields) @staticmethod def _normalize_word(word): if word == "/." or word == "/?": return word[1:] else: return word