class BaseReader(DatasetReader): def __init__( self, token_indexers: Dict[str, TokenIndexer], human_prob: float = 1.0, lazy: bool = False, ) -> None: super().__init__(lazy=lazy) self._tokenizer = WhitespaceTokenizer() self._token_indexers = token_indexers self._human_prob = human_prob @overrides def _read(self, file_path): rs = RandomState(seed=1000) with open(cached_path(file_path), "r") as data_file: for _, line in enumerate(data_file.readlines()): items = json.loads(line) document = items["document"] annotation_id = items["annotation_id"] query = items.get("query", None) label = items.get("label", None) rationale = items.get( "rationale", []) if rs.random_sample() < self._human_prob else [] if label is not None: label = str(label).replace(" ", "_") instance = self.text_to_instance( annotation_id=annotation_id, document=document, query=query, label=label, rationale=rationale, ) yield instance @overrides def text_to_instance( self, annotation_id: str, document: str, query: str = None, label: str = None, rationale: List[tuple] = None, ) -> Instance: # type: ignore # pylint: disable=arguments-differ fields = {} document_tokens = [ to_token(t.text) for t in self._tokenizer.tokenize(document) ] human_rationale_labels = [0] * len(document_tokens) for s, e in rationale: for i in range(s, e): human_rationale_labels[i] = 1 if query is not None: query_tokens = [ to_token(t.text) for t in self._tokenizer.tokenize(query) ] else: query_tokens = [] for index_name, indexer in self._token_indexers.items(): if hasattr(indexer, "add_token_info"): indexer.add_token_info(document_tokens, index_name) indexer.add_token_info(query_tokens, index_name) fields["document"] = MetadataField({ "tokens": document_tokens, "reader_object": self }) fields["query"] = MetadataField({"tokens": query_tokens}) fields["rationale"] = ArrayField(np.array(human_rationale_labels)) metadata = { "annotation_id": annotation_id, "human_rationale": rationale, "document": document, "label": label, } if query is not None: metadata["query"] = query fields["metadata"] = MetadataField(metadata) if label is not None: fields["label"] = LabelField(label, label_namespace="labels") return Instance(fields) def convert_tokens_to_instance(self, tokens: List[Token]): fields = {} tokens = tokens[0] + ( ([to_token("[DQSEP]")] + tokens[1]) if len(tokens[1]) > 0 else []) fields["document"] = TextField(tokens, self._token_indexers) return Instance(fields) def convert_documents_to_batch(self, documents: List[Tuple[List[Token], List[Token]]], vocabulary) -> Dict[str, Any]: batch = Batch( [self.convert_tokens_to_instance(tokens) for tokens in documents]) batch.index_instances(vocabulary) batch = batch.as_tensor_dict() return batch["document"] def combine_document_query(self, document: List[MetadataField], query: List[MetadataField], vocabulary): document_tokens = [(x["tokens"], y["tokens"]) for x, y in zip(document, query)] return self.convert_documents_to_batch(document_tokens, vocabulary)
class RationaleReader(DatasetReader): def __init__( self, token_indexers: Dict[str, TokenIndexer], max_sequence_length: int = None, human_prob: float = 1.0, lazy: bool = False, ) -> None: super().__init__(lazy=lazy) self._tokenizer = WhitespaceTokenizer() self._max_sequence_length = max_sequence_length self._token_indexers = token_indexers self._human_prob = human_prob self._bert = "bert" in token_indexers @overrides def _read(self, file_path): rs = RandomState(seed=1000) with open(cached_path(file_path), "r") as data_file: for _, line in enumerate(data_file.readlines()): items = json.loads(line) document = items["document"] query = items.get("query", None) label = items.get("label", None) rationale = items.get("rationale", []) annotation_id = items["annotation_id"] if label is not None: label = str(label).replace(' ', '_') if rs.random_sample() > self._human_prob: rationale = -1 instance = self.text_to_instance(annotation_id=annotation_id, document=document, query=query, label=label, rationale=rationale) if instance is not None: yield instance @overrides def text_to_instance( self, annotation_id: str, document: str, query: str = None, label: str = None, rationale: List[tuple] = None) -> Instance: # type: ignore # pylint: disable=arguments-differ fields = {} tokens = [Token("<S>")] keep_tokens = [1] word_tokens = self._tokenizer.tokenize(document) rationale_tokens = [0] * len(word_tokens) if rationale != -1: for s, e in rationale: for i in range(s, e): rationale_tokens[i] = 1 tokens.extend(word_tokens) keep_tokens.extend([0 for _ in range(len(word_tokens))]) rationale_tokens = [0] + rationale_tokens if query is not None: if self._bert: query_tokens = self._tokenizer.tokenize(query) tokens += [Token('[SEP]')] + query_tokens keep_tokens += [1 for _ in range(len(query_tokens) + 1)] rationale_tokens += [1] * (len(query_tokens) + 1) else: fields["query"] = TextField(self._tokenizer.tokenize(query), self._token_indexers) fields["document"] = TextField(tokens, self._token_indexers) assert len(rationale_tokens) == len(tokens), breakpoint() fields['rationale'] = SequenceLabelField(rationale_tokens, fields['document'], 'rationale_labels') metadata = { "annotation_id": annotation_id, "tokens": tokens, "keep_tokens": keep_tokens, "document": document, "query": query, "convert_tokens_to_instance": self.convert_tokens_to_instance, "label": label } fields["metadata"] = MetadataField(metadata) if label is not None: fields["label"] = LabelField(label, label_namespace="labels") return Instance(fields) def convert_tokens_to_instance(self, tokens): fields = {} fields["document"] = TextField(tokens, self._token_indexers) return Instance(fields)