示例#1
0
    def predict(self, seq_lengths, word_ids, char_for_ids, char_rev_ids,
                word_lengths, cap_ids, pos_ids):
        feed_dict = dict()
        feed_dict[self.seq_lengths] = seq_lengths
        feed_dict[self.dropout_keep_prob] = 1
        if self.word_dim:
            feed_dict[self.word_ids] = word_ids
        if self.char_dim:
            feed_dict[self.char_for_ids] = char_for_ids
            feed_dict[self.word_lengths] = word_lengths
            if self.char_bidirect:
                feed_dict[self.char_rev_ids] = char_rev_ids
        if self.cap_dim:
            feed_dict[self.cap_ids] = cap_ids
        if self.pos_dim:
            feed_dict[self.pos_ids] = pos_ids
        tag_scores, transitions = self.session.run(
            [self.tag_scores, self.transitions], feed_dict)
        batch_viterbi_sequence = []
        batch_score = []
        for tag_score_, seq_length_ in zip(tag_scores, seq_lengths):
            # Remove padding from scores and tag sequence.
            tag_score_ = tag_score_[:seq_length_]

            # Compute the highest scoring sequence.
            viterbi_sequence, score = crf.viterbi_decode(
                tag_score_, transitions)
            batch_viterbi_sequence.append(viterbi_sequence)
            batch_score.append(score)

        return batch_viterbi_sequence, batch_score
示例#2
0
def run_evaluation(batches, extra_text=''):
    predictions = []
    batches_with_mask = []
    for batch in batches:
        token_batch, label_batch, shape_batch, char_batch, seq_len_batch, tok_len_batch, label_mask_batch = batch
        batch_seq_len, mask_batch, seq_len_batch = mask(batch)
        batches_with_mask.append(batch + (mask_batch, ))

        char_embedding_feed = char_feed(token_batch, char_batch, tok_len_batch)
        lstm_feed = {
            model.input_x1: token_batch,
            model.input_x2: shape_batch,
            model.input_y: label_batch,
            model.input_mask: mask_batch,
            model.sequence_lengths: seq_len_batch,
            model.max_seq_len: batch_seq_len,
            model.batch_size: batch_size,
        }
        lstm_feed.update(char_embedding_feed)

        if viterbi:
            run_list = [model.predictions, model.transition_params]
            preds, transition_params = sess.run(run_list, feed_dict=lstm_feed)

            viterbi_repad = np.empty((batch_size, batch_seq_len))
            for i, (unary_scores,
                    sequence_lens) in enumerate(zip(preds, seq_len_batch)):
                viterbi_sequence, _ = crf.viterbi_decode(
                    unary_scores, transition_params)
                viterbi_repad[i] = viterbi_sequence
            predictions.append(viterbi_repad)

        else:
            run_list = [model.predictions, model.unflat_scores]
            preds, scores = sess.run(run_list, feed_dict=lstm_feed)

    inv_label_map = dp.inv_label_map()

    f1_micro, precision = evaluation.segment_eval(
        batches_with_mask,
        predictions,
        type_set,
        type_int_int_map,
        inv_label_map,
        dp.inv_token_map(),
        outside_idx=map(
            lambda t: type_set[t]
            if t in type_set else type_set['O'], outside_set),
        pad_width=pad_width,
        start_end=False,
        extra_text='Segment evaluation %s:' % extra_text)

    print('')

    return f1_micro, precision