Beispiel #1
0
def decode(data_feats, data_tags, data_class, output_path):
    data_index = np.arange(len(data_feats))
    losses = []
    TP, FP, FN, TN = 0.0, 0.0, 0.0, 0.0
    TP2, FP2, FN2, TN2 = 0.0, 0.0, 0.0, 0.0
    with open(output_path, 'w') as f:
        for j in range(0, len(data_index), opt.test_batchSize):
            if opt.testing:
                inputs, tags, raw_tags, classes, raw_classes, lens, line_nums = data_reader.get_minibatch_with_class(data_feats, data_tags, data_class, word_to_idx, tag_to_idx, class_to_idx, data_index, j, opt.test_batchSize, add_start_end=opt.bos_eos, multiClass=opt.multiClass, keep_order=opt.testing, enc_dec_focus=opt.enc_dec, device=opt.device)
            else:
                inputs, tags, raw_tags, classes, raw_classes, lens = data_reader.get_minibatch_with_class(data_feats, data_tags, data_class, word_to_idx, tag_to_idx, class_to_idx, data_index, j, opt.test_batchSize, add_start_end=opt.bos_eos, multiClass=opt.multiClass, keep_order=opt.testing, enc_dec_focus=opt.enc_dec, device=opt.device)
            if opt.word_digit_features:
                word_seqs = [[idx_to_word[w_idx] for w_idx in word_seq] for word_seq in inputs.data.cpu().numpy()]
                ext_features = feature_extractor.get_digit_features(word_seqs, lens)

            if opt.enc_dec:
                opt.greed_decoding = True
                if opt.greed_decoding:
                    if opt.word_digit_features:
                        tag_scores_1best, outputs_1best, encoder_info = model_tag.decode_greed(inputs, tags[:, 0:1], lens, with_snt_classifier=True, extFeats=ext_features)
                    else:
                        tag_scores_1best, outputs_1best, encoder_info = model_tag.decode_greed(inputs, tags[:, 0:1], lens, with_snt_classifier=True)
                    tag_loss = tag_loss_function(tag_scores_1best.contiguous().view(-1, len(tag_to_idx)), tags[:, 1:].contiguous().view(-1))
                    top_pred_slots = outputs_1best.cpu().numpy()
                else:
                    beam_size = 2
                    if opt.word_digit_features:
                        beam_scores_1best, top_path_slots, encoder_info = model_tag.decode_beam_search(inputs, lens, beam_size, tag_to_idx, with_snt_classifier=True, extFeats=ext_features)
                    else:
                        beam_scores_1best, top_path_slots, encoder_info = model_tag.decode_beam_search(inputs, lens, beam_size, tag_to_idx, with_snt_classifier=True)
                    top_pred_slots = [[item[0].item() for item in seq] for seq in top_path_slots]
                    ppl = beam_scores_1best.cpu() / torch.tensor(lens, dtype=torch.float)
                    tag_loss = ppl.exp().sum()
                #tags = tags[:, 1:].data.cpu().numpy()
            elif opt.crf:
                max_len = max(lens)
                masks = [([1] * l) + ([0] * (max_len - l)) for l in lens]
                masks = torch.tensor(masks, dtype=torch.uint8, device=opt.device)
                if opt.word_digit_features:
                    crf_feats, encoder_info = model_tag._get_lstm_features(inputs, lens, with_snt_classifier=True, extFeats=ext_features)
                else:
                    crf_feats, encoder_info = model_tag._get_lstm_features(inputs, lens, with_snt_classifier=True)
                tag_path_scores, tag_path = model_tag.forward(crf_feats, masks)
                tag_loss = model_tag.neg_log_likelihood(crf_feats, masks, tags)
                top_pred_slots = tag_path.data.cpu().numpy()
            else:
                if opt.word_digit_features:
                    tag_scores, encoder_info = model_tag(inputs, lens, with_snt_classifier=True, extFeats=ext_features)
                else:
                    tag_scores, encoder_info = model_tag(inputs, lens, with_snt_classifier=True)
                tag_loss = tag_loss_function(tag_scores.contiguous().view(-1, len(tag_to_idx)), tags.view(-1))
                top_pred_slots = tag_scores.data.cpu().numpy().argmax(axis=-1)
                #tags = tags.data.cpu().numpy()
            if opt.task_sc:
                class_scores = model_class(encoder_info_filter(encoder_info))
                class_loss = class_loss_function(class_scores, classes)
                if opt.multiClass:
                    snt_probs = class_scores.data.cpu().numpy()
                else:
                    snt_probs = class_scores.data.cpu().numpy().argmax(axis=-1)
                losses.append([tag_loss.item()/sum(lens), class_loss.item()/len(lens)])
            else:
                losses.append([tag_loss.item()/sum(lens), 0])

            inputs = inputs.data.cpu().numpy()
            #classes = classes.data.cpu().numpy()
            for idx, pred_line in enumerate(top_pred_slots):
                length = lens[idx]
                pred_seq = [idx_to_tag[tag] for tag in pred_line][:length]
                lab_seq = [idx_to_tag[tag] if type(tag) == int else tag for tag in raw_tags[idx]]
                pred_chunks = acc.get_chunks(['O']+pred_seq+['O'])
                label_chunks = acc.get_chunks(['O']+lab_seq+['O'])
                for pred_chunk in pred_chunks:
                    if pred_chunk in label_chunks:
                        TP += 1
                    else:
                        FP += 1
                for label_chunk in label_chunks:
                    if label_chunk not in pred_chunks:
                        FN += 1

                input_line = [idx_to_word[word] for word in inputs[idx]][:length]
                word_tag_line = [input_line[_idx]+':'+lab_seq[_idx]+':'+pred_seq[_idx] for _idx in range(len(input_line))]
                
                if opt.task_sc:
                    if opt.multiClass:
                        pred_classes = [idx_to_class[i] for i,p in enumerate(snt_probs[idx]) if p > 0.5]
                        gold_classes = [idx_to_class[i] for i in raw_classes[idx]]
                        for pred_class in pred_classes:
                            if pred_class in gold_classes:
                                TP2 += 1
                            else:
                                FP2 += 1
                        for gold_class in gold_classes:
                            if gold_class not in pred_classes:
                                FN2 += 1
                        gold_class_str = ';'.join(gold_classes)
                        pred_class_str = ';'.join(pred_classes)
                    else:
                        pred_class = idx_to_class[snt_probs[idx]]
                        if type(raw_classes[idx]) == int:
                            gold_classes = {idx_to_class[raw_classes[idx]]}
                        else:
                            gold_classes = set(raw_classes[idx])
                        if pred_class in gold_classes:
                            TP2 += 1
                        else:
                            FP2 += 1
                            FN2 += 1
                        gold_class_str = ';'.join(list(gold_classes))
                        pred_class_str = pred_class
                else:
                    gold_class_str = ''
                    pred_class_str = ''

                if opt.testing:
                    f.write(str(line_nums[idx])+' : '+' '.join(word_tag_line)+' <=> '+gold_class_str+' <=> '+pred_class_str+'\n')
                else:
                    f.write(' '.join(word_tag_line)+' <=> '+gold_class_str+' <=> '+pred_class_str+'\n')

    if TP == 0:
        p, r, f = 0, 0, 0
    else:
        p, r, f = 100*TP/(TP+FP), 100*TP/(TP+FN), 100*2*TP/(2*TP+FN+FP)
    
    mean_losses = np.mean(losses, axis=0)
    return mean_losses, p, r, f, 0 if 2*TP2+FN2+FP2 == 0 else 100*2*TP2/(2*TP2+FN2+FP2)
Beispiel #2
0
 logger.info("Training starts at %s" % (time.asctime(time.localtime(time.time()))))
 train_data_index = np.arange(len(train_feats['data']))
 best_f1, best_result = -1, {}
 for i in range(opt.max_epoch):
     start_time = time.time()
     losses = []
     # training data shuffle
     np.random.shuffle(train_data_index)
     model_tag.train()
     if opt.task_sc:
         model_class.train()
     
     nsentences = len(train_data_index)
     piece_sentences = opt.batchSize if int(nsentences * 0.1 / opt.batchSize) == 0 else int(nsentences * 0.1 / opt.batchSize) * opt.batchSize
     for j in range(0, nsentences, opt.batchSize):
         inputs, tags, raw_tags, classes, raw_classes, lens = data_reader.get_minibatch_with_class(train_feats['data'], train_tags['data'], train_class['data'], word_to_idx, tag_to_idx, class_to_idx, train_data_index, j, opt.batchSize, add_start_end=opt.bos_eos, multiClass=opt.multiClass, enc_dec_focus=opt.enc_dec, device=opt.device)
         if opt.word_digit_features:
             word_seqs = [[idx_to_word[w_idx] for w_idx in word_seq] for word_seq in inputs.data.cpu().numpy()]
             ext_features = feature_extractor.get_digit_features(word_seqs, lens)
         optimizer.zero_grad()
         if opt.enc_dec:
             if opt.word_digit_features:
                 tag_scores, encoder_info = model_tag(inputs, tags[:, :-1], lens, with_snt_classifier=True, extFeats=ext_features)
             else:
                 tag_scores, encoder_info = model_tag(inputs, tags[:, :-1], lens, with_snt_classifier=True)
             tag_loss = tag_loss_function(tag_scores.contiguous().view(-1, len(tag_to_idx)), tags[:, 1:].contiguous().view(-1))
         elif opt.crf:
             max_len = max(lens)
             masks = [([1] * l) + ([0] * (max_len - l)) for l in lens]
             masks = torch.tensor(masks, dtype=torch.uint8, device=opt.device)
             if opt.word_digit_features: