Exemple #1
0
 def postprocess_output_spans(self, input_centers, output_spans,
                              input_dirs):
     en_greek_writtens = list(constants.EN_GREEK_TO_SPOKEN.keys())
     en_greek_spokens = list(constants.EN_GREEK_TO_SPOKEN.values())
     for ix, (_input, _output) in enumerate(zip(input_centers,
                                                output_spans)):
         if self.lang == constants.ENGLISH:
             # Handle URL
             if is_url(_input):
                 _output = _output.replace('http', ' h t t p ')
                 _output = _output.replace('/', ' slash ')
                 _output = _output.replace('.', ' dot ')
                 _output = _output.replace(':', ' colon ')
                 _output = _output.replace('-', ' dash ')
                 _output = _output.replace('_', ' underscore ')
                 _output = _output.replace('%', ' percent ')
                 _output = _output.replace('www', ' w w w ')
                 _output = _output.replace('ftp', ' f t p ')
                 output_spans[ix] = ' '.join(wordninja.split(_output))
                 continue
             # Greek letters
             if _input in en_greek_writtens:
                 if input_dirs[ix] == constants.INST_FORWARD:
                     output_spans[ix] = constants.EN_GREEK_TO_SPOKEN[_input]
             if _input in en_greek_spokens:
                 if input_dirs[ix] == constants.INST_FORWARD:
                     output_spans[ix] = _input
                 if input_dirs[ix] == constants.INST_BACKWARD:
                     output_spans[ix] = constants.EN_SPOKEN_TO_GREEK[_input]
     return output_spans
Exemple #2
0
 def postprocess_output_spans(self, input_centers, output_spans,
                              input_dirs):
     greek_spokens = list(constants.GREEK_TO_SPOKEN.values())
     for ix, (_input, _output) in enumerate(zip(input_centers,
                                                output_spans)):
         # Handle URL
         if is_url(_input):
             output_spans[ix] = ' '.join(wordninja.split(_output))
             continue
         # Greek letters
         if _input in greek_spokens:
             if input_dirs[ix] == constants.INST_FORWARD:
                 output_spans[ix] = _input
             if input_dirs[ix] == constants.INST_BACKWARD:
                 output_spans[ix] = constants.SPOKEN_TO_GREEK[_input]
     return output_spans
    def postprocess_output_spans(self, input_centers: List[str],
                                 generated_spans: List[str],
                                 input_dirs: List[str]):
        """
        Post processing of the generated texts

        Args:
            input_centers: Input str (no special tokens or context)
            generated_spans: Generated spans
            input_dirs: task direction: constants.INST_BACKWARD or constants.INST_FORWARD

        Returns:
            Processing texts
        """
        en_greek_writtens = list(constants.EN_GREEK_TO_SPOKEN.keys())
        en_greek_spokens = list(constants.EN_GREEK_TO_SPOKEN.values())
        for ix, (_input,
                 _output) in enumerate(zip(input_centers, generated_spans)):
            if self.lang == constants.ENGLISH:
                # Handle URL
                if is_url(_input):
                    _output = _output.replace('http', ' h t t p ')
                    _output = _output.replace('/', ' slash ')
                    _output = _output.replace('.', ' dot ')
                    _output = _output.replace(':', ' colon ')
                    _output = _output.replace('-', ' dash ')
                    _output = _output.replace('_', ' underscore ')
                    _output = _output.replace('%', ' percent ')
                    _output = _output.replace('www', ' w w w ')
                    _output = _output.replace('ftp', ' f t p ')
                    generated_spans[ix] = ' '.join(wordninja.split(_output))
                    continue
                # Greek letters
                if _input in en_greek_writtens:
                    if input_dirs[ix] == constants.INST_FORWARD:
                        generated_spans[ix] = constants.EN_GREEK_TO_SPOKEN[
                            _input]
                if _input in en_greek_spokens:
                    if input_dirs[ix] == constants.INST_FORWARD:
                        generated_spans[ix] = _input
                    if input_dirs[ix] == constants.INST_BACKWARD:
                        generated_spans[ix] = constants.EN_SPOKEN_TO_GREEK[
                            _input]
        return generated_spans
    def _infer(
        self,
        sents: List[List[str]],
        nb_spans: List[int],
        span_starts: List[List[int]],
        span_ends: List[List[int]],
        inst_directions: List[str],
    ):
        """ Main function for Inference
        Args:
            sents: A list of inputs tokenized by a basic tokenizer.
            nb_spans: A list of ints where each int indicates the number of semiotic spans in each input.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input.
            span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).

        Returns: A list of lists where each list contains the decoded spans for the corresponding input.
        """
        self.eval()

        if sum(nb_spans) == 0:
            return [[]] * len(sents)
        model, tokenizer = self.model, self._tokenizer
        try:
            model_max_len = model.config.n_positions
        except AttributeError:
            model_max_len = 512
        ctx_size = constants.DECODE_CTX_SIZE
        extra_id_0 = constants.EXTRA_ID_0
        extra_id_1 = constants.EXTRA_ID_1
        """
        Build all_inputs - extracted spans to be transformed by the decoder model
        Inputs for TN direction have "0" prefix, while the backward, ITN direction, has prefix "1"
        "input_centers" - List[str] - ground-truth labels for the span #TODO: rename
        """
        input_centers, input_dirs, all_inputs = [], [], []
        for ix, sent in enumerate(sents):
            cur_inputs = []
            for jx in range(nb_spans[ix]):
                cur_start = span_starts[ix][jx]
                cur_end = span_ends[ix][jx]
                ctx_left = sent[max(0, cur_start - ctx_size):cur_start]
                ctx_right = sent[cur_end + 1:cur_end + 1 + ctx_size]
                span_words = sent[cur_start:cur_end + 1]
                span_words_str = ' '.join(span_words)
                if is_url(span_words_str):
                    span_words_str = span_words_str.lower()
                input_centers.append(span_words_str)
                input_dirs.append(inst_directions[ix])
                # Build cur_inputs
                if inst_directions[ix] == constants.INST_BACKWARD:
                    cur_inputs = [constants.ITN_PREFIX]
                if inst_directions[ix] == constants.INST_FORWARD:
                    cur_inputs = [constants.TN_PREFIX]
                cur_inputs += ctx_left
                cur_inputs += [extra_id_0
                               ] + span_words_str.split(' ') + [extra_id_1]
                cur_inputs += ctx_right
                all_inputs.append(' '.join(cur_inputs))

        # Apply the decoding model
        batch = tokenizer(all_inputs, padding=True, return_tensors='pt')
        input_ids = batch['input_ids'].to(self.device)

        generated_texts, generated_ids, sequence_toks_scores = self._generate_predictions(
            input_ids=input_ids, model_max_len=model_max_len)

        # Use covering grammars (if enabled)
        if self.use_cg:
            # Compute sequence probabilities
            sequence_probs = torch.ones(len(all_inputs)).to(self.device)
            for ix, cur_toks_scores in enumerate(sequence_toks_scores):
                cur_generated_ids = generated_ids[:, ix + 1].tolist()
                cur_toks_probs = torch.nn.functional.softmax(cur_toks_scores,
                                                             dim=-1)
                # Compute selected_toks_probs
                selected_toks_probs = []
                for jx, _id in enumerate(cur_generated_ids):
                    if _id != self._tokenizer.pad_token_id:
                        selected_toks_probs.append(cur_toks_probs[jx, _id])
                    else:
                        selected_toks_probs.append(1)
                selected_toks_probs = torch.tensor(selected_toks_probs).to(
                    self.device)
                sequence_probs *= selected_toks_probs

            # For TN cases where the neural model is not confident, use CGs
            neural_confidence_threshold = self.neural_confidence_threshold
            for ix, (_dir, _input, _prob) in enumerate(
                    zip(input_dirs, input_centers, sequence_probs)):
                if _dir == constants.INST_FORWARD and _prob < neural_confidence_threshold:
                    if is_url(_input):
                        _input = _input.replace(' ',
                                                '')  # Remove spaces in URLs
                    try:
                        cg_outputs = self.cg_normalizer.normalize(
                            text=_input, verbose=False, n_tagged=self.n_tagged)
                        generated_texts[ix] = list(cg_outputs)[0]
                    except:  # if there is any exception, fall back to the input
                        generated_texts[ix] = _input

        # Post processing
        generated_texts = self.postprocess_output_spans(
            input_centers, generated_texts, input_dirs)

        # Prepare final_texts
        final_texts, span_ctx = [], 0
        for nb_span in nb_spans:
            cur_texts = []
            for i in range(nb_span):
                cur_texts.append(generated_texts[span_ctx])
                span_ctx += 1
            final_texts.append(cur_texts)

        return final_texts
Exemple #5
0
    def _infer(
        self,
        sents: List[List[str]],
        nb_spans: List[int],
        span_starts: List[List[int]],
        span_ends: List[List[int]],
        inst_directions: List[str],
    ):
        """ Main function for Inference
        Args:
            sents: A list of inputs tokenized by a basic tokenizer (e.g., using nltk.word_tokenize()).
            nb_spans: A list of ints where each int indicates the number of semiotic spans in each input.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input.
            span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).

        Returns: A list of lists where each list contains the decoded spans for the corresponding input.
        """
        self.eval()

        if sum(nb_spans) == 0:
            return [[]] * len(sents)
        model, tokenizer = self.model, self._tokenizer
        model_max_len = model.config.n_positions
        ctx_size = constants.DECODE_CTX_SIZE
        extra_id_0 = constants.EXTRA_ID_0
        extra_id_1 = constants.EXTRA_ID_1

        # Build all_inputs
        input_centers, input_dirs, all_inputs = [], [], []
        for ix, sent in enumerate(sents):
            cur_inputs = []
            for jx in range(nb_spans[ix]):
                cur_start = span_starts[ix][jx]
                cur_end = span_ends[ix][jx]
                ctx_left = sent[max(0, cur_start - ctx_size):cur_start]
                ctx_right = sent[cur_end + 1:cur_end + 1 + ctx_size]
                span_words = sent[cur_start:cur_end + 1]
                span_words_str = ' '.join(span_words)
                if is_url(span_words_str):
                    span_words_str = span_words_str.lower()
                input_centers.append(span_words_str)
                input_dirs.append(inst_directions[ix])
                # Build cur_inputs
                if inst_directions[ix] == constants.INST_BACKWARD:
                    cur_inputs = [constants.ITN_PREFIX]
                if inst_directions[ix] == constants.INST_FORWARD:
                    cur_inputs = [constants.TN_PREFIX]
                cur_inputs += ctx_left
                cur_inputs += [extra_id_0
                               ] + span_words_str.split(' ') + [extra_id_1]
                cur_inputs += ctx_right
                all_inputs.append(' '.join(cur_inputs))

        # Apply the decoding model
        batch = tokenizer(all_inputs, padding=True, return_tensors='pt')
        input_ids = batch['input_ids'].to(self.device)
        generated_ids = model.generate(input_ids, max_length=model_max_len)
        generated_texts = tokenizer.batch_decode(generated_ids,
                                                 skip_special_tokens=True)

        # Post processing
        generated_texts = self.postprocess_output_spans(
            input_centers, generated_texts, input_dirs)

        # Prepare final_texts
        final_texts, span_ctx = [], 0
        for nb_span in nb_spans:
            cur_texts = []
            for i in range(nb_span):
                cur_texts.append(generated_texts[span_ctx])
                span_ctx += 1
            final_texts.append(cur_texts)

        return final_texts