Exemple #1
0
    def calc_f1_batch(self, decoded_data, target_data):
        """
        update statics for f1 score

        args:
            decoded_data (batch_size, seq_len): prediction sequence
            target_data (batch_size, seq_len): ground-truth
        """
        batch_decoded = torch.unbind(decoded_data, 1)
        batch_targets = torch.unbind(target_data, 0)

        for decoded, target in zip(batch_decoded, batch_targets):
            gold = self.packer.convert_for_eval(target)

            # remove padding
            length = utils.find_length_from_labels(gold, self.l_map)
            gold = gold[:length]
            best_path = decoded[:length]

            ## filter out removed label
            best_path_filted, gold_filted = self.label_filter(
                best_path.numpy(), gold.numpy())
            correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(
                best_path_filted, gold_filted)
            self.correct_labels += correct_labels_i
            self.total_labels += total_labels_i
            self.gold_count += gold_count_i
            self.guess_count += guess_count_i
            self.overlap_count += overlap_count_i
Exemple #2
0
    def check_output(self,
                     ner_model,
                     dataset_loader,
                     out,
                     f_map,
                     emb,
                     word_to_id,
                     gpu,
                     knowledge_dict,
                     no_dict=False):
        ner_model.eval()
        self.reset()
        f = open('model_output.txt', 'w')
        for f_f, f_p, b_f, b_p, w_f, tg, mask_v, len_v in itertools.chain.from_iterable(
                dataset_loader):
            mask_v = mask_v.bool()

            f_f, f_p, b_f, b_p, w_f, _, mask_v = self.packer.repack_vb(
                f_f, f_p, b_f, b_p, w_f, tg, mask_v, len_v)
            w_f_word = utils.reconstruct_word_input(w_f, f_map, emb,
                                                    word_to_id, gpu)
            prior_prob = utils.generate_prior_prob(self.r_c_map, self.l_map,
                                                   f_f, knowledge_dict)
            scores = ner_model(f_f, f_p, b_f, b_p, w_f, w_f_word, prior_prob)
            decoded = self.decoder.decode(scores.data, mask_v.data, prior_prob,
                                          no_dict)

            self.eval_b(decoded, tg)
            batch_decoded = torch.unbind(decoded, 1)
            batch_targets = torch.unbind(tg, 0)
            batch_f_f = torch.unbind(f_f, 1)

            for decoded, target, character in zip(batch_decoded, batch_targets,
                                                  batch_f_f):

                gold = self.packer.convert_for_eval(target)
                # remove padding
                length = utils.find_length_from_labels(gold, self.l_map)
                gold = gold[:length].numpy()
                best_path = decoded[:length].numpy()
                character_filted = []
                for c in character.cpu().numpy():
                    if c != 42:
                        character_filted.append(c)
                char = character_filted[:length]
                for i in range(len(gold)):

                    f.write(self.r_c_map[char[i]] + ' ' +
                            self.r_l_map[gold[i]] + ' ' +
                            self.r_l_map[best_path[i]])
                    f.write('\n')

                f.write('\n')
        f.close()
        return self.calc_s()
Exemple #3
0
    def calc_predict(self, ner_model, dataset_loader, test_features, file_out,
                     file_out_2, f_map):
        """
        calculate score for pre-selected metrics

        args:
            ner_model: LSTM-CRF model
            dataset_loader: loader class for test set
        """
        ner_model.eval()
        self.reset()
        idx2label = {v: k for k, v in self.l_map.items()}
        idx2word = {v: k for k, v in f_map.items()}
        for i in range(len(dataset_loader[0])):
            fea_v, tg_v, mask_v = self.packer.repack_vb(
                np.asarray(dataset_loader[0][i]),
                np.asarray(dataset_loader[1][i]),
                np.asarray(dataset_loader[2][i]))
            ner_model.zero_grad()
            scores, hidden = ner_model(fea_v, dataset_loader[3][i])
            decoded = self.decoder.decode(scores.data, mask_v.data)
            gold = [d % len(self.l_map) for d in dataset_loader[1][i]]
            # words = [idx2word[w] for w in dataset_loader[0][i]]
            length = utils.find_length_from_labels(gold, self.l_map)
            gold = gold[:length]
            words = test_features[i][:length]
            best_path = decoded.squeeze(1).tolist()[:length]
            gold = [idx2label[g] for g in gold]
            best_path = [idx2label[g] for g in best_path]
            for i in range(length):
                file_out.write("%s %s\n" % (words[i], best_path[i]))
            file_out.write("\n")

            sent = ''
            pos = None
            word = ''
            for i in range(length):
                if best_path[i].startswith('B'):
                    if pos != None:
                        sent += word + '_' + pos + ' '
                        word = ''
                        pos = None
                    word += words[i]
                    pos = best_path[i].split('-')[1]
                else:
                    assert pos != None
                    word += words[i]
            if len(word) > 0:
                sent += word + '_' + pos + ' '
            file_out_2.write("%s\n" % (sent))
    def calc_f1_batch(self, target_data, decoded_data_crfs, decode_data_scrfs,
                      decode_data_jnts):
        """
        update statics for f1 score

        args:
            decoded_data (batch_size, seq_len): prediction sequence
            target_data (batch_size, seq_len): ground-truth

        """
        for target, decoded_data_crf, decode_data_scrf, decode_data_jnt in zip(
                target_data, decoded_data_crfs, decode_data_scrfs,
                decode_data_jnts):

            length = utils.find_length_from_labels(target, self.l_map)
            gold = target[:length]
            decoded_data_crf = decoded_data_crf[:length]
            decode_data_scrf = decode_data_scrf[:length]
            decode_data_jnt = decode_data_jnt[:length]

            correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(
                decoded_data_crf, gold)
            self.correct_labels_crf += correct_labels_i
            self.total_labels_crf += total_labels_i
            self.gold_count_crf += gold_count_i
            self.guess_count_crf += guess_count_i
            self.overlap_count_crf += overlap_count_i

            correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(
                decode_data_scrf, gold)
            self.correct_labels_scrf += correct_labels_i
            self.total_labels_scrf += total_labels_i
            self.gold_count_scrf += gold_count_i
            self.guess_count_scrf += guess_count_i
            self.overlap_count_scrf += overlap_count_i

            correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(
                decode_data_jnt, gold)
            self.correct_labels_jnt += correct_labels_i
            self.total_labels_jnt += total_labels_i
            self.gold_count_jnt += gold_count_i
            self.guess_count_jnt += guess_count_i
            self.overlap_count_jnt += overlap_count_i
Exemple #5
0
    def calc_acc_batch(self, decoded_data, target_data):
        """
        update statics for accuracy

        args:
            decoded_data (batch_size, seq_len): prediction sequence
            target_data (batch_size, seq_len): ground-truth
        """
        batch_decoded = torch.unbind(decoded_data, 1)
        batch_targets = torch.unbind(target_data, 0)

        for decoded, target in zip(batch_decoded, batch_targets):
            gold = self.packer.convert_for_eval(target)
            # remove padding
            length = utils.find_length_from_labels(gold, self.l_map)
            gold = gold[:length].numpy()
            best_path = decoded[:length].numpy()

            self.total_labels += length
            self.correct_labels += np.sum(np.equal(best_path, gold))
Exemple #6
0
    def calc_acc_batch(self, decoded_data, target_data):
        """
        update statics for accuracy

        args:
            decoded_data (batch_size, seq_len): prediction sequence
            target_data (batch_size, seq_len): ground-truth
        """
        batch_decoded = torch.unbind(decoded_data, 1)
        batch_targets = torch.unbind(target_data, 0)

        for decoded, target in zip(batch_decoded, batch_targets):
            gold = self.packer.convert_for_eval(target)
            # remove padding
            length = utils.find_length_from_labels(gold, self.l_map)
            gold = gold[:length].numpy()
            best_path = decoded[:length].numpy()

            self.total_labels += length
            self.correct_labels += np.sum(np.equal(best_path, gold))
Exemple #7
0
    def calc_f1_batch(self, decoded_data, target_data):
        """
        update statics for f1 score

        args:
            decoded_data (batch_size, seq_len): prediction sequence
            target_data (batch_size, seq_len): ground-truth
        """
        batch_decoded = torch.unbind(decoded_data, 1)
        batch_targets = torch.unbind(target_data, 0)

        for decoded, target in zip(batch_decoded, batch_targets):
            gold = self.packer.convert_for_eval(target)
            # remove padding
            length = utils.find_length_from_labels(gold, self.l_map)
            gold = gold[:length]
            best_path = decoded[:length]

            correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(best_path.numpy(), gold.numpy())
            self.correct_labels += correct_labels_i
            self.total_labels += total_labels_i
            self.gold_count += gold_count_i
            self.guess_count += guess_count_i
            self.overlap_count += overlap_count_i