def test_mapping_works_with_dict(self): field = MetadataField({"a": 1, "b": [0]}) assert "a" in field assert field["a"] == 1 assert len(field) == 2 keys = {k for k in field} assert keys == {"a", "b"} values = [v for v in field.values()] assert len(values) == 2 assert 1 in values assert [0] in values
def make_reading_comprehension_instance_quac( question_list_tokens: List[List[Token]], passage_tokens: List[Token], token_indexers: Dict[str, TokenIndexer], passage_text: str, token_span_lists: List[List[Tuple[int, int]]] = None, yesno_list: List[int] = None, followup_list: List[int] = None, additional_metadata: Dict[str, Any] = None, num_context_answers: int = 0) -> Instance: """ Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use in a reading comprehension model. Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both ``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts`` and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end`` fields, which are both ``IndexFields``. Parameters ---------- question_list_tokens : ``List[List[Token]]`` An already-tokenized list of questions. Each dialog have multiple questions. passage_tokens : ``List[Token]`` An already-tokenized passage that contains the answer to the given question. token_indexers : ``Dict[str, TokenIndexer]`` Determines how the question and passage ``TextFields`` will be converted into tensors that get input to a model. See :class:`TokenIndexer`. passage_text : ``str`` The original passage text. We need this so that we can recover the actual span from the original passage that the model predicts as the answer to the question. This is used in official evaluation scripts. token_spans_lists : ``List[List[Tuple[int, int]]]``, optional Indices into ``passage_tokens`` to use as the answer to the question for training. This is a list of list, first because there is multiple questions per dialog, and because there might be several possible correct answer spans in the passage. Currently, we just select the last span in this list (i.e., QuAC has multiple annotations on the dev set; this will select the last span, which was given by the original annotator). yesno_list : ``List[int]`` List of the affirmation bit for each question answer pairs. followup_list : ``List[int]`` List of the continuation bit for each question answer pairs. num_context_answers : ``int``, optional How many answers to encode into the passage. additional_metadata : ``Dict[str, Any]``, optional The constructed ``metadata`` field will by default contain ``original_passage``, ``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If you want any other metadata to be associated with each instance, you can pass that in here. This dictionary will get added to the ``metadata`` dictionary we already construct. """ additional_metadata = additional_metadata or {} fields: Dict[str, Field] = {} passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens] # This is separate so we can reference it later with a known type. passage_field = TextField(passage_tokens, token_indexers) fields['passage'] = passage_field fields['question'] = ListField([ TextField(q_tokens, token_indexers) for q_tokens in question_list_tokens ]) metadata = {'original_passage': passage_text, 'token_offsets': passage_offsets, 'question_tokens': [[token.text for token in question_tokens] \ for question_tokens in question_list_tokens], 'passage_tokens': [token.text for token in passage_tokens], } p1_answer_marker_list: List[Field] = [] p2_answer_marker_list: List[Field] = [] p3_answer_marker_list: List[Field] = [] def get_tag(i, i_name): # Generate a tag to mark previous answer span in the passage. return "<{0:d}_{1:s}>".format(i, i_name) def mark_tag(span_start, span_end, passage_tags, prev_answer_distance): try: assert span_start >= 0 assert span_end >= 0 except: raise ValueError( "Previous {0:d}th answer span should have been updated!". format(prev_answer_distance)) # Modify "tags" to mark previous answer span. if span_start == span_end: passage_tags[prev_answer_distance][span_start] = get_tag( prev_answer_distance, "") else: passage_tags[prev_answer_distance][span_start] = get_tag( prev_answer_distance, "start") passage_tags[prev_answer_distance][span_end] = get_tag( prev_answer_distance, "end") for passage_index in range(span_start + 1, span_end): passage_tags[prev_answer_distance][passage_index] = get_tag( prev_answer_distance, "in") if token_span_lists: span_start_list: List[Field] = [] span_end_list: List[Field] = [] p1_span_start, p1_span_end, p2_span_start = -1, -1, -1 p2_span_end, p3_span_start, p3_span_end = -1, -1, -1 # Looping each <<answers>>. for question_index, answer_span_lists in enumerate(token_span_lists): span_start, span_end = answer_span_lists[ -1] # Last one is the original answer span_start_list.append(IndexField(span_start, passage_field)) span_end_list.append(IndexField(span_end, passage_field)) prev_answer_marker_lists = [["O"] * len(passage_tokens), ["O"] * len(passage_tokens), ["O"] * len(passage_tokens), ["O"] * len(passage_tokens)] if question_index > 0 and num_context_answers > 0: mark_tag(p1_span_start, p1_span_end, prev_answer_marker_lists, 1) if question_index > 1 and num_context_answers > 1: mark_tag(p2_span_start, p2_span_end, prev_answer_marker_lists, 2) if question_index > 2 and num_context_answers > 2: mark_tag(p3_span_start, p3_span_end, prev_answer_marker_lists, 3) p3_span_start = p2_span_start p3_span_end = p2_span_end p2_span_start = p1_span_start p2_span_end = p1_span_end p1_span_start = span_start p1_span_end = span_end if num_context_answers > 2: p3_answer_marker_list.append( SequenceLabelField(prev_answer_marker_lists[3], passage_field, label_namespace="answer_tags")) if num_context_answers > 1: p2_answer_marker_list.append( SequenceLabelField(prev_answer_marker_lists[2], passage_field, label_namespace="answer_tags")) if num_context_answers > 0: p1_answer_marker_list.append( SequenceLabelField(prev_answer_marker_lists[1], passage_field, label_namespace="answer_tags")) fields['span_start'] = ListField(span_start_list) fields['span_end'] = ListField(span_end_list) if num_context_answers > 0: fields['p1_answer_marker'] = ListField(p1_answer_marker_list) if num_context_answers > 1: fields['p2_answer_marker'] = ListField(p2_answer_marker_list) if num_context_answers > 2: fields['p3_answer_marker'] = ListField( p3_answer_marker_list) fields['yesno_list'] = ListField( \ [LabelField(yesno, label_namespace="yesno_labels") for yesno in yesno_list]) fields['followup_list'] = ListField([LabelField(followup, label_namespace="followup_labels") \ for followup in followup_list]) metadata.update(additional_metadata) fields['metadata'] = MetadataField(metadata) return Instance(fields)
def __getitem__(self, index): # if self.split == 'test': # raise ValueError("blind test mode not supported quite yet") item = deepcopy(self.items[index]) ################################################################### # Load questions and answers if self.mode == 'rationale': conditioned_label = item[ 'answer_label'] if self.split != 'test' else self.conditioned_answer_choice item['question'] += item['answer_choices'][conditioned_label] answer_choices = item['{}_choices'.format(self.mode)] dets2use, old_det_to_new_ind = self._get_dets_to_use(item) ################################################################### # Load in BERT. We'll get contextual representations of the context and the answer choices # grp_items = {k: np.array(v, dtype=np.float16) for k, v in self.get_h5_group(index).items()} with h5py.File(self.h5fn, 'r') as h5: grp_items = { k: np.array(v, dtype=np.float16) for k, v in h5[str(index)].items() } # Essentially we need to condition on the right answer choice here, if we're doing QA->R. We will always # condition on the `conditioned_answer_choice.` condition_key = self.conditioned_answer_choice if self.split == "test" and self.mode == "rationale" else "" instance_dict = {} if 'endingonly' not in self.embs_to_load: questions_tokenized, question_tags = zip(*[ _fix_tokenization( item['question'], grp_items[f'ctx_{self.mode}{condition_key}{i}'], old_det_to_new_ind, item['objects'], token_indexers=self.token_indexers, pad_ind=0 if self.add_image_as_a_box else -1) for i in range(4) ]) instance_dict['question'] = ListField(questions_tokenized) instance_dict['question_tags'] = ListField(question_tags) answers_tokenized, answer_tags = zip(*[ _fix_tokenization( answer, grp_items[f'answer_{self.mode}{condition_key}{i}'], old_det_to_new_ind, item['objects'], token_indexers=self.token_indexers, pad_ind=0 if self.add_image_as_a_box else -1) for i, answer in enumerate(answer_choices) ]) instance_dict['answers'] = ListField(answers_tokenized) instance_dict['answer_tags'] = ListField(answer_tags) if self.split != 'test': instance_dict['label'] = LabelField(item['{}_label'.format( self.mode)], skip_indexing=True) instance_dict['metadata'] = MetadataField({ 'annot_id': item['annot_id'], 'ind': index, 'movie': item['movie'], 'img_fn': item['img_fn'], 'question_number': item['question_number'] }) ################################################################### # Load image now and rescale it. Might have to subtract the mean and whatnot here too. image = load_image(os.path.join(VCR_IMAGES_DIR, item['img_fn'])) image, window, img_scale, padding = resize_image( image, random_pad=self.is_train) image = to_tensor_and_normalize(image) c, h, w = image.shape ################################################################### # Load boxes. with open(os.path.join(VCR_IMAGES_DIR, item['metadata_fn']), 'r') as f: metadata = json.load(f) # [nobj, 14, 14] segms = np.stack([ make_mask(mask_size=14, box=metadata['boxes'][i], polygons_list=metadata['segms'][i]) for i in dets2use ]) # Chop off the final dimension, that's the confidence boxes = np.array(metadata['boxes'])[dets2use, :-1] # Possibly rescale them if necessary boxes *= img_scale boxes[:, :2] += np.array(padding[:2])[None] boxes[:, 2:] += np.array(padding[:2])[None] obj_labels = [ self.coco_obj_to_ind[item['objects'][i]] for i in dets2use.tolist() ] if self.add_image_as_a_box: boxes = np.row_stack((window, boxes)) segms = np.concatenate((np.ones( (1, 14, 14), dtype=np.float32), segms), 0) obj_labels = [self.coco_obj_to_ind['__background__']] + obj_labels instance_dict['segms'] = ArrayField(segms, padding_value=0) instance_dict['objects'] = ListField( [LabelField(x, skip_indexing=True) for x in obj_labels]) if not np.all((boxes[:, 0] >= 0.) & (boxes[:, 0] < boxes[:, 2])): import ipdb ipdb.set_trace() assert np.all((boxes[:, 1] >= 0.) & (boxes[:, 1] < boxes[:, 3])) assert np.all((boxes[:, 2] <= w)) assert np.all((boxes[:, 3] <= h)) instance_dict['boxes'] = ArrayField(boxes, padding_value=-1) instance = Instance(instance_dict) instance.index_fields(self.vocab) return image, instance
def text_to_instance( self, # type: ignore sentences: List[List[str]], gold_clusters: Optional[List[List[Tuple[int, int]]]] = None) -> Instance: # pylint: disable=arguments-differ """ Parameters ---------- sentences : ``List[List[str]]``, required. A list of lists representing the tokenised words and sentences in the document. gold_clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None) A list of all clusters in the document, represented as word spans. Each cluster contains some number of spans, which can be nested and overlap, but will never exactly match between clusters. Returns ------- An ``Instance`` containing the following ``Fields``: text : ``TextField`` The text of the full document. spans : ``ListField[SpanField]`` A ListField containing the spans represented as ``SpanFields`` with respect to the document text. span_labels : ``SequenceLabelField``, optional The id of the cluster which each possible span belongs to, or -1 if it does not belong to a cluster. As these labels have variable length (it depends on how many spans we are considering), we represent this a as a ``SequenceLabelField`` with respect to the ``spans ``ListField``. """ flattened_sentences = [ self._normalize_word(word) for sentence in sentences for word in sentence ] # align clusters gold_clusters = self.align_clusters_to_tokens(flattened_sentences, gold_clusters) def tokenizer(s: str): return self.token_indexer.wordpiece_tokenizer(s) flattened_sentences = tokenizer(" ".join(flattened_sentences)) metadata: Dict[str, Any] = {"original_text": flattened_sentences} if gold_clusters is not None: metadata["clusters"] = gold_clusters if len(flattened_sentences) > 512: #import pdb #pdb.set_trace() text_field = TextField( [Token(["[CLS]"])] + [Token(word) for word in flattened_sentences[:512]] + [Token(["[SEP]"])], self._token_indexers) total_list = [text_field] import math for i in range( math.ceil(float(len(flattened_sentences[512:])) / 100.0)): # slide by 100 text_field = TextField([Token(["[CLS]"])] + [ Token(word) for word in flattened_sentences[512 + (i * 100):512 + ((i + 1) * 100)] ] + [Token(["[SEP]"])], self._token_indexers) total_list.append(text_field) text_field = ListField(total_list) # doing the Listfield else: text_field = TextField( [Token(["[CLS]"])] + [Token(word) for word in flattened_sentences] + [Token(["[SEP]"])], self._token_indexers) cluster_dict = {} if gold_clusters is not None: for cluster_id, cluster in enumerate(gold_clusters): for mention in cluster: cluster_dict[tuple(mention)] = cluster_id spans: List[Field] = [] span_labels: Optional[ List[int]] = [] if gold_clusters is not None else None sentence_offset = 0 normal = [] for sentence in sentences: # enumerate the spans. for start, end in enumerate_spans( sentence, offset=sentence_offset, max_span_width=self._max_span_width): if span_labels is not None: if (start, end) in cluster_dict: span_labels.append(cluster_dict[(start, end)]) else: span_labels.append(-1) # align the spans to the BERT tokeniation normal.append((start, end)) # span field for Span, which needs to be a flattened esnetnece. span_field = None """ if len(flattened_sentences) > 512: span_field = TextField([Token(["[CLS]"])] + [Token(word) for word in flattened_sentences]+ [Token(["[SEP]"])] , self._token_indexers) else: span_field = text_field """ spans.append(SpanField(start, end, span_field)) sentence_offset += len(sentence) span_field = ListField(spans) metadata_field = MetadataField(metadata) fields: Dict[str, Field] = { "text": text_field, "spans": span_field, "metadata": metadata_field } if span_labels is not None: fields["span_labels"] = SequenceLabelField(span_labels, span_field) return Instance(fields)
def make_coref_instance( sentences: List[List[str]], token_indexers: Dict[str, TokenIndexer], max_span_width: int, gold_clusters: Optional[List[List[Tuple[int, int]]]] = None, wordpiece_modeling_tokenizer: PretrainedTransformerTokenizer = None, max_sentences: int = None, remove_singleton_clusters: bool = True, desc_embeddings = None, ) -> Instance: """ # Parameters sentences : `List[List[str]]`, required. A list of lists representing the tokenised words and sentences in the document. token_indexers : `Dict[str, TokenIndexer]` This is used to index the words in the document. See :class:`TokenIndexer`. max_span_width : `int`, required. The maximum width of candidate spans to consider. gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = `None`) A list of all clusters in the document, represented as word spans with absolute indices in the entire document. Each cluster contains some number of spans, which can be nested and overlap. If there are exact matches between clusters, they will be resolved using `_canonicalize_clusters`. wordpiece_modeling_tokenizer: `PretrainedTransformerTokenizer`, optional (default = `None`) If not None, this dataset reader does subword tokenization using the supplied tokenizer and distribute the labels to the resulting wordpieces. All the modeling will be based on wordpieces. If this is set to `False` (default), the user is expected to use `PretrainedTransformerMismatchedIndexer` and `PretrainedTransformerMismatchedEmbedder`, and the modeling will be on the word-level. max_sentences: `int`, optional (default = `None`) The maximum number of sentences in each document to keep. By default keeps all sentences. remove_singleton_clusters : `bool`, optional (default = `True`) Some datasets contain clusters that are singletons (i.e. no coreferents). This option allows the removal of them. desc_embeddings - 768 dimensional embeddings for the descriptions. Optional # Returns An `Instance` containing the following `Fields`: text : `TextField` The text of the full document. spans : `ListField[SpanField]` A ListField containing the spans represented as `SpanFields` with respect to the document text. span_labels : `SequenceLabelField`, optional The id of the cluster which each possible span belongs to, or -1 if it does not belong to a cluster. As these labels have variable length (it depends on how many spans we are considering), we represent this a as a `SequenceLabelField` with respect to the spans `ListField`. """ if max_sentences is not None and len(sentences) > max_sentences: sentences = sentences[:max_sentences] total_length = sum(len(sentence) for sentence in sentences) if gold_clusters is not None: new_gold_clusters = [] for cluster in gold_clusters: new_cluster = [] for mention in cluster: if mention[1] < total_length: new_cluster.append(mention) if new_cluster: new_gold_clusters.append(new_cluster) gold_clusters = new_gold_clusters flattened_sentences = [_normalize_word(word) for sentence in sentences for word in sentence] if wordpiece_modeling_tokenizer is not None: flat_sentences_tokens, offsets = wordpiece_modeling_tokenizer.intra_word_tokenize( flattened_sentences ) flattened_sentences = [t.text for t in flat_sentences_tokens] else: flat_sentences_tokens = [Token(word) for word in flattened_sentences] text_field = TextField(flat_sentences_tokens, token_indexers) cluster_dict = {} if gold_clusters is not None: gold_clusters = _canonicalize_clusters(gold_clusters) if remove_singleton_clusters: gold_clusters = [cluster for cluster in gold_clusters if len(cluster) > 1] if wordpiece_modeling_tokenizer is not None: for cluster in gold_clusters: for mention_id, mention in enumerate(cluster): start = offsets[mention[0]][0] end = offsets[mention[1]][1] cluster[mention_id] = (start, end) for cluster_id, cluster in enumerate(gold_clusters): for mention in cluster: cluster_dict[tuple(mention)] = cluster_id spans: List[Field] = [] span_labels: Optional[List[int]] = [] if gold_clusters is not None else None sentence_offset = 0 for sentence in sentences: for start, end in enumerate_spans( sentence, offset=sentence_offset, max_span_width=max_span_width ): if wordpiece_modeling_tokenizer is not None: start = offsets[start][0] end = offsets[end][1] # `enumerate_spans` uses word-level width limit; here we apply it to wordpieces # We have to do this check here because we use a span width embedding that has # only `max_span_width` entries, and since we are doing wordpiece # modeling, the span width embedding operates on wordpiece lengths. So a check # here is necessary or else we wouldn't know how many entries there would be. if end - start + 1 > max_span_width: continue # We also don't generate spans that contain special tokens if start < len(wordpiece_modeling_tokenizer.single_sequence_start_tokens): continue if end >= len(flat_sentences_tokens) - len( wordpiece_modeling_tokenizer.single_sequence_end_tokens ): continue if span_labels is not None: if (start, end) in cluster_dict: span_labels.append(cluster_dict[(start, end)]) else: span_labels.append(-1) spans.append(SpanField(start, end, text_field)) sentence_offset += len(sentence) span_field = ListField(spans) metadata: Dict[str, Any] = {"original_text": flattened_sentences} if gold_clusters is not None: metadata["clusters"] = gold_clusters if desc_embeddings: flat_embeddings = [wordlist for doclist in desc_embeddings for wordlist in doclist] #flatten embeddings to get rid of sentence splits metadata["desc_embeddings"] = flat_embeddings metadata_field = MetadataField(metadata) fields: Dict[str, Field] = { "text": text_field, "spans": span_field, "metadata": metadata_field, } if span_labels is not None: fields["span_labels"] = SequenceLabelField(span_labels, span_field) #import pdb;pdb.set_trace() return Instance(fields)
def text_to_instance(self, sample: list) -> Instance: fields = {} text: str = sample['text'].strip() labels = sample['aspect_terms'] pieces = [] piece_labels = [] last_label_end_index = 0 for i, label in enumerate(labels): if label.from_index != last_label_end_index: pieces.append(text[last_label_end_index:label.from_index]) piece_labels.append(0) pieces.append(text[label.from_index:label.to_index]) piece_labels.append(1) last_label_end_index = label.to_index if i == len(labels) - 1 and label.to_index != len(text): pieces.append(text[label.to_index:]) piece_labels.append(0) words_of_pieces = [self.tokenizer(piece.strip()) for piece in pieces] word_indices_of_aspect_terms = [] start_index = 0 for i in range(len(words_of_pieces)): words_of_piece = words_of_pieces[i] end_index = start_index + len(words_of_piece) if piece_labels[i] == 1: word_indices_of_aspect_terms.append([start_index, end_index]) start_index = end_index sample['word_indices_of_aspect_terms'] = word_indices_of_aspect_terms words = [] for words_of_piece in words_of_pieces: words.extend(words_of_piece) sample['words'] = words for i in range(len(word_indices_of_aspect_terms)): start_index = word_indices_of_aspect_terms[i][0] end_index = word_indices_of_aspect_terms[i][1] aspect_term_text = ' '.join(words[start_index:end_index]) graph = self._build_graph(text) sample['graph'] = graph tokens = [Token(word) for word in words] sentence_field = TextField(tokens, self.token_indexers) fields['tokens'] = sentence_field position = [Token(str(i)) for i in range(len(tokens))] position_field = TextField(position, self.position_indexers) fields['position'] = position_field if self.configuration['sample_mode'] == 'single': max_aspect_term_num = 1 else: max_aspect_term_num = self.configuration['max_aspect_term_num'] polarity_labels = [-100] * max_aspect_term_num for i, aspect_term in enumerate(sample['aspect_terms']): polarity_labels[i] = self.polarities.index(aspect_term.polarity) label_field = ArrayField(np.array(polarity_labels)) fields["label"] = label_field polarity_mask = [ 1 if polarity_labels[i] != -100 else 0 for i in range(max_aspect_term_num) ] polarity_mask_field = ArrayField(np.array(polarity_mask)) fields['polarity_mask'] = polarity_mask_field # stop_word_labels = [1 if word in english_stop_words else 0 for word in words] # stop_word_num = sum(stop_word_labels) # stop_word_labels = [label / stop_word_num for label in stop_word_labels] # sample.append(stop_word_labels) sample_field = MetadataField(sample) fields["sample"] = sample_field return Instance(fields)
def text_to_instance(self, sample: list) -> Instance: fields = {} text: str = sample['text'].strip() labels: List[data_object.AspectTerm] = sample['aspect_terms'] pieces = [] piece_labels = [] last_label_end_index = 0 for i, label in enumerate(labels): if label.from_index != last_label_end_index: pieces.append(text[last_label_end_index:label.from_index]) piece_labels.append(0) if self.configuration['aspect_term_aware']: pieces.append('#') piece_labels.append(0) pieces.append(text[label.from_index:label.to_index]) piece_labels.append(1) if self.configuration['aspect_term_aware']: if self.configuration['same_special_token']: pieces.append('#') else: pieces.append('$') piece_labels.append(0) last_label_end_index = label.to_index if i == len(labels) - 1 and label.to_index != len(text): pieces.append(text[label.to_index:]) piece_labels.append(0) words_of_pieces = [ self.tokenizer(piece.lower().strip()) for piece in pieces ] word_indices_of_aspect_terms = [] # the first token is '[CLS]' start_index = 1 for i in range(len(words_of_pieces)): words_of_piece = words_of_pieces[i] end_index = start_index + len(words_of_piece) if piece_labels[i] == 1: word_indices_of_aspect_terms.append([start_index, end_index]) start_index = end_index sample['word_indices_of_aspect_terms'] = word_indices_of_aspect_terms words = [] for words_of_piece in words_of_pieces: words.extend(words_of_piece) words = ['[CLS]'] + words + ['[SEP]'] sample['words'] = words if self.configuration['pair']: assert self.configuration[ 'sample_mode'] == 'single', 'While pair is True, sample_mode must be single' words_of_aspect_term = self.tokenizer( labels[0].term.lower().strip()) words = words + words_of_aspect_term + ['[SEP]'] graph = self._build_graph(text) sample['graph'] = graph tokens = [Token(word) for word in words] sentence_field = TextField(tokens, self.token_indexers) fields['tokens'] = sentence_field position = [Token(str(i)) for i in range(len(tokens))] position_field = TextField(position, self.position_indexers) fields['position'] = position_field if self.configuration['sample_mode'] == 'single': max_aspect_term_num = 1 else: max_aspect_term_num = self.configuration['max_aspect_term_num'] polarity_labels = [-100] * max_aspect_term_num for i, aspect_term in enumerate(sample['aspect_terms']): polarity_labels[i] = self.polarities.index(aspect_term.polarity) label_field = ArrayField(np.array(polarity_labels)) fields["label"] = label_field polarity_mask = [ 1 if polarity_labels[i] != -100 else 0 for i in range(max_aspect_term_num) ] polarity_mask_field = ArrayField(np.array(polarity_mask)) fields['polarity_mask'] = polarity_mask_field # stop_word_labels = [1 if word in english_stop_words else 0 for word in words] # stop_word_num = sum(stop_word_labels) # stop_word_labels = [label / stop_word_num for label in stop_word_labels] # sample.append(stop_word_labels) sample_field = MetadataField(sample) fields["sample"] = sample_field return Instance(fields)
def text_to_instance( self, query: str, tokenized_query: List[Token], passage: str, tokenized_passage: List[Token], answers: List[str], token_answer_span: Optional[Tuple[int, int]] = None, additional_metadata: Optional[Dict[str, Any]] = None, always_add_answer_span: Optional[bool] = False, ) -> Instance: """ A lot of this comes directly from the `transformer_squad.text_to_instance` TODO: Improve docs """ fields = {} # Create the query field from the tokenized question and context. Use # `self._tokenizer.add_special_tokens` function to add the necessary # special tokens to the query. query_field = TextField( self._tokenizer.add_special_tokens( # The `add_special_tokens` function automatically adds in the # separation token to mark the separation between the two lists of # tokens. Therefore, we can create the query field WITH context # through passing them both as arguments. tokenized_query, tokenized_passage), self._token_indexers) # Add the query field to the fields dict that will be outputted as an # instance. Do it here rather than assign above so that we can use # attributes from `query_field` rather than continuously indexing # `fields`. fields['question_with_context'] = query_field # Calculate the index that marks the start of the context. start_of_context = ( +len(tokenized_query) # Used getattr so I can test without having to load a # transformer model. - len(getattr(self._tokenizer, 'sequence_pair_start_tokens', [])) - len(getattr(self._tokenizer, 'sequence_pair_mid_tokens', []))) # make the answer span if token_answer_span is not None: assert all(i >= 0 for i in token_answer_span) assert token_answer_span[0] <= token_answer_span[1] fields["answer_span"] = SpanField( token_answer_span[0] + start_of_context, token_answer_span[1] + start_of_context, query_field, ) # make the context span, i.e., the span of text from which possible # answers should be drawn fields["context_span"] = SpanField( start_of_context, start_of_context + len(tokenized_passage) - 1, query_field) # make the metadata metadata = { "question": query, "question_tokens": tokenized_query, "context": passage, "context_tokens": tokenized_passage, "answers": answers or [], } if additional_metadata is not None: metadata.update(additional_metadata) fields["metadata"] = MetadataField(metadata) return Instance(fields)
def make_reading_comprehension_instance( question_tokens: List[Token], passage_tokens: List[Token], token_indexers: Dict[str, TokenIndexer], answer: List[bool], passage_text: str, additional_metadata: Dict[str, Any] = None, ) -> Instance: """ Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use in a reading comprehension model. Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both ``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts`` and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end`` fields, which are both ``IndexFields``. Parameters ---------- question_tokens : ``List[Token]`` An already-tokenized question. passage_tokens : ``List[Token]`` An already-tokenized passage that contains the answer to the given question. token_indexers : ``Dict[str, TokenIndexer]`` Determines how the question and passage ``TextFields`` will be converted into tensors that get input to a model. See :class:`TokenIndexer`. passage_text : ``str`` The original passage text. We need this so that we can recover the actual span from the original passage that the model predicts as the answer to the question. This is used in official evaluation scripts. token_spans : ``List[Tuple[int, int]]``, optional Indices into ``passage_tokens`` to use as the answer to the question for training. This is a list because there might be several possible correct answer spans in the passage. Currently, we just select the most frequent span in this list (i.e., SQuAD has multiple annotations on the dev set; this will select the span that the most annotators gave as correct). answer_texts : ``List[str]``, optional All valid answer strings for the given question. In SQuAD, e.g., the training set has exactly one answer per question, but the dev and test sets have several. TriviaQA has many possible answers, which are the aliases for the known correct entity. This is put into the metadata for use with official evaluation scripts, but not used anywhere else. additional_metadata : ``Dict[str, Any]``, optional The constructed ``metadata`` field will by default contain ``original_passage``, ``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If you want any other metadata to be associated with each instance, you can pass that in here. This dictionary will get added to the ``metadata`` dictionary we already construct. """ additional_metadata = additional_metadata or {} fields: Dict[str, Field] = {} passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens] # This is separate so we can reference it later with a known type. passage_field = TextField(passage_tokens, token_indexers) fields["passage"] = passage_field fields["question"] = TextField(question_tokens, token_indexers) fields["answer"] = ArrayField(np.array(answer)) metadata = { # "original_passage": passage_text, # "token_offsets": passage_offsets, # "question_tokens": [token.text for token in question_tokens], # "passage_tokens": [token.text for token in passage_tokens], } metadata.update(additional_metadata) fields["metadata"] = MetadataField(metadata) return Instance(fields)
def text_to_instance(self, source: str, target: str = None) -> Instance: def prepare_text(text, max_tokens): tokens = self._tokenizer.tokenize(text)[:max_tokens] tokens.insert(0, Token(START_SYMBOL)) tokens.append(Token(END_SYMBOL)) return tokens source_tokens = prepare_text(source, self._source_max_tokens) source_tokens_indexed = TextField(source_tokens, self._source_token_indexers) result = {"source_tokens": source_tokens_indexed} meta_fields = {} if self._save_copy_fields: source_to_target_field = NamespaceSwappingField( source_tokens[1:-1], self._target_namespace) result["source_to_target"] = source_to_target_field meta_fields["source_tokens"] = [ x.text for x in source_tokens[1:-1] ] if self._save_pgn_fields: source_to_target_field = NamespaceSwappingField( source_tokens, self._target_namespace) result["source_to_target"] = source_to_target_field meta_fields["source_tokens"] = [x.text for x in source_tokens] if target: target_tokens = prepare_text(target, self._target_max_tokens) target_tokens_indexed = TextField(target_tokens, self._target_token_indexers) result["target_tokens"] = target_tokens_indexed if self._save_pgn_fields: meta_fields["target_tokens"] = [y.text for y in target_tokens] source_and_target_token_ids = self._tokens_to_ids( source_tokens + target_tokens) source_token_ids = source_and_target_token_ids[:len( source_tokens)] result["source_token_ids"] = ArrayField( np.array(source_token_ids, dtype="long")) target_token_ids = source_and_target_token_ids[ len(source_tokens):] result["target_token_ids"] = ArrayField( np.array(target_token_ids, dtype="long")) if self._save_copy_fields: meta_fields["target_tokens"] = [ y.text for y in target_tokens[1:-1] ] source_and_target_token_ids = self._tokens_to_ids( source_tokens[1:-1] + target_tokens) source_token_ids = source_and_target_token_ids[:len( source_tokens) - 2] result["source_token_ids"] = ArrayField( np.array(source_token_ids)) target_token_ids = source_and_target_token_ids[ len(source_tokens) - 2:] result["target_token_ids"] = ArrayField( np.array(target_token_ids)) elif self._save_copy_fields: source_token_ids = self._tokens_to_ids(source_tokens[1:-1]) result["source_token_ids"] = ArrayField(np.array(source_token_ids)) elif self._save_pgn_fields: source_token_ids = self._tokens_to_ids(source_tokens) result["source_token_ids"] = ArrayField(np.array(source_token_ids)) if self._save_copy_fields or self._save_pgn_fields: result["metadata"] = MetadataField(meta_fields) return Instance(result)
def test_forward_works(self): # Setting up the model. transformer_name = "epwalsh/bert-xsmall-dummy" vocab = Vocabulary() backbone = PretrainedTransformerBackbone(vocab, transformer_name) head1 = ClassifierHead(vocab, seq2vec_encoder=ClsPooler(20), input_dim=20, num_labels=3) head2 = ClassifierHead(vocab, seq2vec_encoder=ClsPooler(20), input_dim=20, num_labels=4) # We'll start with one head, and add another later. model = MultiTaskModel(vocab, backbone, {"cls": head1}) # Setting up the data. tokenizer = PretrainedTransformerTokenizer(model_name=transformer_name) token_indexers = PretrainedTransformerIndexer( model_name=transformer_name) tokens = tokenizer.tokenize("This is a test") text_field = TextField(tokens, {"tokens": token_indexers}) label_field1 = LabelField(1, skip_indexing=True) label_field2 = LabelField(3, skip_indexing=True) instance = Instance({ "text": text_field, "label": label_field1, "task": MetadataField("cls") }) # Now we run some tests. First, the default. outputs = model.forward_on_instance(instance) assert "encoded_text" in outputs assert "cls_logits" in outputs assert "loss" in outputs assert "cls_loss" in outputs # When we don't have labels. instance = Instance({"text": text_field, "task": MetadataField("cls")}) outputs = model.forward_on_instance(instance) assert "encoded_text" in outputs assert "cls_logits" in outputs assert "loss" not in outputs # Same in eval mode model.eval() outputs = model.forward_on_instance(instance) assert "encoded_text" in outputs assert "loss" not in outputs # no loss because we have no labels assert "cls_logits" in outputs # but we can compute logits model.train() # Now for two headed and other more complex tests. model = MultiTaskModel( vocab, backbone, { "cls1": head1, "cls2": head2 }, arg_name_mapping={ "backbone": { "question": "text" }, }, ) # Basic case where things should work, with two heads that both need label inputs. instance1 = Instance({ "text": text_field, "label": label_field1, "task": MetadataField("cls1") }) instance2 = Instance({ "text": text_field, "label": label_field2, "task": MetadataField("cls2") }) batch = Batch([instance1, instance2]) outputs = model.forward(**batch.as_tensor_dict()) assert "encoded_text" in outputs assert "cls1_logits" in outputs assert "cls1_loss" in outputs assert "cls2_logits" in outputs assert "cls2_loss" in outputs assert "loss" in outputs combined_loss = outputs["cls1_loss"].item( ) + outputs["cls2_loss"].item() assert abs(outputs["loss"].item() - combined_loss) <= 1e-6 # This should fail, because we're using task 'cls1' with the labels for `cls2`, and the sizes don't match. # This shows up as an IndexError in this case. It'd be nice to catch this kind of error more cleanly in the # model class, but I'm not sure how. instance = Instance({ "text": text_field, "label": label_field2, "task": MetadataField("cls1") }) with pytest.raises(IndexError): outputs = model.forward_on_instance(instance) # This one should fail because we now have two things that map to "text" in the backbone, # and they would clobber each other. The name mapping that we have in the model is ok, as # long as our data loader is set up such that we don't batch instances that have both of # these fields at the same time. instance = Instance({ "question": text_field, "text": text_field, "task": MetadataField("cls1") }) with pytest.raises(ValueError, match="duplicate argument text"): outputs = model.forward_on_instance(instance)
def text_to_instance(self, dialog_idx, dialog_context, exact_match_feas, tags, utt_lens, labels, spans, span_labels): token_indexers = self._token_indexers symbol_indexers = self._sys_user_symbol_indexers fields: Dict[str, Field] = {} fields['dialogs'] = TextField(dialog_context, token_indexers) fields['tags'] = TextField(tags, symbol_indexers) fields['utt_lens'] = ArrayField(np.array(utt_lens), dtype=np.int32) fields['exact_match'] = ListField(exact_match_feas) fields['metadata'] = MetadataField(dialog_context) fields['dialog_indices'] = MetadataField(dialog_idx) # calculate labels if labels != None: expanded_value_labels = [] for turn_label in labels: turn_value_label = [-1 if self._ds_type[ds] == "span" else 0 for ds in self._ds_list] # 0 is default which is 'none' is in vocab for each_label in turn_label: if each_label[2] == "": continue ds = each_label[0] + " " + each_label[1] if ds in self._ds_text2id: if self._ds_type[ds] == "classification": if each_label[2] not in self._value_text2id[ds]: #print(ds, each_label[2]) continue turn_value_label[self._ds_text2id[ds]] = self._value_text2id[ds][each_label[2]] if self._ds_type[ds] == "span" and self._ds_use_value_list[ds] == True: if each_label[2] != "none" and each_label[2] != "dont care": if each_label[2] not in self._value_text2id[ds]: #print(ds, each_label[2]) continue turn_value_label[self._ds_text2id[ds]] = self._value_text2id[ds][each_label[2]] expanded_value_labels.append(ListField([LabelField(l, skip_indexing=True) for l in turn_value_label])) fields['labels'] = ListField(expanded_value_labels) # calculate spans if len(self._ds_span_list) != 0: spans_start = [] spans_end = [] for turn_span in spans: cur_span_start = [-1] * len(self._ds_span_list) cur_span_end = [-1] * len(self._ds_span_list) for each_span in turn_span: cur_ds = each_span[0] + " " + each_span[1] cur_span_start[self._ds_span_text2id[cur_ds]] = each_span[2] cur_span_end[self._ds_span_text2id[cur_ds]] = each_span[3] spans_start.append(ListField([LabelField(l, skip_indexing=True) for l in cur_span_start])) spans_end.append(ListField([LabelField(l, skip_indexing=True) for l in cur_span_end])) fields["spans_start"] = ListField(spans_start) fields["spans_end"] = ListField(spans_end) expanded_span_labels = [] for turn_span_label in span_labels: cur_span_label = [0 for _ in self._ds_span_list] for each_span_label in turn_span_label: cur_ds = each_span_label[0] + " " + each_span_label[1] cur_span_label[self._ds_span_text2id[cur_ds]] = each_span_label[2] expanded_span_labels.append(ListField([LabelField(l, skip_indexing=True) for l in cur_span_label])) fields["span_labels"] = ListField(expanded_span_labels) return Instance(fields)
def main(args): """ CoronAI: Sequence2Vector ========== In many comment analysis and note analysis applications it is important to be able to represent portions of the string as numeric vectors. This application in CoronAI project assists us with that. Using this application, one can read the clean CSV file that includes the text segments, and represent each and every one of them using a variant of BERT. Sample command to run: ``` coronai_segment2vector --gpu=2 --input_csv=.../text_segment_dataset.csv --output_pkl=.../output.pkl --path_to_bert_weights=.../NCBI_BERT_pubmed_mimic_uncased_L-24_H-1024_A-16 --batch_size=400 ``` Parameters ---------- args: `args.Namespace`, required The arguments needed for it """ # reading the input files input_dataframe = pandas.read_csv(os.path.abspath(args.input_csv)) # todo: currently the system is for bert only. we will add arguments to control this and make # it modular for other embeddings such as elmo, etc. # todo: the paranthesis of wordpiece_tokenizer covering was an issue found in other places as well, # make sure to look for them and resolve the issue. source_token_indexers = get_bert_token_indexers( path_to_bert_weights=args.path_to_bert_weights, maximum_number_of_tokens=512, is_model_lowercase=True) source_tokenizer = lambda x: source_token_indexers.wordpiece_tokenizer( preprocess_text(x))[:510] source_token_embeddings = get_bert_token_embeddings( path_to_bert_weights=args.path_to_bert_weights, top_layer_only=True, indexer_id='source_tokens', token_to_embed='word', ) # finding the output dimension for the token embeddings, to be used later. source_embeddings_dimension = source_token_embeddings.get_output_dim() sequence_encoder = BertSequencePooler() if args.gpu > -1: sequence_encoder = sequence_encoder.to( torch.device('cuda:{}'.format(args.gpu))) try: source_token_embeddings = source_token_embeddings.cuda( torch.device('cuda:{}'.format(args.gpu))) except Exception as e: print(e) import pdb pdb.set_trace() input_text_sequence_instances = [] print(">> (status): preparing data...\n\n") if os.path.isfile( os.path.join(os.path.dirname(args.output_pkl), 'input_text_sequence_instances.pkl')): with open( os.path.join(os.path.dirname(args.output_pkl), 'input_text_sequence_instances.pkl'), 'rb') as handle: input_text_sequence_instances = pickle.load(handle) else: for i in tqdm(range(input_dataframe.shape[0])): fields = dict() row = input_dataframe.iloc[i, :] tokens = [Token(x) for x in source_tokenizer(row['text_segment'])] sequence_field = TextField( tokens, {'source_tokens': source_token_indexers}) fields['dataset_index'] = MetadataField(i) fields['source_tokens'] = sequence_field fields['paper_id'] = MetadataField(row['paper_id']) fields['text_segment'] = MetadataField(row['text_segment']) fields['corresponding_section'] = MetadataField( row['corresponding_section']) input_text_sequence_instances.append(Instance(fields)) with open( os.path.join(os.path.dirname(args.output_pkl), 'input_text_sequence_instances.pkl'), 'wb') as handle: pickle.dump(input_text_sequence_instances, handle) print( ">> (info): the input_text_sequence_instances file is successfully saved in the storage.\n" ) # now it's time to encode the now tokenized instances. batch_size = args.batch_size # iterator = PassThroughIterator() iterator = BucketIterator(batch_size=batch_size, sorting_keys=[('source_tokens', 'num_tokens')]) vocabulary = Vocabulary() iterator.index_with(vocabulary) number_of_instances = len(input_text_sequence_instances) data_stream = iterator(iter(input_text_sequence_instances)) output_corpora = dict() output_corpora['dataset_index'] = list() output_corpora['text_segment'] = list() output_corpora['vector_representation'] = list() output_corpora['paper_id'] = list() output_corpora['corresponding_section'] = list() for batches_processed_sofar in tqdm( range(0, number_of_instances // batch_size + 1)): sample = next(data_stream) if args.gpu > -1: sample = nn_util.move_to_device(sample, args.gpu) vector_representations = sequence_encoder( source_token_embeddings( sample['source_tokens'])).data.cpu().numpy() output_corpora['text_segment'] += sample['text_segment'] output_corpora['paper_id'] += sample['paper_id'] output_corpora['corresponding_section'] += sample[ 'corresponding_section'] output_corpora['dataset_index'] += sample['dataset_index'] output_corpora['vector_representation'] += [ numpy.array(e) for e in vector_representations.tolist() ] if batches_processed_sofar > 0 and batches_processed_sofar % 100 == 0: if not os.path.isdir( os.path.join(os.path.dirname(args.output_pkl), 'batches')): os.makedirs( os.path.join(os.path.dirname(args.output_pkl), 'batches')) with open( os.path.join( os.path.dirname(args.output_pkl), 'batches/batch_{}.pkl'.format( batches_processed_sofar)), 'wb') as handle: pickle.dump(output_corpora, handle) for key in output_corpora.keys(): output_corpora[key] = list() print('>> (info): all done.\n')
def text_to_instance( self, # type: ignore tokens: List[str], pos_tags: List[str] = None, gold_tree: Tree = None) -> Instance: """ We take `pre-tokenized` input here, because we don't have a tokenizer in this class. Parameters ---------- tokens : ``List[str]``, required. The tokens in a given sentence. pos_tags ``List[str]``, optional, (default = None). The POS tags for the words in the sentence. gold_tree : ``Tree``, optional (default = None). The gold parse tree to create span labels from. Returns ------- An ``Instance`` containing the following fields: tokens : ``TextField`` The tokens in the sentence. pos_tags : ``SequenceLabelField`` The POS tags of the words in the sentence. Only returned if ``use_pos_tags`` is ``True`` spans : ``ListField[SpanField]`` A ListField containing all possible subspans of the sentence. span_labels : ``SequenceLabelField``, optional. The constiutency tags for each of the possible spans, with respect to a gold parse tree. If a span is not contained within the tree, a span will have a ``NO-LABEL`` label. gold_tree : ``MetadataField(Tree)`` The gold NLTK parse tree for use in evaluation. """ # pylint: disable=arguments-differ text_field = TextField([Token(x) for x in tokens], token_indexers=self._token_indexers) fields: Dict[str, Field] = {"tokens": text_field} if self._use_pos_tags and pos_tags is not None: pos_tag_field = SequenceLabelField(pos_tags, text_field, "pos_tags") fields["pos_tags"] = pos_tag_field elif self._use_pos_tags: raise ConfigurationError( "use_pos_tags was set to True but no gold pos" " tags were passed to the dataset reader.") spans: List[Field] = [] gold_labels = [] if gold_tree is not None: gold_spans_with_pos_tags: Dict[Tuple[int, int], str] = {} self._get_gold_spans(gold_tree, 0, gold_spans_with_pos_tags) gold_spans = { span: label for (span, label) in gold_spans_with_pos_tags.items() if "-POS" not in label } else: gold_spans = None for start, end in enumerate_spans(tokens): spans.append(SpanField(start, end, text_field)) if gold_spans is not None: if (start, end) in gold_spans.keys(): gold_labels.append(gold_spans[(start, end)]) else: gold_labels.append("NO-LABEL") if gold_tree: fields["gold_tree"] = MetadataField(gold_tree) span_list_field: ListField = ListField(spans) fields["spans"] = span_list_field if gold_tree is not None: fields["span_labels"] = SequenceLabelField(gold_labels, span_list_field) return Instance(fields)
def text_to_instance( # type: ignore self, tokens: List[Token], verb_label: List[int], frames: List[str] = None, lemmas: List[str] = None, tags: List[str] = None, sentence_id=None, ) -> Instance: """ We take `pre-tokenized` input here, along with a verb label. The verb label should be a one-hot binary vector, the same length as the tokens, indicating the position of the verb to find arguments for. """ metadata_dict: Dict[str, Any] = {} wordpieces, offsets, start_offsets = self._wordpiece_tokenize_input( [t.text for t in tokens]) new_verbs = _convert_verb_indices_to_wordpiece_indices( verb_label, offsets) frame_indicator = _convert_frames_indices_to_wordpiece_indices( verb_label, offsets, True) # add verb as information to the model # verb_tokens = [token for token, v in zip(wordpieces, new_verbs) if v == 1] # verb_tokens = verb_tokens + [self.tokenizer.sep_token] # if isinstance(self.tokenizer, XLMRobertaTokenizer): # verb_tokens = [self.tokenizer.sep_token] + verb_tokens # wordpieces += verb_tokens # new_verbs += [0 for _ in range(len(verb_tokens))] # frame_indicator += [0 for _ in range(len(verb_tokens))] # In order to override the indexing mechanism, we need to set the `text_id` # attribute directly. This causes the indexing to use this id. text_field = TextField( [ Token(t, text_id=self.tokenizer.convert_tokens_to_ids(t)) for t in wordpieces ], token_indexers=self._token_indexers, ) verb_indicator = SequenceLabelField(new_verbs, text_field) frame_indicator = SequenceLabelField(frame_indicator, text_field) metadata_dict["offsets"] = start_offsets fields: Dict[str, Field] = { "tokens": text_field, "verb_indicator": verb_indicator, "frame_indicator": frame_indicator, } if all(x == 0 for x in verb_label): verb = None verb_index = None else: verb_index = verb_label.index(1) verb = tokens[verb_index].text metadata_dict["words"] = [x.text for x in tokens] metadata_dict["lemmas"] = lemmas metadata_dict["verb"] = verb metadata_dict["verb_index"] = verb_index metadata_dict["sentence_id"] = sentence_id if tags: # roles new_tags = self._convert_tags_to_wordpiece_tags(tags, offsets) new_tags += ["O" for _ in range(len(wordpieces) - len(new_tags))] # frames new_frames = _convert_frames_indices_to_wordpiece_indices( frames, offsets) new_frames += [ "O" for _ in range(len(wordpieces) - len(new_frames)) ] # for model fields["tags"] = SequenceLabelField(new_tags, text_field) fields["frame_tags"] = SequenceLabelField( new_frames, text_field, label_namespace="frames_labels") metadata_dict["gold_tags"] = tags metadata_dict["gold_frame_tags"] = frames fields["metadata"] = MetadataField(metadata_dict) return Instance(fields)
def text_to_instance( self, source_string: str, target_string: str = None) -> Instance: # type: ignore """ Turn raw source string and target string into an ``Instance``. # Parameters source_string : ``str``, required target_string : ``str``, optional (default = None) # Returns Instance See the above for a description of the fields that the instance will contain. """ tokenized_source = self._source_tokenizer.tokenize(source_string) tokenized_source.insert(0, Token(START_SYMBOL)) tokenized_source.append(Token(END_SYMBOL)) source_field = TextField(tokenized_source, self._source_token_indexers) # For each token in the source sentence, we keep track of the matching token # in the target sentence (which will be the OOV symbol if there is no match). source_to_target_field = NamespaceSwappingField( tokenized_source[1:-1], self._target_namespace) meta_fields = { "source_tokens": [x.text for x in tokenized_source[1:-1]] } fields_dict = { "source_tokens": source_field, "source_to_target": source_to_target_field } if target_string is not None: tokenized_target = self._target_tokenizer.tokenize(target_string) tokenized_target.insert(0, Token(START_SYMBOL)) tokenized_target.append(Token(END_SYMBOL)) target_field = TextField(tokenized_target, self._target_token_indexers) fields_dict["target_tokens"] = target_field meta_fields["target_tokens"] = [ y.text for y in tokenized_target[1:-1] ] source_and_target_token_ids = self._tokens_to_ids( tokenized_source[1:-1] + tokenized_target) source_token_ids = source_and_target_token_ids[:len( tokenized_source) - 2] fields_dict["source_token_ids"] = ArrayField( np.array(source_token_ids)) target_token_ids = source_and_target_token_ids[len(tokenized_source ) - 2:] fields_dict["target_token_ids"] = ArrayField( np.array(target_token_ids)) else: source_token_ids = self._tokens_to_ids(tokenized_source[1:-1]) fields_dict["source_token_ids"] = ArrayField( np.array(source_token_ids)) fields_dict["metadata"] = MetadataField(meta_fields) return Instance(fields_dict)
def text_to_instance( self, # type: ignore sentences: List[List[str]], gold_clusters: Optional[List[List[Tuple[int, int]]]] = None, ) -> Instance: """ # Parameters sentences : `List[List[str]]`, required. A list of lists representing the tokenised words and sentences in the document. gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = None) A list of all clusters in the document, represented as word spans. Each cluster contains some number of spans, which can be nested and overlap, but will never exactly match between clusters. # Returns An `Instance` containing the following `Fields`: text : `TextField` The text of the full document. spans : `ListField[SpanField]` A ListField containing the spans represented as `SpanFields` with respect to the document text. span_labels : `SequenceLabelField`, optional The id of the cluster which each possible span belongs to, or -1 if it does not belong to a cluster. As these labels have variable length (it depends on how many spans we are considering), we represent this a as a `SequenceLabelField` with respect to the `spans `ListField`. """ flattened_sentences = [ self._normalize_word(word) for sentence in sentences for word in sentence ] if self._wordpiece_modeling_tokenizer is not None: flat_sentences_tokens, offsets = self._wordpiece_modeling_tokenizer.intra_word_tokenize( flattened_sentences) flattened_sentences = [t.text for t in flat_sentences_tokens] else: flat_sentences_tokens = [ Token(word) for word in flattened_sentences ] text_field = TextField(flat_sentences_tokens, self._token_indexers) cluster_dict = {} if gold_clusters is not None: if self._wordpiece_modeling_tokenizer is not None: for cluster in gold_clusters: for mention_id, mention in enumerate(cluster): start = offsets[mention[0]][0] end = offsets[mention[1]][1] cluster[mention_id] = (start, end) for cluster_id, cluster in enumerate(gold_clusters): for mention in cluster: cluster_dict[tuple(mention)] = cluster_id spans: List[Field] = [] span_labels: Optional[ List[int]] = [] if gold_clusters is not None else None sentence_offset = 0 for sentence in sentences: for start, end in enumerate_spans( sentence, offset=sentence_offset, max_span_width=self._max_span_width): if self._wordpiece_modeling_tokenizer is not None: start = offsets[start][0] end = offsets[end][1] # `enumerate_spans` uses word-level width limit; here we apply it to wordpieces # We have to do this check here because we use a span width embedding that has # only `self._max_span_width` entries, and since we are doing wordpiece # modeling, the span width embedding operates on wordpiece lengths. So a check # here is necessary or else we wouldn't know how many entries there would be. if end - start + 1 > self._max_span_width: continue # We also don't generate spans that contain special tokens if start < self._wordpiece_modeling_tokenizer.num_added_start_tokens: continue if (end >= len(flat_sentences_tokens) - self._wordpiece_modeling_tokenizer. num_added_end_tokens): continue if span_labels is not None: if (start, end) in cluster_dict: span_labels.append(cluster_dict[(start, end)]) else: span_labels.append(-1) spans.append(SpanField(start, end, text_field)) sentence_offset += len(sentence) span_field = ListField(spans) metadata: Dict[str, Any] = {"original_text": flattened_sentences} if gold_clusters is not None: metadata["clusters"] = gold_clusters metadata_field = MetadataField(metadata) fields: Dict[str, Field] = { "text": text_field, "spans": span_field, "metadata": metadata_field, } if span_labels is not None: fields["span_labels"] = SequenceLabelField(span_labels, span_field) return Instance(fields)
def build_instance( self, # type: ignore doc: List[List[str]], clusters: List[List[Tuple[int, int]]] = None, doc_relations: List[Dict[Tuple[Tuple[int, int], Tuple[int, int]], str]] = None, doc_ner_labels: List[Dict[Tuple[int, int], str]] = None, **kwargs) -> Instance: """ Parameters ---------- doc : ``List[List[str]]``, required. A list of lists representing the tokenized words and sentences in the document. clusters : ``Optional[List[List[Tuple[int, int]]]]``, optional (default = None) A list of all clusters in the document, represented as word spans. Each cluster contains some number of spans, which can be nested and overlap, but will never exactly match between clusters. doc_relations : TODO Returns ------- An ``Instance`` containing the following ``Fields``: text : ``TextField`` The text of the full document. spans : ``ListField[SpanField]`` A ListField containing the spans represented as ``SpanFields`` with respect to the document text. span_labels : ``SequenceLabelField``, optional The id of the cluster which each possible span belongs to, or -1 if it does not belong to a cluster. As these labels have variable length (it depends on how many spans we are considering), we represent this a as a ``SequenceLabelField`` with respect to the ``spans ``ListField``. Extra fields: spans : see docstring Shape: (num_spans) 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 sentences_span_indices : list spans (absolute indices) for every sentence, can be used in order to isolate spans between different sentences. By design, the RelEx part considers intra-sentences truth_relations and is able to extract inter-sentence truth_relations with the help of already predicted sets of coreferences. Shape: (sentences_padded, spans_in_sentence_padded) 0 1 2 3 4 5 6 # 7 8 9 # 10 11 12 13 14 15 # # Range: [0, ..., num_spans-1], # is padding TODO sentences_truth_spans : relative indices in sentence_spans correspond to at least one relation from truth. Intended to be used for effective packing and padding of the sparse matrix. PyTorch lacks of (at least stable) support of sparse tensors, and we aim to implement it ourselves. The matrix is not going to be encoded using COO because the sparsity of matrix is just an effect of sparsity of the truth spans. This matrix is simply compressed matrix w.r.t. COO-encoded spans that are going to be used for encoding the relation matrix. Shape: (sentences_padded, gold_spans_in_sentence_padded) 1 3 0 2 0 1 2 # 1 # Range: [0, ..., spans_in_sentence - 1], # is padding TODO sentences_spans_in_truth : simply the inverse of `sentences_truth_spans` This matrix can be also interpreted as boolean matrix of if the span occurs is truth span: if the element is not padded, it is, and the element points out where they occur in compressed matrix. Shape: (sentences_padded, spans_in_sentence_padded) # 0 # 1 0 # 1 # 0 1 # # # # 0 # # 0 # # Range: [0, ..., gold_spans_in_sentence_padded - 1], # is padding TODO sentences_relations : TODO Shape: (sentences_padded, gold_spans_in_sentence_padded, gold_spans_in_sentence_padded) Range: [0, ..., num_classes - 1], # is padding sentences_ner_labels : TODO Shape: TODO Range: TODO """ metadatas: Dict[str, Any] = {} flattened_doc = [ self._normalize_word(word) for sentence in doc for word in sentence ] metadatas["doc_tokens"] = doc metadatas["original_text"] = flattened_doc metadatas.update(kwargs) text_field = TextField([Token(word) for word in flattened_doc], self._token_indexers) spans: List[SpanField] = [] doc_span_offsets: List[List[int]] = [] # Construct spans and mappings sentence_offset = 0 for sentence in doc: sentence_spans: List[int] = [] for start, end in enumerate_spans( sentence, offset=sentence_offset, max_span_width=self._max_span_width): absolute_index = len(spans) spans.append(SpanField(start, end, text_field)) sentence_spans.append(absolute_index) sentence_offset += len(sentence) doc_span_offsets.append(sentence_spans) # Just making fields out of the lists spans_field = OptionalListField(spans, empty_field=SpanField( -1, -1, text_field).empty_field()) doc_span_offsets_field = ListField([ OptionalListField([ IndexField(span_offset, spans_field) for span_offset in sentence_span_offsets ], empty_field=IndexField( -1, spans_field).empty_field()) for sentence_span_offsets in doc_span_offsets ]) # num_sentences = len(sentences) # num_spans = len(spans) # inverse_mapping = -np.ones(shape=(num_sentences, num_spans), dtype=int) # for sentence_id, indices in enumerate(sentences_span_indices): # for gold_index, real_index in enumerate(indices.array): # inverse_mapping[sentence_id, real_index] = gold_index # sentences_spans_field = ListField([ # ListField(spans) for spans in sentences_span_indices # ]) # sentences_span_inverse_mapping_field = ArrayField(inverse_mapping, padding_value=-1) fields: Dict[str, Field] = { "text": text_field, "spans": spans_field, "doc_span_offsets": doc_span_offsets_field } # TODO TODO TODO rename sentences to doc, sencence to snt for key, value in metadatas.items(): fields[key] = MetadataField(value) if clusters is None or doc_relations is None: return Instance(fields) # Here we can be sure both `clusters` and `doc_relations` are given. # However, we can be sure yet whether `doc_ner_labels` is given or not. # # TRUTH AFTER THIS ONLY # fields["clusters"] = MetadataField(clusters) cluster_dict = {(start, end): cluster_id for cluster_id, cluster in enumerate(clusters) for start, end in cluster} truth_spans = { span for sentence in doc_relations for spans, label in sentence.items() for span in spans } fields["truth_spans"] = MetadataField(truth_spans) span_labels: Optional[List[int]] = [] doc_truth_spans: List[List[int]] = [] doc_spans_in_truth: List[List[int]] = [] for sentence, sentence_spans_field in zip(doc, doc_span_offsets_field): sentence_truth_spans: List[IndexField] = [] sentence_spans_in_truth: List[int] = [] for relative_index, span in enumerate(sentence_spans_field): absolute_index = cast(IndexField, span).sequence_index span_field: SpanField = cast(SpanField, spans_field[absolute_index]) start = span_field.span_start end = span_field.span_end if (start, end) in cluster_dict: span_labels.append(cluster_dict[(start, end)]) else: span_labels.append(-1) compressed_index = -1 if (start, end) in truth_spans: compressed_index = len(sentence_truth_spans) sentence_truth_spans.append( IndexField(relative_index, sentence_spans_field)) sentence_spans_in_truth.append(compressed_index) sentence_truth_spans_field = OptionalListField( sentence_truth_spans, empty_field=IndexField(-1, sentence_spans_field).empty_field()) doc_truth_spans.append(sentence_truth_spans_field) sentence_spans_in_truth_field = OptionalListField( [ IndexField(compressed_index, sentence_truth_spans_field) for compressed_index in sentence_spans_in_truth ], empty_field=IndexField( -1, sentence_truth_spans_field).empty_field()) doc_spans_in_truth.append(sentence_spans_in_truth_field) span_labels_field = SequenceLabelField(span_labels, spans_field) doc_truth_spans_field = ListField(doc_truth_spans) doc_spans_in_truth_field = ListField(doc_spans_in_truth) fields["span_labels"] = span_labels_field fields["doc_truth_spans"] = doc_truth_spans_field fields["doc_spans_in_truth"] = doc_spans_in_truth_field # "sentences_span_inverse_mapping": sentences_span_inverse_mapping_field, # "truth_relations": MetadataField(truth_relations) # our code # test code # sample_label = LabelField('foo') # sample_list = ListField([sample_label, sample_label]) # sample_seq_labels = SequenceLabelField(labels=['bar', 'baz'], # sequence_field=sample_list) # # empty_seq_labels = sample_seq_labels.empty_field() # TODO reverse matrix generation tactic # TODO Add dummy doc_relex_matrices: List[AdjacencyField] = [] for (sentence, truth_relations, sentence_spans, truth_spans_field, spans_in_truth) in zip(doc, doc_relations, doc_span_offsets, doc_truth_spans_field, doc_spans_in_truth_field): relations = collections.defaultdict(str) for (span_a, span_b), label in truth_relations.items(): # Span absolute indices (document-wide indexing) try: a_absolute_index = spans.index(span_a) b_absolute_index = spans.index(span_b) # Fill the dict as sparse matrix, padded with zeros relations[a_absolute_index, b_absolute_index] = label except ValueError: logger.warning('Span not found') indices: List[Tuple[int, int]] = [] labels: List[str] = [] for span_a, span_b in itertools.product( enumerate(truth_spans_field), repeat=2): a_compressed_index, a_relative = cast(Tuple[int, IndexField], span_a) b_compressed_index, b_relative = cast(Tuple[int, IndexField], span_b) a_absolute = sentence_spans[a_relative.sequence_index] b_absolute = sentence_spans[b_relative.sequence_index] label = relations[a_absolute, b_absolute] indices.append((a_compressed_index, b_compressed_index)) labels.append(label) doc_relex_matrices.append( AdjacencyField(indices=indices, labels=labels, sequence_field=truth_spans_field, label_namespace="relation_labels") ) # TODO pad with zeros maybe? # fields["doc_relations"] = MetadataField(doc_relations) fields["doc_relation_labels"] = ListField(doc_relex_matrices) # gold_candidates = [] # gold_candidate_labels = [] # # for sentence in sentences_relations: # # candidates: List[ListField[SpanField]] = [] # candidate_labels: List[LabelField] = [] # # for label, (a_start, a_end), (b_start, b_end) in sentence: # a_span = SpanField(a_start, a_end, text_field) # b_span = SpanField(b_start, b_end, text_field) # candidate_field = ListField([a_span, b_span]) # label_field = OptionalLabelField(label, 'relation_labels') # # candidates.append(candidate_field) # candidate_labels.append(label_field) # # # if not candidates: # # continue # # # TODO very very tmp # # empty_text = text_field.empty_field() # empty_span = SpanField(-1, -1, empty_text).empty_field() # empty_candidate = ListField([empty_span, empty_span]).empty_field() # empty_candidates = ListField([empty_candidate]).empty_field() # empty_label = OptionalLabelField('', 'relation_labels') # .empty_field()? # empty_candidate_labels = ListField([empty_label]) # ? .empty_field() ? # # if candidates: # candidates_field = ListField(candidates) # candidate_labels_field = ListField(candidate_labels) # else: # candidates_field = empty_candidates # candidate_labels_field = empty_candidate_labels # # gold_candidates.append(candidates_field) # gold_candidate_labels.append(candidate_labels_field) # # fields["gold_candidates"] = ListField(gold_candidates) # fields["gold_candidate_labels"] = ListField(gold_candidate_labels) # # fields["sentences_relations"] = MetadataField(sentences_relations) if doc_ner_labels is None: return Instance(fields) # NER doc_ner: List[OptionalListField[LabelField]] = [] sentence_offset = 0 for sentence, sentence_ner_dict in zip(doc, doc_ner_labels): sentence_ner_labels: List[LabelField] = [] for start, end in enumerate_spans( sentence, offset=sentence_offset, max_span_width=self._max_span_width): if (start, end) in sentence_ner_dict: label = sentence_ner_dict[(start, end)] sentence_ner_labels.append(LabelField(label, 'ner_labels')) else: sentence_ner_labels.append(LabelField('O', 'ner_labels')) sentence_offset += len(sentence) sentence_ner_labels_field = OptionalListField( sentence_ner_labels, empty_field=LabelField('*', 'ner_tags').empty_field()) doc_ner.append(sentence_ner_labels_field) doc_ner_field = ListField(doc_ner) fields["doc_ner_labels"] = doc_ner_field return Instance(fields)
def text_to_instance( self, # type: ignore question: str, table_info: Union[str, JsonDict], example_lisp_string: str = None, dpd_output: List[str] = None, tokenized_question: List[Token] = None) -> Instance: """ Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV files, which we use for training. For running a demo, we may want to provide tables in a JSON format. To make this method compatible with both, we take ``table_info``, which can either be a filename, or a dict. We check the argument's type and call the appropriate method in ``TableQuestionKnowledgeGraph``. Parameters ---------- question : ``str`` Input question table_info : ``str`` or ``JsonDict`` Table filename or the table content itself, as a dict. See ``TableQuestionKnowledgeGraph.read_from_json`` for the expected format. example_lisp_string : ``str``, optional The original (lisp-formatted) example string in the WikiTableQuestions dataset. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE for evaluating logical forms during training. It isn't otherwise used for anything. dpd_output : List[str], optional List of logical forms, produced by dynamic programming on denotations. Not required during test. tokenized_question : ``List[Token]``, optional If you have already tokenized the question, you can pass that in here, so we don't duplicate that work. You might, for example, do batch processing on the questions in the whole dataset, then pass the result in here. """ # pylint: disable=arguments-differ tokenized_question = tokenized_question or self._tokenizer.tokenize( question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) if isinstance(table_info, str): table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_file( table_info, tokenized_question) table_metadata = MetadataField(open(table_info).readlines()) else: table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_json( table_info) table_metadata = MetadataField(table_info) table_field = KnowledgeGraphField( table_knowledge_graph, tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, feature_extractors=self._linking_feature_extractors, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field } if self._include_table_metadata: fields['table_metadata'] = table_metadata if example_lisp_string: fields['example_lisp_string'] = MetadataField(example_lisp_string) # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if dpd_output: action_sequence_fields: List[Field] = [] for logical_form in dpd_output: if not self._should_keep_logical_form(logical_form): logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_info}') continue try: expression = world.parse_logical_form(logical_form) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_info}') continue except: logger.error(logical_form) raise action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_info}') logger.debug(f'Logical form was: {logical_form}') continue if len(action_sequence_fields) >= self._max_dpd_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance(self, sample: list) -> Instance: fields = {} text: str = sample['text'].strip() labels = sample['aspect_terms'] pieces = [] piece_labels = [] last_label_end_index = 0 for i, label in enumerate(labels): if label.from_index != last_label_end_index: pieces.append(text[last_label_end_index:label.from_index]) piece_labels.append(0) if self.configuration[ 'model_name'] == 'aspect-term-aware-bert-syntax': pieces.append('#') piece_labels.append(2) pieces.append(text[label.from_index:label.to_index]) piece_labels.append(1) if self.configuration[ 'model_name'] == 'aspect-term-aware-bert-syntax': pieces.append('#') piece_labels.append(2) last_label_end_index = label.to_index if i == len(labels) - 1 and label.to_index != len(text): pieces.append(text[label.to_index:]) piece_labels.append(0) # words_of_pieces = [self.tokenizer(piece.lower().strip()) for piece in pieces] words_of_pieces = [] word_indices_of_aspect_terms = [] start_index = 1 # ,[word,word piecesindex,word piecesindex] word_and_word_pieces = [] for i in range(len(pieces)): piece = pieces[i] doc = self.spacy_nlp(piece.lower().strip()) words_of_piece = [word.text for word in doc] word_pieces_of_words = [] word_pieces_start_index = start_index for word in words_of_piece: word_pieces_of_word = self.tokenizer(word.lower().strip()) word_pieces_of_words.extend(word_pieces_of_word) word_pieces_end_index = word_pieces_start_index + len( word_pieces_of_word) if piece_labels[i] != 2: word_and_word_pieces.append([ word, word_pieces_start_index, word_pieces_end_index, word_pieces_of_word ]) word_pieces_start_index = word_pieces_end_index words_of_pieces.append(word_pieces_of_words) end_index = start_index + len(word_pieces_of_words) if piece_labels[i] == 1: word_indices_of_aspect_terms.append([start_index, end_index]) start_index = end_index sample['word_indices_of_aspect_terms'] = word_indices_of_aspect_terms words = [] for words_of_piece in words_of_pieces: words.extend(words_of_piece) words = ['[CLS]'] + words + ['[SEP]'] sample['words'] = words graph = self._build_graph(word_and_word_pieces, len(words)) sample['graph'] = graph tokens = [Token(word) for word in words] sentence_field = TextField(tokens, self.token_indexers) fields['tokens'] = sentence_field position = [Token(str(i)) for i in range(len(tokens))] position_field = TextField(position, self.position_indexers) fields['position'] = position_field if self.configuration['sample_mode'] == 'single': max_aspect_term_num = 1 else: max_aspect_term_num = self.configuration['max_aspect_term_num'] polarity_labels = [-100] * max_aspect_term_num for i, aspect_term in enumerate(sample['aspect_terms']): polarity_labels[i] = self.polarities.index(aspect_term.polarity) label_field = ArrayField(np.array(polarity_labels)) fields["label"] = label_field polarity_mask = [ 1 if polarity_labels[i] != -100 else 0 for i in range(max_aspect_term_num) ] polarity_mask_field = ArrayField(np.array(polarity_mask)) fields['polarity_mask'] = polarity_mask_field # stop_word_labels = [1 if word in english_stop_words else 0 for word in words] # stop_word_num = sum(stop_word_labels) # stop_word_labels = [label / stop_word_num for label in stop_word_labels] # sample.append(stop_word_labels) sample_field = MetadataField(sample) fields["sample"] = sample_field return Instance(fields)
def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance: question_tokens = self._read_tokens_from_json_list( json_obj['question_tokens']) question_field = TextField(question_tokens, self._question_token_indexers) table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines( json_obj['table_lines'], question_tokens) entity_tokens = [ self._read_tokens_from_json_list(token_list) for token_list in json_obj['entity_texts'] ] table_field = KnowledgeGraphField( table_knowledge_graph, question_tokens, tokenizer=None, token_indexers=self._table_token_indexers, entity_tokens=entity_tokens, linking_features=json_obj['linking_features'], include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) example_string_field = MetadataField(json_obj['example_lisp_string']) fields = { 'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field, 'example_lisp_string': example_string_field } if 'target_action_sequences' in json_obj or 'agenda' in json_obj: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if 'target_action_sequences' in json_obj: action_sequence_fields: List[Field] = [] for sequence in json_obj['target_action_sequences']: index_fields: List[Field] = [] for production_rule in sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) fields['target_action_sequences'] = ListField( action_sequence_fields) if 'agenda' in json_obj: agenda_index_fields: List[Field] = [] for agenda_action in json_obj['agenda']: agenda_index_fields.append( IndexField(action_map[agenda_action], action_field)) fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields)
def _read(self, file_path: str): file_path = cached_path(file_path) logger.info("Reading file at %s", file_path) with open(file_path) as dataset_file: dataset = json.load(dataset_file) # load span data to json object span_file = open(self._span_file_path) span_json = json.load(span_file) # read dataset logger.info("Reading the dataset") for data, best_spans in zip(dataset, span_json): answer = data['answers'][0] question = data['query'] passages_json = data['passages'] passages = [ passages_json[i]['passage_text'] for i in range(len(passages_json)) ] passages_is_selected = [ passages_json[i]['is_selected'] for i in range(len(passages_json)) ] normalized_answer = util.normalize_text_msmarco(answer) normalized_question = util.normalize_text_msmarco(question) tokenized_question = self._tokenizer.tokenize(normalized_question) question_field = TextField(tokenized_question, self._token_indexers) fields = {'question': question_field} # choose span with score larger than 0.9 if best_spans['score'] > 0.9: # set passage field normalized_passage = util.normalize_text_msmarco( passages[best_spans['passage_idx']]) tokenized_passage = self._tokenizer.tokenize( normalized_passage) passage_field = TextField(tokenized_passage, self._token_indexers) fields['passage'] = passage_field # set span field span_start_list = [] span_end_list = [] for span in best_spans['best_spans']: span_start_field = IndexField(int(span[0]), passage_field) span_end_field = IndexField(int(span[1]), passage_field) span_start_list.append(span_start_field) span_end_list.append(span_end_field) fields['span_start'] = ListField(span_start_list) fields['span_end'] = ListField(span_end_list) # set metadata field passage_offsets = [(token.idx, token.idx + len(token.text)) for token in tokenized_passage] metadata = { 'original_passage': normalized_passage, 'token_offsets': passage_offsets, 'question_tokens': [token.text for token in tokenized_question], 'passage_tokens': [token.text for token in tokenized_passage] } if answer: metadata['answer_texts'] = [normalized_answer] fields['metadata'] = MetadataField(metadata) yield Instance(fields)
def text_to_instance( self, # type: ignore question: str, image_id: str, logical_form: str = None, gold_question_attentions: np.ndarray = None, objects: List[List[int]] = None, denotation: str = None, ) -> Instance: tokenized_sentence = self._tokenizer.tokenize(question.lower()) sentence_field = TextField(tokenized_sentence, self._token_indexers) if self.img_data is not None: img_info = self.img_data[image_id] else: img_info = pickle.load( open(os.path.join(self._image_feat_cache_dir, image_id), "rb")) visual_feat = img_info["features"].copy() boxes = img_info["boxes"].copy() fields = { "visual_feat": ArrayField(visual_feat), "pos": ArrayField(boxes), "question_field": sentence_field, "image_id": MetadataField(image_id), "actions": self._production_rule_field, } if denotation is not None: if denotation.lower() in {"yes", "true"}: fields["denotation"] = ArrayField(np.array(1)) else: fields["denotation"] = ArrayField(np.array(0)) if logical_form is not None: fields["logical_form"] = MetadataField(logical_form) actions = self._language.logical_form_to_action_sequence( logical_form) index_fields = [] for production_rule in actions: index_fields.append( IndexField(self._action_map[production_rule], self._production_rule_field)) fields["target_action_sequence"] = ListField(index_fields) if gold_question_attentions is not None: fields["gold_question_attentions"] = ArrayField( gold_question_attentions) if objects is not None and self.object_supervision: proposals = torch.from_numpy(boxes) gold_proposal_choices = np.zeros((len(objects), boxes.shape[0])) for i in range(len(objects)): if len(objects[i]) > 0: ious = box_iou( torch.from_numpy(np.array(objects[i])).float(), proposals.float(), ) ious = ious.numpy() for j in range(ious.shape[0]): if ious[j].max() > 0: gold_proposal_choices[ i, ious[j, :] >= self.positive_threshold] = 1 gold_proposal_choices[ i, ious[j, :] < self.negative_threshold] = -1 fields["gold_object_choices"] = ArrayField(gold_proposal_choices) if gold_proposal_choices.max() <= 0 and self.require_some_positive: return None return Instance(fields) if self.object_supervision: return None return Instance(fields)
def text_to_instance( self, # type: ignore sent_tasks: Dict, full_data: List[str], col_idxs: Dict[str, int], is_train: bool, task2type: Dict[str, str], dataset: str) -> Instance: """ converts the previously read data into an AllenNLP Instance, containing mainly a TextField and one or more *LabelField's """ fields: Dict[str, Field] = {} tokens = TextField(sent_tasks['tokens'], self.token_indexers) for task in sent_tasks: if task == 'tokens': fields[task] = tokens fields['src_words'] = SequenceLabelField( [str(x) for x in sent_tasks[task]], tokens, label_namespace="src_tokens") elif task2type[task] == 'dependency': fields[task + '_rels'] = SequenceLabelField( [x[0] for x in sent_tasks[task]], tokens, label_namespace=task + '_rels') fields[task + '_head_indices'] = SequenceLabelField( [x[1] for x in sent_tasks[task]], tokens, label_namespace=task + '_head_indices') elif task2type[task] == "multiseq": label_sequence = [] # For each token label, check if it is a multilabel and handle it for raw_label in sent_tasks[task]: label_list = raw_label.split("$") label_sequence.append(label_list) fields[task] = SequenceMultiLabelField(label_sequence, tokens, label_namespace=task) elif task2type[task] == 'classification': fields[task] = LabelField(sent_tasks[task], label_namespace=task) elif task2type[task] == 'seq2seq': fields[task + '_target'] = TextField( sent_tasks[task], self._target_token_indexers) fields[task + '_target_words'] = SequenceLabelField( [str(x) for x in sent_tasks[task]], fields[task + '_target'], label_namespace=task + "_target_words") else: # seq labeling fields[task] = SequenceLabelField(sent_tasks[task], tokens, label_namespace=task) fields['dataset'] = LabelField(dataset, label_namespace='dataset') sent_tasks["full_data"] = full_data sent_tasks["col_idxs"] = col_idxs sent_tasks['is_train'] = is_train fields["metadata"] = MetadataField(sent_tasks) return Instance(fields)
def text_to_instance( self, # type: ignore question: str, table_lines: List[List[str]], target_values: List[str] = None, offline_search_output: List[str] = None) -> Instance: """ Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset, or simply in a tsv format where each line corresponds to a row and the cells are tab-separated. Parameters ---------- question : ``str`` Input question table_lines : ``List[List[str]]`` The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines`` for the expected format. target_values : ``List[str]``, optional Target values for the denotations the logical forms should execute to. Not required for testing. offline_search_output : ``List[str]``, optional List of logical forms, produced by offline search. Not required during test. """ # pylint: disable=arguments-differ tokenized_question = self._tokenizer.tokenize(question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) metadata: Dict[str, Any] = { "question_tokens": [x.text for x in tokenized_question] } table_context = TableQuestionContext.read_from_lines( table_lines, tokenized_question) world = WikiTablesLanguage(table_context) world_field = MetadataField(world) # Note: Not passing any featre extractors when instantiating the field below. This will make # it use all the available extractors. table_field = KnowledgeGraphField( table_context.get_table_knowledge_graph(), tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_productions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_instance_specific_entity( rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'metadata': MetadataField(metadata), 'table': table_field, 'world': world_field, 'actions': action_field } if target_values is not None: target_values_field = MetadataField(target_values) fields['target_values'] = target_values_field # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if offline_search_output: action_sequence_fields: List[Field] = [] for logical_form in offline_search_output: try: action_sequence = world.logical_form_to_action_sequence( logical_form) index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue except: logger.error(logical_form) raise if len(action_sequence_fields ) >= self._max_offline_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(conservative=True): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance( self, agenda: List[str], event_context: List[str], text_context: List[str], indices: List[int], forthcoming_event: str, e_f_index: int, # set to none so it is compartible with previous models. same for model.forward. prev_reg_event: str = None, succ_reg_event: str = None, target: List[str] = None) -> Instance: """ ! a dummy end of story event is added to the end of the agenda. No text will be generated for it. :param event_context: :param text_context: :param indices: the indices for segmentation. collect indices of the first token of the successive segment, so slicing could be used directly. :param forthcoming_event: :param target: the target sequence, ending with <eos>. :return: """ ''' merge context event labels if configured. Forth coming event is always merged. ''' scenario = forthcoming_event[last_index_of(forthcoming_event, '_') + 1:] agenda_field = TextField([ Token(self.merged(event)) for event in agenda + ['{}_{}'.format(GLOBAL_CONSTANTS.ending_event, scenario)] ], self.event_indexers) if self.merge_irregular_events: event_context_field = TextField( [Token(self.merged(t)) for t in event_context], self.event_indexers) else: event_context_field = TextField([Token(t) for t in event_context], self.event_indexers) text_context_field = TextField([Token(t) for t in text_context], self.word_indexers) forthcoming_event_field = TextField( [Token(self.merged(forthcoming_event))], self.event_indexers) fields = { 'agenda': agenda_field, 'event_context': event_context_field, 'text_context': text_context_field, 'forthcoming_event': forthcoming_event_field, 'indices': MetadataField(indices), 'is_flesh_event': MetadataField(self.is_flesh_event(forthcoming_event)), 'e_f_index': MetadataField(e_f_index) } if target: target_field = TextField([Token(t) for t in target], self.word_indexers) fields['target'] = target_field if prev_reg_event: prev_reg_event_field = TextField( [Token(self.merged(prev_reg_event))], self.event_indexers) fields['previous_regular_event'] = prev_reg_event_field if succ_reg_event: succ_reg_event_field = TextField( [Token(self.merged(succ_reg_event))], self.event_indexers) fields['succesive_regular_event'] = succ_reg_event_field ''' an instance of data is constructed with a dict ''' return Instance(fields)
def make_reading_comprehension_instance( question_tokens: List[Token], passage_tokens: List[Token], token_indexers: Dict[str, TokenIndexer], passage_text: str, token_spans: List[Tuple[int, int]] = None, answer_texts: List[str] = None, additional_metadata: Dict[str, Any] = None) -> Instance: """ Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use in a reading comprehension model. Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both ``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts`` and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end`` fields, which are both ``IndexFields``. Parameters ---------- question_tokens : ``List[Token]`` An already-tokenized question. passage_tokens : ``List[Token]`` An already-tokenized passage that contains the answer to the given question. token_indexers : ``Dict[str, TokenIndexer]`` Determines how the question and passage ``TextFields`` will be converted into tensors that get input to a model. See :class:`TokenIndexer`. passage_text : ``str`` The original passage text. We need this so that we can recover the actual span from the original passage that the model predicts as the answer to the question. This is used in official evaluation scripts. token_spans : ``List[Tuple[int, int]]``, optional Indices into ``passage_tokens`` to use as the answer to the question for training. This is a list because there might be several possible correct answer spans in the passage. Currently, we just select the most frequent span in this list (i.e., SQuAD has multiple annotations on the dev set; this will select the span that the most annotators gave as correct). answer_texts : ``List[str]``, optional All valid answer strings for the given question. In SQuAD, e.g., the training set has exactly one answer per question, but the dev and test sets have several. TriviaQA has many possible answers, which are the aliases for the known correct entity. This is put into the metadata for use with official evaluation scripts, but not used anywhere else. additional_metadata : ``Dict[str, Any]``, optional The constructed ``metadata`` field will by default contain ``original_passage``, ``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If you want any other metadata to be associated with each instance, you can pass that in here. This dictionary will get added to the ``metadata`` dictionary we already construct. """ additional_metadata = additional_metadata or {} fields: Dict[str, Field] = {} passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens] # This is separate so we can reference it later with a known type. passage_field = TextField(passage_tokens, token_indexers) fields['passage'] = passage_field fields['question'] = TextField(question_tokens, token_indexers) metadata = { 'original_passage': passage_text, 'token_offsets': passage_offsets, 'question_tokens': [token.text for token in question_tokens], 'passage_tokens': [token.text for token in passage_tokens], } if answer_texts: metadata['answer_texts'] = answer_texts if token_spans: # There may be multiple answer annotations, so we pick the one that occurs the most. This # only matters on the SQuAD dev set, and it means our computed metrics ("start_acc", # "end_acc", and "span_acc") aren't quite the same as the official metrics, which look at # all of the annotations. This is why we have a separate official SQuAD metric calculation # (the "em" and "f1" metrics use the official script). candidate_answers: Counter = Counter() for span_start, span_end in token_spans: candidate_answers[(span_start, span_end)] += 1 span_start, span_end = candidate_answers.most_common(1)[0][0] fields['span_start'] = IndexField(span_start, passage_field) fields['span_end'] = IndexField(span_end, passage_field) metadata.update(additional_metadata) fields['metadata'] = MetadataField(metadata) return Instance(fields)
def text_to_instance( self, # type: ignore utterances: List[str], sql_query_labels: List[str] = None) -> Instance: # pylint: disable=arguments-differ """ Parameters ---------- utterances: ``List[str]``, required. List of utterances in the interaction, the last element is the current utterance. sql_query_labels: ``List[str]``, optional The SQL queries that are given as labels during training or validation. """ if self._num_turns_to_concatenate: utterances[-1] = f' {END_OF_UTTERANCE_TOKEN} '.join( utterances[-self._num_turns_to_concatenate:]) utterance = utterances[-1] action_sequence: List[str] = [] if not utterance: return None world = AtisWorld(utterances=utterances) if sql_query_labels: # If there are multiple sql queries given as labels, we use the shortest # one for training. sql_query = min(sql_query_labels, key=len) try: action_sequence = world.get_action_sequence(sql_query) except ParseError: action_sequence = [] logger.debug(f'Parsing error') tokenized_utterance = self._tokenizer.tokenize(utterance.lower()) utterance_field = TextField(tokenized_utterance, self._token_indexers) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): nonterminal, _ = production_rule.split(' ->') # The whitespaces are not semantically meaningful, so we filter them out. production_rule = ' '.join([ token for token in production_rule.split(' ') if token != 'ws' ]) field = ProductionRuleField(production_rule, self._is_global_rule(nonterminal)) production_rule_fields.append(field) action_field = ListField(production_rule_fields) action_map = { action.rule: i # type: ignore for i, action in enumerate(action_field.field_list) } index_fields: List[Field] = [] world_field = MetadataField(world) fields = { 'utterance': utterance_field, 'actions': action_field, 'world': world_field, 'linking_scores': ArrayField(world.linking_scores) } if sql_query_labels is not None: fields['sql_queries'] = MetadataField(sql_query_labels) if self._keep_if_unparseable or action_sequence: for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) if not action_sequence: index_fields = [IndexField(-1, action_field)] action_sequence_field = ListField(index_fields) fields['target_action_sequence'] = action_sequence_field else: # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly # to keep it, then we will skip the it. return None return Instance(fields)
def text_to_instance( self, # type: ignore sentences: List[List[str]], gold_clusters: Optional[List[List[Tuple[int, int]]]] = None, ) -> Instance: """ # Parameters sentences : `List[List[str]]`, required. A list of lists representing the tokenised words and sentences in the document. gold_clusters : `Optional[List[List[Tuple[int, int]]]]`, optional (default = None) A list of all clusters in the document, represented as word spans. Each cluster contains some number of spans, which can be nested and overlap, but will never exactly match between clusters. # Returns An `Instance` containing the following `Fields`: text : `TextField` The text of the full document. spans : `ListField[SpanField]` A ListField containing the spans represented as `SpanFields` with respect to the document text. span_labels : `SequenceLabelField`, optional The id of the cluster which each possible span belongs to, or -1 if it does not belong to a cluster. As these labels have variable length (it depends on how many spans we are considering), we represent this a as a `SequenceLabelField` with respect to the `spans `ListField`. """ flattened_sentences = [ self._normalize_word(word) for sentence in sentences for word in sentence ] metadata: Dict[str, Any] = {"original_text": flattened_sentences} if gold_clusters is not None: metadata["clusters"] = gold_clusters text_field = TextField([Token(word) for word in flattened_sentences], self._token_indexers) cluster_dict = {} if gold_clusters is not None: for cluster_id, cluster in enumerate(gold_clusters): for mention in cluster: cluster_dict[tuple(mention)] = cluster_id spans: List[Field] = [] span_labels: Optional[ List[int]] = [] if gold_clusters is not None else None sentence_offset = 0 for sentence in sentences: for start, end in enumerate_spans( sentence, offset=sentence_offset, max_span_width=self._max_span_width): if span_labels is not None: if (start, end) in cluster_dict: span_labels.append(cluster_dict[(start, end)]) else: span_labels.append(-1) spans.append(SpanField(start, end, text_field)) sentence_offset += len(sentence) span_field = ListField(spans) metadata_field = MetadataField(metadata) fields: Dict[str, Field] = { "text": text_field, "spans": span_field, "metadata": metadata_field, } if span_labels is not None: fields["span_labels"] = SequenceLabelField(span_labels, span_field) return Instance(fields)
def text_to_instance( self, # type: ignore tokens: List[Token], verb_label: List[int], parseTree: Tree, tags: List[str] = None, fout=None) -> Instance: """ We take `pre-tokenized` input here, along with a verb label. The verb label should be a one-hot binary vector, the same length as the tokens, indicating the position of the verb to find arguments for. """ # pylint: disable=arguments-differ # Convert tags to BIOUL QUESTION - BIO or IOB1? # print(f"Tags before: {tags}") if (self.label_encoding == "BIOUL"): if (tags is not None): old_tags = deepcopy(tags) tags = to_bioul(tags, encoding="BIO") try: spans = bioul_tags_to_spans(tags) except InvalidTagSequence: print(f"Old tags: {old_tags}") print(f"New tags: {tags}\n") # Create span matrix from parse tree leftLabelsTree = leftMost(parseTree) rightLabelsTree = rightMost(parseTree) # leaves = [] # right_leaves = [] # get_leaves(parseTree, leaves) # get_leaves(parseTree, right_leaves) # assert(leaves == right_leaves) # leaf2idx = {} # for idx, leaf in enumerate(leaves): # leaf2idx[leaf] = idx leftList = [] rightList = [] addToList(leftLabelsTree, leftList) addToList(rightLabelsTree, rightList) if len(leftList) != len(rightList): raise Exception( f"For tree {parseTree}, leftList and rightList lengths do not match" ) span_matrix = np.zeros([len(tokens), len(tokens)]) for idx in range(len(leftList)): leftLabel, rightLabel = leftList[idx], rightList[idx] if (leftLabel == rightLabel): continue span_matrix[leftLabel, rightLabel] = 1 # print(f"Tags after: {tags}\n") # print(tokens) # print(verb_label) # print(tags) fields: Dict[str, Field] = {} text_field = TextField(tokens, token_indexers=self._token_indexers) fields['tokens'] = text_field fields['verb_indicator'] = SequenceLabelField(verb_label, text_field) if (self.label_encoding == "BIOUL"): fields['span_matrix'] = ArrayField(span_matrix) if all([x == 0 for x in verb_label]): verb = None else: verb = tokens[verb_label.index(1)].text metadata_dict = {"words": [x.text for x in tokens], "verb": verb} if tags: fields['tags'] = SequenceLabelField(tags, text_field) metadata_dict["gold_tags"] = tags fields["metadata"] = MetadataField(metadata_dict) if (fout is not None): srl_dict = {"parse_tree": parseTree, "span_matrix": span_matrix} pickle.dump(srl_dict, fout) return Instance(fields)