Ejemplo n.º 1
0
class BaseReader(DatasetReader):
    def __init__(
        self,
        token_indexers: Dict[str, TokenIndexer],
        human_prob: float = 1.0,
        lazy: bool = False,
    ) -> None:
        super().__init__(lazy=lazy)
        self._tokenizer = WhitespaceTokenizer()
        self._token_indexers = token_indexers
        self._human_prob = human_prob

    @overrides
    def _read(self, file_path):
        rs = RandomState(seed=1000)
        with open(cached_path(file_path), "r") as data_file:
            for _, line in enumerate(data_file.readlines()):
                items = json.loads(line)
                document = items["document"]
                annotation_id = items["annotation_id"]
                query = items.get("query", None)
                label = items.get("label", None)
                rationale = items.get(
                    "rationale",
                    []) if rs.random_sample() < self._human_prob else []

                if label is not None:
                    label = str(label).replace(" ", "_")

                instance = self.text_to_instance(
                    annotation_id=annotation_id,
                    document=document,
                    query=query,
                    label=label,
                    rationale=rationale,
                )
                yield instance

    @overrides
    def text_to_instance(
        self,
        annotation_id: str,
        document: str,
        query: str = None,
        label: str = None,
        rationale: List[tuple] = None,
    ) -> Instance:  # type: ignore
        # pylint: disable=arguments-differ
        fields = {}

        document_tokens = [
            to_token(t.text) for t in self._tokenizer.tokenize(document)
        ]
        human_rationale_labels = [0] * len(document_tokens)
        for s, e in rationale:
            for i in range(s, e):
                human_rationale_labels[i] = 1

        if query is not None:
            query_tokens = [
                to_token(t.text) for t in self._tokenizer.tokenize(query)
            ]
        else:
            query_tokens = []

        for index_name, indexer in self._token_indexers.items():
            if hasattr(indexer, "add_token_info"):
                indexer.add_token_info(document_tokens, index_name)
                indexer.add_token_info(query_tokens, index_name)

        fields["document"] = MetadataField({
            "tokens": document_tokens,
            "reader_object": self
        })
        fields["query"] = MetadataField({"tokens": query_tokens})
        fields["rationale"] = ArrayField(np.array(human_rationale_labels))

        metadata = {
            "annotation_id": annotation_id,
            "human_rationale": rationale,
            "document": document,
            "label": label,
        }

        if query is not None:
            metadata["query"] = query

        fields["metadata"] = MetadataField(metadata)

        if label is not None:
            fields["label"] = LabelField(label, label_namespace="labels")

        return Instance(fields)

    def convert_tokens_to_instance(self, tokens: List[Token]):
        fields = {}
        tokens = tokens[0] + (
            ([to_token("[DQSEP]")] + tokens[1]) if len(tokens[1]) > 0 else [])
        fields["document"] = TextField(tokens, self._token_indexers)

        return Instance(fields)

    def convert_documents_to_batch(self, documents: List[Tuple[List[Token],
                                                               List[Token]]],
                                   vocabulary) -> Dict[str, Any]:
        batch = Batch(
            [self.convert_tokens_to_instance(tokens) for tokens in documents])
        batch.index_instances(vocabulary)
        batch = batch.as_tensor_dict()
        return batch["document"]

    def combine_document_query(self, document: List[MetadataField],
                               query: List[MetadataField], vocabulary):
        document_tokens = [(x["tokens"], y["tokens"])
                           for x, y in zip(document, query)]
        return self.convert_documents_to_batch(document_tokens, vocabulary)
class RationaleReader(DatasetReader):
    def __init__(
        self,
        token_indexers: Dict[str, TokenIndexer],
        max_sequence_length: int = None,
        human_prob: float = 1.0,
        lazy: bool = False,
    ) -> None:
        super().__init__(lazy=lazy)
        self._tokenizer = WhitespaceTokenizer()
        self._max_sequence_length = max_sequence_length
        self._token_indexers = token_indexers
        self._human_prob = human_prob

        self._bert = "bert" in token_indexers

    @overrides
    def _read(self, file_path):
        rs = RandomState(seed=1000)
        with open(cached_path(file_path), "r") as data_file:
            for _, line in enumerate(data_file.readlines()):
                items = json.loads(line)
                document = items["document"]
                query = items.get("query", None)
                label = items.get("label", None)
                rationale = items.get("rationale", [])
                annotation_id = items["annotation_id"]

                if label is not None:
                    label = str(label).replace(' ', '_')

                if rs.random_sample() > self._human_prob:
                    rationale = -1

                instance = self.text_to_instance(annotation_id=annotation_id,
                                                 document=document,
                                                 query=query,
                                                 label=label,
                                                 rationale=rationale)
                if instance is not None:
                    yield instance

    @overrides
    def text_to_instance(
            self,
            annotation_id: str,
            document: str,
            query: str = None,
            label: str = None,
            rationale: List[tuple] = None) -> Instance:  # type: ignore
        # pylint: disable=arguments-differ
        fields = {}

        tokens = [Token("<S>")]
        keep_tokens = [1]

        word_tokens = self._tokenizer.tokenize(document)
        rationale_tokens = [0] * len(word_tokens)
        if rationale != -1:
            for s, e in rationale:
                for i in range(s, e):
                    rationale_tokens[i] = 1

        tokens.extend(word_tokens)
        keep_tokens.extend([0 for _ in range(len(word_tokens))])

        rationale_tokens = [0] + rationale_tokens

        if query is not None:
            if self._bert:
                query_tokens = self._tokenizer.tokenize(query)
                tokens += [Token('[SEP]')] + query_tokens
                keep_tokens += [1 for _ in range(len(query_tokens) + 1)]
                rationale_tokens += [1] * (len(query_tokens) + 1)
            else:
                fields["query"] = TextField(self._tokenizer.tokenize(query),
                                            self._token_indexers)

        fields["document"] = TextField(tokens, self._token_indexers)

        assert len(rationale_tokens) == len(tokens), breakpoint()
        fields['rationale'] = SequenceLabelField(rationale_tokens,
                                                 fields['document'],
                                                 'rationale_labels')

        metadata = {
            "annotation_id": annotation_id,
            "tokens": tokens,
            "keep_tokens": keep_tokens,
            "document": document,
            "query": query,
            "convert_tokens_to_instance": self.convert_tokens_to_instance,
            "label": label
        }

        fields["metadata"] = MetadataField(metadata)

        if label is not None:
            fields["label"] = LabelField(label, label_namespace="labels")

        return Instance(fields)

    def convert_tokens_to_instance(self, tokens):
        fields = {}
        fields["document"] = TextField(tokens, self._token_indexers)
        return Instance(fields)