Example #1
0
    def test_sample_spans(self, text: str, num_anchors: int,
                          num_positives: int, sampling_strategy: Union[str,
                                                                       None]):
        tokens = self.tokenize(text)
        num_tokens = len(tokens)
        # These represent sensible defaults
        max_span_len = num_tokens // 2
        min_span_len = randint(1, num_tokens //
                               2) if num_tokens // 2 > 1 else 1

        # The sampling procedure often breaks if we don't have at least ten tokens, so we set
        # a strict lower bound.
        if num_tokens < 10:
            with pytest.raises(ValueError):
                _, _ = sample_anchor_positive_pairs(
                    text,
                    num_anchors=num_anchors,
                    num_positives=num_positives,
                    max_span_len=max_span_len,
                    min_span_len=min_span_len,
                    sampling_strategy=sampling_strategy,
                )
        else:
            anchors, positives = sample_anchor_positive_pairs(
                text,
                num_anchors=num_anchors,
                num_positives=num_positives,
                max_span_len=max_span_len,
                min_span_len=min_span_len,
                sampling_strategy=sampling_strategy,
            )
            assert len(anchors) == num_anchors
            assert len(positives) == num_positives
            for i, anchor in enumerate(anchors):
                # Several simple checks for valid anchors.
                anchor_tokens = self.tokenize(anchor)
                anchor_length = len(anchor_tokens)
                assert anchor_length <= max_span_len
                assert anchor_length >= min_span_len
                # The tokenization process may lead to certain characters (such as escape
                # characters) being dropped, so repeat the tokenization process before performing
                # this check (otherwise a bunch of tests fail).
                assert anchor in " ".join(tokens)
                for j in range(i, i + num_positives):
                    # Several simple checks for valid positives.
                    positive = positives[j]
                    positive_tokens = self.tokenize(positive)
                    positive_length = len(positive_tokens)
                    assert positive_length <= max_span_len
                    assert positive_length >= min_span_len
                    assert positive in " ".join(tokens)
                    # Test that specific sampling strategies are obeyed.
                    if sampling_strategy == "subsuming":
                        assert positive in " ".join(anchor_tokens)
                    elif sampling_strategy == "adjacent":
                        assert positive not in " ".join(anchor_tokens)
Example #2
0
    def text_to_instance(self, text: str,label:str) -> Instance:  # type: ignore
        """
        # Parameters

        text : `str`, required.
            The text to process.

        # Returns

        An `Instance` containing the following fields:
            - anchors (`Union[TextField, ListField[TextField]]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each anchor span sampled from `text`. Otherwise, this will be a `TextField` object
                containing the tokenized `text`.
            - positives (`ListField[TextField]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each positive span sampled from `text`. Otherwise this field will not be included
                in the returned `Instance`.
        """
        # Some very minimal preprocessing to remove whitespace, newlines and tabs.
        # We peform it here as it will cover both training and predicting with the model.
        # We DON'T lowercase by default, but rather allow `self._tokenizer` to decide.
        #print(label)
        #print(text)
        text = sanitize(text, lowercase=False)

        fields: Dict[str, Field] = {}
        if self.sample_spans:
            # Choose the anchor/positives at random.
            anchor_text, positive_text = sample_anchor_positive_pairs(
                text=text,
                num_anchors=self._num_anchors,
                num_positives=self._num_positives,
                max_span_len=self._max_span_len,
                min_span_len=self._min_span_len,
                sampling_strategy=self._sampling_strategy,
            )
            anchors: List[Field] = []
            for text in anchor_text:
                tokens = self._tokenizer.tokenize(text)
                anchors.append(TextField(tokens, self._token_indexers))
            fields["anchors"] = ListField(anchors)
        
            positives: List[Field] = []
            for text in positive_text:
                tokens = self._tokenizer.tokenize(text)
                positives.append(TextField(tokens, self._token_indexers))
            fields["positives"] = ListField(positives)
            #ltokens = self._tokenizer.tokenize(label)
            #fields["label"] = TextField(ltokens, self._token_indexers)
            fields["label"] = LabelField(str(label))
        else:
            tokens = self._tokenizer.tokenize(text)
            fields["anchors"] = TextField(tokens, self._token_indexers)
            #ltokens = self._tokenizer.tokenize(label)
            #fields["label"] = TextField(ltoken, self._token_indexers)
            fields["label"] = LabelField(str(label))
        return Instance(fields)
Example #3
0
    def test_sample_spans_raises_value_error_invalid_min_span_length(
            self, num_anchors: int, num_positives: int):
        text = "They may take our lives, but they'll never take our freedom!"
        num_tokens = len(self.tokenize(text))

        max_span_len = num_tokens - 1  # This is guaranteed to be valid.
        min_span_len = max_span_len + 1  # This is guaranteed to be invalid.

        with pytest.raises(ValueError):
            _, _ = sample_anchor_positive_pairs(
                text,
                num_anchors=num_anchors,
                num_positives=num_positives,
                max_span_len=max_span_len,
                min_span_len=min_span_len,
            )
Example #4
0
    def text_to_instance(self, text: str) -> Instance:  # type: ignore
        """
        # Parameters

        text : `str`, required.
            The text to process.

        # Returns

        An `Instance` containing the following fields:
            - anchors (`Union[TextField, ListField[TextField]]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each anchor span sampled from `text`. Otherwise, this will be a `TextField` object
                containing the tokenized `text`.
            - positives (`ListField[TextField]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each positive span sampled from `text`. Otherwise this field will not be included
                in the returned `Instance`.
        """
        # Some very minimal preprocessing to remove whitespace, newlines and tabs.
        # We peform it here as it will cover both training and predicting with the model.
        # We DON'T lowercase by default, but rather allow `self._tokenizer` to decide.
        text = sanitize(text, lowercase=False)

        difficulty_step = int(self.instance / 49784) + 1
        # difficulty_step = self.instance
        # difficulty_step = -100
        # difficulty_step_sample = int(self.instance / 165944) + 1
        # difficulty_step_sample = int(self.instance / 82972) - 2
        # difficulty_step = int(self.instance / 24 ) + 1
        # difficulty_step_sample = int(self.instance /44) - 2
        # difficulty_step_sample = int(self.instance / 80 ) + 1
        self.instance += 1

        fields: Dict[str, Field] = {}
        if self.sample_spans:
            # print("reading instance is", self.instance)
            # difficulty_step = int(self.instance / 40 ) + 1

            # # print("difficulty step is ",difficulty_step)
            # if difficulty_step > 5 :
            # # if difficulty_step > 2 :
            #     # self._num_anchors = 2
            #     # self._num_anchors = int(difficulty_step /2) + 1
            #     # self._num_anchors = int((difficulty_step - 1)/2) - 1
            #     self._num_anchors = difficulty_step_sample
            #     if self._num_anchors > 3:
            #         # print("over anchor!")
            #         self._num_anchors = 3
            #     # self._num_anchors = random.randint(1, self._num_anchors)
            #     # print("num_anchors", self._num_anchors, self.instance, difficulty_step)
            #     # sample_difficulty = difficulty_step
            #     sample_difficulty = 1
            # else:
            #     sample_difficulty = 1
            # self._num_anchors = difficulty_step_sample
            # if difficulty_step_sample <=0 :
            #     self._num_anchors = 1
            # if self._num_anchors > 3:
            #     # print("over anchor!")
            #     self._num_anchors = 3
            sample_difficulty = 1
            # print("anchor num is", self._num_anchors)

            # fields["text"] = LabelField(len(text), skip_indexing=True)
            # Choose the anchor/positives at random.
            # anchor_text, positive_text = sample_anchor_positive_pairs(
            #     text=text,
            #     num_anchors=self._num_anchors,
            #     num_positives=self._num_positives,
            #     max_span_len=self._max_span_len,
            #     min_span_len=self._min_span_len,
            #     difficulty_step = sample_difficulty,
            #     sampling_strategy=self._sampling_strategy,
            # )
            # # print("anchor_text", anchor_text)
            # # print("positive_text", positive_text)
            # anchors: List[Field] = []
            # for text in anchor_text:
            #     tokens = self._tokenizer.tokenize(text)
            #     anchors.append(TextField(tokens, self._token_indexers))
            # fields["anchors"] = ListField(anchors)
            # positives: List[Field] = []
            # for text in positive_text:
            #     tokens = self._tokenizer.tokenize(text)
            #     positives.append(TextField(tokens, self._token_indexers))
            # fields["positives"] = ListField(positives)
            # fields["difficulty"] = LabelField(difficulty_step, skip_indexing=True)
            anchor_text = sample_anchor_positive_pairs(
                text=text,
                num_anchors=self._num_anchors,
                num_positives=self._num_positives,
                max_span_len=self._max_span_len,
                min_span_len=self._min_span_len,
                difficulty_step=sample_difficulty,
                sampling_strategy=self._sampling_strategy,
            )
            # print("anchor_text", anchor_text)
            # print("positive_text", positive_text)
            anchors: List[Field] = []
            for text in anchor_text:
                tokens = self._tokenizer.tokenize(text)
                anchors.append(TextField(tokens, self._token_indexers))
            print("number of token is", len(tokens))
            fields["anchors"] = ListField(anchors)
            fields["difficulty"] = LabelField(difficulty_step,
                                              skip_indexing=True)
        else:
            # print("no sampling")
            tokens = self._tokenizer.tokenize(text)
            print("number of token is", len(tokens))
            fields["anchors"] = TextField(tokens, self._token_indexers)
            fields["difficulty"] = LabelField(difficulty_step,
                                              skip_indexing=True)
        return Instance(fields)
Example #5
0
    def test_sample_spans(
        self,
        inputs: List[str],
        num_anchors: int,
        num_positives: int,
        sampling_strategy: Union[str, None],
    ) -> None:

        for text in inputs:
            tokens = self.tokenize(text)
            num_tokens = len(tokens)

            # Really short examples make the tests unreliable.
            if num_tokens < 7:
                continue

            # These represent sensible defaults
            max_span_len = num_tokens // 4
            min_span_len = random.randint(1, max_span_len) if max_span_len > 1 else 1

            if num_tokens < num_anchors * max_span_len * 2:
                with pytest.raises(ValueError):
                    _, _ = sample_anchor_positive_pairs(
                        text,
                        num_anchors=num_anchors,
                        num_positives=num_positives,
                        max_span_len=max_span_len,
                        min_span_len=min_span_len,
                        sampling_strategy=sampling_strategy,
                    )
            else:
                anchors, positives = sample_anchor_positive_pairs(
                    text,
                    num_anchors=num_anchors,
                    num_positives=num_positives,
                    max_span_len=max_span_len,
                    min_span_len=min_span_len,
                    sampling_strategy=sampling_strategy,
                )
                assert len(anchors) == num_anchors
                assert len(positives) == num_anchors * num_positives
                for i, anchor in enumerate(anchors):
                    # Several simple checks for valid anchors.
                    anchor_tokens = self.tokenize(anchor)
                    anchor_length = len(anchor_tokens)
                    assert anchor_length <= max_span_len
                    assert anchor_length >= min_span_len
                    # The tokenization process may lead to certain characters (such as escape
                    # characters) being dropped, so repeat the tokenization process before
                    # performing this check (otherwise a bunch of tests fail).
                    assert anchor in " ".join(tokens)
                    for j in range(i * num_positives, i * num_positives + num_positives):
                        # Several simple checks for valid positives.
                        positive = positives[j]
                        positive_tokens = self.tokenize(positive)
                        positive_length = len(positive_tokens)
                        assert positive_length <= max_span_len
                        assert positive_length >= min_span_len
                        assert positive in " ".join(tokens)
                        # Test that specific sampling strategies are obeyed.
                        if sampling_strategy == "subsuming":
                            assert positive in " ".join(anchor_tokens)
                        elif sampling_strategy == "adjacent":
                            assert positive not in " ".join(anchor_tokens)
Example #6
0
    def text_to_instance(self, text: str) -> Instance:  # type: ignore
        """
        # Parameters

        text : `str`, required.
            The text to process.

        # Returns

        An `Instance` containing the following fields:
            - anchors (`Union[TextField, ListField[TextField]]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each anchor span sampled from `text`. Otherwise, this will be a `TextField` object
                containing the tokenized `text`.
            - positives (`ListField[TextField]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each positive span sampled from `text`. Otherwise this field will not be included
                in the returned `Instance`.
        """
        # Some very minimal preprocessing to remove whitespace, newlines and tabs.
        # We peform it here as it will cover both training and predicting with the model.
        # We DON'T lowercase by default, but rather allow `self._tokenizer` to decide.
        text = sanitize_text(text, lowercase=False)

        fields: Dict[str, Field] = {}
        if self.sample_spans:
            if isinstance(self._tokenizer, PretrainedTransformerTokenizer):
                # We add a space in front of the text in order to achieve consistant tokenization with
                # certain tokenizers, e.g. the BPE tokenizer used by RoBERTa, GPT and others.
                # See: https://github.com/huggingface/transformers/issues/1196
                text = f" {text.lstrip()}"
                tokenization_func = self._tokenizer.tokenizer.tokenize
                # A call to the `tokenize` method of the AllenNLP tokenizer causes
                # subsequent calls to the underlying HuggingFace Tokenizer (if `use_fast`)
                # to truncate text. Reset the truncation each time here.
                # Note this only appears to happen for transformers<3.1
                if self._tokenizer.tokenizer.is_fast:
                    self._tokenizer.tokenizer._tokenizer.no_truncation()
            else:
                tokenization_func = None
            # Choose the anchor/positives at random.
            anchor_spans, positive_spans = sample_anchor_positive_pairs(
                text=text,
                num_anchors=self._num_anchors,
                num_positives=self._num_positives,
                max_span_len=self._max_span_len,
                min_span_len=self._min_span_len,
                sampling_strategy=self._sampling_strategy,
                tokenizer=tokenization_func,
            )

            anchors: List[Field] = []
            for span in anchor_spans:
                # Sampled spans have already been tokenized and joined by whitespace.
                # We need to convert them back to a string to use the AllenNLP tokenizer
                # It would be simpler to use convert_tokens_to_string, but we can't guarantee
                # this method is implemented for all HuggingFace Tokenizers
                anchor_text = self._tokenizer.tokenizer.decode(
                    self._tokenizer.tokenizer.convert_tokens_to_ids(
                        span.split()))
                tokens = self._tokenizer.tokenize(anchor_text)
                anchors.append(TextField(tokens, self._token_indexers))
            fields["anchors"] = ListField(anchors)
            positives: List[Field] = []
            for span in positive_spans:
                positive_text = self._tokenizer.tokenizer.decode(
                    self._tokenizer.tokenizer.convert_tokens_to_ids(
                        span.split()))
                tokens = self._tokenizer.tokenize(positive_text)
                positives.append(TextField(tokens, self._token_indexers))
            fields["positives"] = ListField(positives)
        else:
            tokens = self._tokenizer.tokenize(text)
            fields["anchors"] = TextField(tokens, self._token_indexers)
        return Instance(fields)