def predict(args, model, tokenizer):
    dataset, evaluate_label_ids, total_words = load_and_cache_examples(
        args, args.task_name, tokenizer)
    sampler = SequentialSampler(dataset)
    # process the incoming data one by one
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=1)
    print("***** Running prediction *****")

    total_preds, gold_labels, total_pred_labels = None, None, None
    idx = 0
    if args.tagging_schema == 'BIEOS':
        absa_label_vocab = {
            'O': 0,
            'EQ': 1,
            'B-POS': 2,
            'I-POS': 3,
            'E-POS': 4,
            'S-POS': 5,
            'B-NEG': 6,
            'I-NEG': 7,
            'E-NEG': 8,
            'S-NEG': 9,
            'B-NEU': 10,
            'I-NEU': 11,
            'E-NEU': 12,
            'S-NEU': 13
        }
    elif args.tagging_schema == 'BIO':
        absa_label_vocab = {
            'O': 0,
            'EQ': 1,
            'B-POS': 2,
            'I-POS': 3,
            'B-NEG': 4,
            'I-NEG': 5,
            'B-NEU': 6,
            'I-NEU': 7
        }
    elif args.tagging_schema == 'OT':
        absa_label_vocab = {'O': 0, 'T-POS': 1, 'T-NEG': 2, 'T-NEU': 3}
    else:
        raise Exception("Invalid tagging schema %s..." % args.tagging_schema)
    absa_id2tag = {}
    for k in absa_label_vocab:
        v = absa_label_vocab[k]
        absa_id2tag[v] = k

    for batch in tqdm(dataloader, desc="Evaluating"):
        batch = tuple(t.to(args.device) for t in batch)
        with torch.no_grad():
            inputs = {
                'input_ids':
                batch[0],
                'attention_mask':
                batch[1],
                'token_type_ids':
                batch[2] if args.model_type in ['bert', 'xlnet'] else None,
                # XLM don't use segment_ids
                'labels':
                batch[3]
            }
            outputs = model(**inputs)
            # logits: (1, seq_len, label_size)
            logits = outputs[1]
            # preds: (1, seq_len)
            if model.tagger_config.absa_type != 'crf':
                preds = np.argmax(logits.detach().cpu().numpy(), axis=-1)
            else:
                mask = batch[1]
                preds = model.tagger.viterbi_tags(logits=logits, mask=mask)
            label_indices = evaluate_label_ids[idx]
            words = total_words[idx]
            pred_labels = preds[0][label_indices]
            assert len(words) == len(pred_labels)
            pred_tags = [absa_id2tag[label] for label in pred_labels]

            if args.tagging_schema == 'OT':
                pred_tags = ot2bieos_ts(pred_tags)
            elif args.tagging_schema == 'BIO':
                pred_tags = ot2bieos_ts(bio2ot_ts(pred_tags))
            else:
                # current tagging schema is BIEOS, do nothing
                pass
            p_ts_sequence = tag2ts(ts_tag_sequence=pred_tags)
            output_ts = []
            #print(p_ts_sequence)
            for t in p_ts_sequence:
                beg, end, sentiment = t
                aspect = words[beg:end + 1]
                output_ts.append('%s: %s' % (aspect, sentiment))
            #print("Input: %s, output: %s" % (' '.join(words), '\t'.join(output_ts)))
            # for evaluation
            pred_labels = np.pad(pred_labels,
                                 (0, preds.shape[1] - len(pred_labels)),
                                 'constant').reshape((1, -1))
            if total_preds is None:
                total_preds = preds
                total_pred_labels = pred_labels
            else:
                total_preds = np.append(total_preds, preds, axis=0)
                total_pred_labels = np.append(total_pred_labels,
                                              pred_labels,
                                              axis=0)
            if inputs['labels'] is not None:
                # for the unseen data, there is no ``labels''
                if gold_labels is None:
                    gold_labels = inputs['labels'].detach().cpu().numpy()
                else:
                    gold_labels = np.append(
                        gold_labels,
                        inputs['labels'].detach().cpu().numpy(),
                        axis=0)
        idx += 1
    if gold_labels is not None:
        torch.save(gold_labels, 'gold_labels.pt')
        torch.save(total_pred_labels, 'total_pred_labels.pt')
        result = compute_metrics_absa(
            preds=total_preds,
            labels=gold_labels,
            all_evaluate_label_ids=evaluate_label_ids,
            tagging_schema=args.tagging_schema)
        for (k, v) in result.items():
            print("%s: %s" % (k, v))
Пример #2
0
def evaluate(args, model, tokenizer, mode, prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = (args.task_name, )
    eval_outputs_dirs = (args.output_dir, )

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset, eval_evaluate_label_ids = load_and_cache_examples(
            args, eval_task, tokenizer, mode=mode)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(
            eval_dataset) if args.local_rank == -1 else DistributedSampler(
                eval_dataset)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # Eval!
        #logger.info("***** Running evaluation on %s.txt *****" % mode)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        crf_logits, crf_mask = [], []
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {
                    'input_ids':
                    batch[0],
                    'attention_mask':
                    batch[1],
                    'token_type_ids':
                    batch[2] if args.model_type in ['bert', 'xlnet'] else
                    None,  # XLM don't use segment_ids
                    'labels':
                    batch[3]
                }
                outputs = model(**inputs)
                # logits: (bsz, seq_len, label_size)
                # here the loss is the masked loss
                tmp_eval_loss, logits = outputs[:2]
                eval_loss += tmp_eval_loss.mean().item()

                crf_logits.append(logits)
                crf_mask.append(batch[1])
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs['labels'].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    inputs['labels'].detach().cpu().numpy(),
                    axis=0)
        eval_loss = eval_loss / nb_eval_steps
        # argmax operation over the last dimension
        if model.tagger_config.absa_type != 'crf':
            # greedy decoding
            preds = np.argmax(preds, axis=-1)
        else:
            # viterbi decoding for CRF-based model
            crf_logits = torch.cat(crf_logits, dim=0)
            crf_mask = torch.cat(crf_mask, dim=0)
            preds = model.tagger.viterbi_tags(logits=crf_logits, mask=crf_mask)
        result = compute_metrics_absa(preds, out_label_ids,
                                      eval_evaluate_label_ids,
                                      args.tagging_schema)
        result['eval_loss'] = eval_loss
        results.update(result)

        output_eval_file = os.path.join(eval_output_dir,
                                        "%s_results.txt" % mode)
        with open(output_eval_file, "w") as writer:
            #logger.info("***** %s results *****" % mode)
            for key in sorted(result.keys()):
                if 'eval_loss' in key:
                    logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
            #logger.info("***** %s results *****" % mode)

    return results