コード例 #1
0
    def writeLabels(self,
                    count,
                    batch_size,
                    pred_ids,
                    batch,
                    all_probabilities,
                    all_idxs,
                    error_probs,
                    max_len=50):
        all_results = []
        noop_index = self.vocab.get_token_index("$KEEP", "labels")

        for tokens, probabilities, idxs, error_prob, pred_id in zip(
                batch, all_probabilities, all_idxs, error_probs, pred_ids):
            length = min(len(tokens), max_len)
            edits = []

            #>>>>>>>>>>>>>>>>>>>
            fa = open("outputFiles/fce-predictions.txt", "a")
            fa.write(str(count * batch_size + pred_id) + ": ")
            for idx in idxs[:length + 1]:
                fa.write(str(idx) + ",")
            fa.write("\n")

            # print(tokens, len(idxs[:length+1]))

            #<<<<<<<<<<<<<<<<<<<

            # skip whole sentences if there no errors
            if max(idxs) == 0:
                all_results.append(tokens)
                continue

            # skip whole sentence if probability of correctness is not high
            if error_prob < self.min_error_probability:
                all_results.append(tokens)
                continue

            for i in range(length + 1):
                # because of START token
                if i == 0:
                    token = START_TOKEN
                else:
                    token = tokens[i - 1]
                # skip if there is no error
                if idxs[i] == noop_index:
                    continue

                sugg_token = self.vocab.get_token_from_index(
                    idxs[i], namespace='labels')
                action = self.get_token_action(token, i, probabilities[i],
                                               sugg_token)
                if not action:
                    continue

                edits.append(action)
            all_results.append(get_target_sent_by_edits(tokens, edits))
        return all_results
コード例 #2
0
    def postprocess_batch(self,
                          batch,
                          all_probabilities,
                          all_idxs,
                          error_probs,
                          max_len=50):
        all_results = []
        passed_info = []  #changed here

        noop_index = self.vocab.get_token_index("$KEEP", "labels")
        for tokens, probabilities, idxs, error_prob in zip(
                batch, all_probabilities, all_idxs, error_probs):
            length = min(len(tokens), max_len)
            edits = []

            # skip whole sentences if there no errors
            if max(idxs) == 0:
                all_results.append(tokens)
                continue

            # skip whole sentence if probability of correctness is not high
            if error_prob < self.min_error_probability:
                all_results.append(tokens)
                continue

            for i in range(length + 1):
                # because of START token
                if i == 0:
                    token = START_TOKEN
                else:
                    token = tokens[i - 1]
                # skip if there is no error
                if idxs[i] == noop_index:
                    continue

                sugg_token = self.vocab.get_token_from_index(
                    idxs[i], namespace='labels')

                action = self.get_token_action(token, i, probabilities[i],
                                               sugg_token)

                tok = sugg_token  #changed here
                changed_pos = action[0]  #here
                change_info = tok, changed_pos  #here
                passed_info.append(change_info)  #here

                if not action:
                    continue

                edits.append(action)

            all_results.append(get_target_sent_by_edits(tokens, edits))

        return all_results, passed_info
コード例 #3
0
ファイル: gec_model.py プロジェクト: rajaswa/gector
    def postprocess_batch(self, batch, all_probabilities, all_idxs,
                          error_probs,
                          max_len=50):
        all_results, batch_output_tokens = [], []
        noop_index = self.vocab.get_token_index("$KEEP", "labels")
        for tokens, probabilities, idxs, error_prob in zip(batch,
                                                           all_probabilities,
                                                           all_idxs,
                                                           error_probs):
            length = min(len(tokens), max_len)
            edits = []

            # store output tokens
            output_tokens = []
            for i in range(length + 1):
                sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels')
                output_tokens.append(sugg_token)
            batch_output_tokens.append(" ".join(output_tokens))

            # skip whole sentences if there no errors
            if max(idxs) == 0:
                all_results.append(tokens)
                continue

            # skip whole sentence if probability of correctness is not high
            if error_prob < self.min_error_probability:
                all_results.append(tokens)
                continue

            for i in range(length + 1):
                # because of START token
                if i == 0:
                    token = START_TOKEN
                else:
                    token = tokens[i - 1]

                sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels')
                # output_tokens.append(sugg_token)

                # skip if there is no error
                if idxs[i] == noop_index:
                    continue

                action = self.get_token_action(token, i, probabilities[i],
                                               sugg_token)
                if not action:
                    continue

                edits.append(action)
                
            all_results.append(get_target_sent_by_edits(tokens, edits))
        assert(len(all_results)==len(batch_output_tokens))
        return all_results, batch_output_tokens
コード例 #4
0
ファイル: gec_model.py プロジェクト: zhangbo2008/pachong2
    def postprocess_batch(self, batch, all_probabilities, all_idxs,
                          error_probs,
                          max_len=50):
        all_results = []
        noop_index = self.vocab.get_token_index("$KEEP", "labels")
        for tokens, probabilities, idxs, error_prob in zip(batch,
                                                           all_probabilities,
                                                           all_idxs,
                                                           error_probs):
            # 补上start
            # for i in range(len(tokens)):
            #     tokens[i]=['$START']+tokens[i]
            # tokens=['$START']+tokens
            length = min(len(tokens), max_len)
            edits = []

            # skip whole sentences if there no errors
            if max(idxs) == 0:  # 里面的0表示没修改
                all_results.append(tokens)
                continue

            # skip whole sentence if probability of correctness is not high
            if error_prob < self.min_error_probability:
                all_results.append(tokens)
                continue

            for i in range(length+1):  # fixed 一个bug,之前length没有加1
                token = tokens[i - 1]  # because of START token
                # skip if there is no error
                if idxs[i] == noop_index:
                    continue

                sugg_token = self.vocab.get_token_from_index(idxs[i],
                                                             namespace='labels')
                action = self.get_token_action(token, i, probabilities[i],
                                               sugg_token)
                if not action:
                    continue

                edits.append(action)
            all_results.append(get_target_sent_by_edits(tokens, edits))
        return all_results