Пример #1
0
def main():
    parser = argparse.ArgumentParser(
        description=
        """Computes rationale and final class classification scores""",
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        '--data_dir',
        dest='data_dir',
        required=True,
        help='Which directory contains a {train,val,test}.jsonl file?')
    parser.add_argument('--split',
                        dest='split',
                        required=True,
                        help='Which of {train,val,test} are we scoring on?')
    parser.add_argument('--strict',
                        dest='strict',
                        required=False,
                        action='store_true',
                        default=False,
                        help='Do we perform strict scoring?')
    parser.add_argument('--results',
                        dest='results',
                        required=True,
                        help="""Results File
    Contents are expected to be jsonl of:
    {
        "annotation_id": str, required
        # these classifications *must not* overlap
        # these classifications *must not* overlap
        "rationales": List[
            {
                "docid": str, required
                "hard_rationale_predictions": List[{
                    "start_token": int, inclusive, required
                    "end_token": int, exclusive, required
                }], optional,
                # token level classifications, a value must be provided per-token
                # in an ideal world, these correspond to the hard-decoding above.
                "soft_rationale_predictions": List[float], optional.
                # sentence level classifications, a value must be provided for every
                # sentence in each document, or not at all
                "soft_sentence_predictions": List[float], optional.
            }
        ],
        # the classification the model made for the overall classification task
        "classification": str, optional
        # A probability distribution output by the model. We require this to be normalized.
        "classification_scores": Dict[str, float], optional
        # The next two fields are measures for how faithful your model is (the
        # rationales it predicts are in some sense causal of the prediction), and
        # how sufficient they are. We approximate a measure for comprehensiveness by
        # asking that you remove the top k%% of tokens from your documents,
        # running your models again, and reporting the score distribution in the
        # "comprehensiveness_classification_scores" field.
        # We approximate a measure of sufficiency by asking exactly the converse
        # - that you provide model distributions on the removed k%% tokens.
        # 'k' is determined by human rationales, and is documented in our paper.
        # You should determine which of these tokens to remove based on some kind
        # of information about your model: gradient based, attention based, other
        # interpretability measures, etc.
        # scores per class having removed k%% of the data, where k is determined by human comprehensive rationales
        "comprehensiveness_classification_scores": Dict[str, float], optional
        # scores per class having access to only k%% of the data, where k is determined by human comprehensive rationales
        "sufficiency_classification_scores": Dict[str, float], optional
        # the number of tokens required to flip the prediction - see "Is Attention Interpretable" by Serrano and Smith.
        "tokens_to_flip": int, optional
    }
    When providing one of the optional fields, it must be provided for *every* instance.
    The classification, classification_score, and comprehensiveness_classification_scores
    must together be present for every instance or absent for every instance.
    """)
    parser.add_argument('--iou_thresholds',
                        dest='iou_thresholds',
                        required=False,
                        nargs='+',
                        type=float,
                        default=[0.5],
                        help='''Thresholds for IOU scoring.

    These are used for "soft" or partial match scoring of rationale spans.
    A span is considered a match if the size of the intersection of the prediction
    and the annotation, divided by the union of the two spans, is larger than
    the IOU threshold. This score can be computed for arbitrary thresholds.
    ''')
    parser.add_argument('--score_file',
                        dest='score_file',
                        required=False,
                        default=None,
                        help='Where to write results?')
    args = parser.parse_args()
    results = load_jsonl(args.results)
    docids = set(
        chain.from_iterable([rat['docid'] for rat in res['rationales']]
                            for res in results))
    docs = load_flattened_documents(args.data_dir, docids)
    verify_instances(results, docs)
    # load truth
    annotations = annotations_from_jsonl(
        os.path.join(args.data_dir, args.split + '.jsonl'))
    docids |= set(
        chain.from_iterable((ev.docid
                             for ev in chain.from_iterable(ann.evidences))
                            for ann in annotations))

    has_final_predictions = _has_classifications(results)
    scores = dict()
    if args.strict:
        if not args.iou_thresholds:
            raise ValueError(
                "--iou_thresholds must be provided when running strict scoring"
            )
        if not has_final_predictions:
            raise ValueError(
                "We must have a 'classification', 'classification_score', and 'comprehensiveness_classification_score' field in order to perform scoring!"
            )
    # TODO think about offering a sentence level version of these scores.
    if _has_hard_predictions(results):
        truth = list(
            chain.from_iterable(
                Rationale.from_annotation(ann) for ann in annotations))
        pred = list(
            chain.from_iterable(
                Rationale.from_instance(inst) for inst in results))
        if args.iou_thresholds is not None:
            iou_scores = partial_match_score(truth, pred, args.iou_thresholds)
            scores['iou_scores'] = iou_scores
        # NER style scoring
        rationale_level_prf = score_hard_rationale_predictions(truth, pred)
        scores['rationale_prf'] = rationale_level_prf
        token_level_truth = list(
            chain.from_iterable(rat.to_token_level() for rat in truth))
        token_level_pred = list(
            chain.from_iterable(rat.to_token_level() for rat in pred))
        token_level_prf = score_hard_rationale_predictions(
            token_level_truth, token_level_pred)
        scores['token_prf'] = token_level_prf
    else:
        logging.info(
            "No hard predictions detected, skipping rationale scoring")

    if _has_soft_predictions(results):
        flattened_documents = load_flattened_documents(args.data_dir, docids)
        paired_scoring = PositionScoredDocument.from_results(
            results, annotations, flattened_documents, use_tokens=True)
        token_scores = score_soft_tokens(paired_scoring)
        scores['token_soft_metrics'] = token_scores
    else:
        logging.info(
            "No soft predictions detected, skipping rationale scoring")

    if _has_soft_sentence_predictions(results):
        documents = load_documents(args.data_dir, docids)
        paired_scoring = PositionScoredDocument.from_results(results,
                                                             annotations,
                                                             documents,
                                                             use_tokens=False)
        sentence_scores = score_soft_tokens(paired_scoring)
        scores['sentence_soft_metrics'] = sentence_scores
    else:
        logging.info(
            "No sentence level predictions detected, skipping sentence-level diagnostic"
        )

    if has_final_predictions:
        flattened_documents = load_flattened_documents(args.data_dir, docids)
        class_results = score_classifications(results, annotations,
                                              flattened_documents)
        scores['classification_scores'] = class_results
    else:
        logging.info(
            "No classification scores detected, skipping classification")

    pprint.pprint(scores)

    if args.score_file:
        with open(args.score_file, 'w') as of:
            json.dump(scores, of, indent=4, sort_keys=True)
def load_and_cache_examples(args,
                            model_params,
                            tokenizer,
                            evaluate=False,
                            split="train",
                            output_examples=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # only load one split
    input_file = os.path.join(args.data_dir, split)
    cached_features_file = os.path.join(
        os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
            split,
            list(filter(None,
                        model_params["tokenizer_name"].split('/'))).pop(),
            str(args.max_seq_length)))
    if args.gold_evidence:
        cached_features_file += "_goldevidence"

    if os.path.exists(cached_features_file
                      ) and not args.overwrite_cache and not output_examples:
        logger.info("Loading features from cached file %s",
                    cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", input_file)
        dataset = annotations_from_jsonl(
            os.path.join(args.data_dir, split + ".jsonl"))

        docids = set(e.docid
                     for e in chain.from_iterable(
                         chain.from_iterable(
                             map(lambda ann: ann.evidences, chain(dataset)))))
        documents = load_documents(args.data_dir, docids)

        if args.out_domain:
            examples = read_json(args)
        else:
            examples = read_examples(args, model_params, dataset, documents,
                                     split)

        features = convert_examples_to_features(
            args,
            model_params,
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            max_query_length=args.max_query_length,
            is_training=not evaluate)
        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s",
                        cached_features_file)
            torch.save(features, cached_features_file)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()

    # Tensorize all features
    all_input_ids = torch.tensor([f.input_ids for f in features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features],
                                   dtype=torch.long)
    all_cls_index = torch.tensor([f.cls_index for f in features],
                                 dtype=torch.long)
    all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
    all_unique_ids = torch.tensor([f.unique_id for f in features],
                                  dtype=torch.float)

    if evaluate:
        all_example_index = torch.arange(all_input_ids.size(0),
                                         dtype=torch.long)
        tensorized_dataset = TensorDataset(all_input_ids, all_input_mask,
                                           all_segment_ids, all_example_index,
                                           all_cls_index, all_p_mask,
                                           all_unique_ids)
    else:
        all_class_labels = torch.tensor([f.class_label for f in features],
                                        dtype=torch.long)
        all_evidence_labels = torch.tensor(
            [f.evidence_label for f in features], dtype=torch.long)
        tensorized_dataset = TensorDataset(all_input_ids, all_input_mask,
                                           all_segment_ids, all_class_labels,
                                           all_cls_index, all_p_mask,
                                           all_unique_ids, all_evidence_labels)

    if output_examples:
        return tensorized_dataset, examples, features
    return tensorized_dataset, features
Пример #3
0
            "annotation_id": annotation_id,
            'classification': classification,
            "docids": docids,
            "query": query,
            "query_type": None,
            'evidences': [converted_evidences],
        }
        test_eraser_dataset.append(eraser_ex)
    print("Avg evidence sentences %f " % np.average(evs))
    train_eraser_file = open(os.path.join(args.data_path, "train.jsonl"), "w+")
    val_eraser_file = open(os.path.join(args.data_path, "val.jsonl"), "w+")
    test_eraser_file = open(os.path.join(args.data_path, "test.jsonl"), "w+")

    for ex in train_eraser_dataset:
        train_eraser_file.write(json.dumps(ex) + "\n")
    for ex in dev_eraser_dataset:
        val_eraser_file.write(json.dumps(ex) + "\n")
    for ex in test_eraser_dataset:
        test_eraser_file.write(json.dumps(ex) + "\n")
    train_eraser_file.close()
    val_eraser_file.close()
    test_eraser_file.close()

    # Verify whether BEER can be opened using
    test_dataset = annotations_from_jsonl(
        os.path.join(args.data_path, "test.jsonl"))
    train_dataset = annotations_from_jsonl(
        os.path.join(args.data_path, "train.jsonl"))
    val_dataset = annotations_from_jsonl(
        os.path.join(args.data_path, "val.jsonl"))
Пример #4
0
def evaluate(args,
             model_params,
             model,
             tokenizer,
             prefix="",
             output_examples=False,
             split="val"):

    # Keep all eraser annotations at  the ready to go
    annotations = annotations_from_jsonl(
        os.path.join(args.data_dir, split + '.jsonl'))
    true_rationales = list(
        chain.from_iterable(
            metrics.Rationale.from_annotation(ann) for ann in annotations))

    if output_examples:
        dataset, examples, features = load_and_cache_examples(
            args,
            model_params,
            tokenizer,
            evaluate=True,
            split=split,
            output_examples=output_examples)
    else:
        dataset, features = load_and_cache_examples(
            args,
            model_params,
            tokenizer,
            evaluate=True,
            split=split,
            output_examples=output_examples)
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(
        dataset) if args.local_rank == -1 else DistributedSampler(dataset)
    eval_dataloader = DataLoader(dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_results = []
    all_targets = []
    # all_rationale_targets = []
    all_rationale_results = []
    all_results_dictionary = {}
    results = []
    class_interner = dict(
        (y, x) for (x, y) in enumerate(model_params['classes']))
    class_labels = [
        k for k, v in sorted(class_interner.items(), key=lambda x: x[1])
    ]

    task_performances = []
    evidence_performances = []
    for run_no in range(10):
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    # 'token_type_ids': None if args.model_type == 'xlm' or args.model_type == 'roberta' else batch[2]  # XLM don't use segment_ids
                    'p_mask': batch[5]
                }
                example_indices = batch[3]
                if args.model_type in ['xlnet', 'xlm']:
                    inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
                inputs.update({"evaluate": True})
                outputs = model(**inputs)
                logits = outputs[0]
                hard_preds = to_list(torch.argmax(logits.float(), dim=-1))
                hard_rationale_pred = outputs[1]
                all_results.extend(hard_preds)
            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_targets.extend([eval_feature.class_label])

                # hard_rationale_binary = np.zeros(eval_feature.evidence_label.shape)
                # hard_rationale_binary[np.array(to_list(hard_rationale_pred[i]))] = 1
                # all_rationale_results.extend(hard_rationale_binary[:sum(eval_feature.input_mask)])
                # all_rationale_targets.extend(eval_feature.evidence_label[:sum(eval_feature.input_mask)])
                # all_rationale_dictionary[unique_id] = hard_rationale_binary[:sum(eval_feature.input_mask)]
                key = (eval_feature.annotation_id, eval_feature.doc_id)
                for chunk in hard_rationale_pred[i]:
                    chunk = to_list(chunk)
                    start_token = eval_feature.token_to_orig_map[chunk[0]]
                    chunk_end = chunk[-1]
                    if eval_feature.tokens[chunk[-1]] in ["[SEP]", "</s>"]:
                        # Last chunk which also has SEP and not part of map to original tokens
                        chunk_end = chunk[-1] - 1
                    end_token = eval_feature.token_to_orig_map[
                        chunk_end] + 1  # The final token is still exclusive
                    # token_to_orig_map hard to implement if Roberta is  being used (token level doesn't do as well)
                    all_rationale_results.append(
                        metrics.Rationale(ann_id=key[0],
                                          docid=key[1],
                                          start_token=start_token,
                                          end_token=end_token))

                all_results_dictionary[key] = hard_preds[i]

        output_ratioanale_prediction_file = os.path.join(
            args.output_dir, "rationale_predictions.json")
        file_pointer = open(output_ratioanale_prediction_file, "w+")
        ann_to_rat = metrics._keyed_rationale_from_list(all_rationale_results)
        ann_to_gold = metrics._keyed_rationale_from_list(true_rationales)
        example_dict = dict()
        for ex in examples:
            example_dict[(ex.id, ex.doc_id)] = ex

        evidence_iou = metrics.partial_match_score(true_rationales,
                                                   all_rationale_results,
                                                   [0.5])
        hard_rationale_metrics = metrics.score_hard_rationale_predictions(
            true_rationales, all_rationale_results)
        # logger.info("Rationale Partial score metrics:")
        # logger.info(evidence_iou)
        # logger.info("Rationale hard metrics:")
        # logger.info(hard_rationale_metrics)
        results = (classification_report(
            all_targets,
            all_results,
            target_names=class_labels,
            output_dict=True)['macro avg']['f1-score'],
                   accuracy_score(all_targets, all_results))
        # logger.info('Classification Report: {}'.format(
        #     classification_report(all_targets, all_results, target_names=class_labels, output_dict=True)))
        task_performances.append(results[0])
        evidence_performances.append(evidence_iou[0]['macro']['f1'])

    # currently considers the last evaluation run sample
    for e, feature in enumerate(features):
        unique_id = (feature.annotation_id, feature.doc_id)
        tokens = feature.tokens
        token_level_rationales = ann_to_rat[unique_id]
        doc_tokens = example_dict[unique_id].doc_toks
        # This technique is aslightly flawed
        # rationale = [tokens[i] for i in np.where(token_level_rationale == 1)[0]]
        # # Every chunk is separable
        # rationale_text = ""
        # for i in range(model_params["K"]):
        #     rationale_text += " ".join(rationale[i*model_params["chunk_size"]:(i+1)*model_params["chunk_size"]]) + "\n"
        class_label = class_labels[feature.class_label]
        predicted_class_label = all_results_dictionary[unique_id]
        # human_rationale = [tokens[i] for i in np.where(np.asarray(feature.evidence_label)== 1)[0]] # This is stale too since there is extra padding into what both models consider
        # query = " ".join([tok for tok in feature.tokens[:64] if tok not in ["[CLS]", "[PAD]", "[SEP]"]])
        rationales = []
        # {
        #     "docid": str, required
        #     "hard_rationale_predictions": List[{
        #                                            "start_token": int, inclusive, required
        # "end_token": int, exclusive, required
        # }]
        for rationale in token_level_rationales:
            start_token = rationale.start_token
            end_token = rationale.end_token
            rationale_text = doc_tokens[start_token:end_token]
            rationales.append({
                "start_token": start_token,
                "end_token": end_token
            })
        output_dict = {
            "annotation_id":
            feature.annotation_id,
            "classification":
            predicted_class_label,
            "rationales": [{
                "docid": feature.doc_id,
                "hard_rationale_predictions": rationales
            }],
            "gold_classification":
            class_label,
            # "gold_rationales" : ann_to_gold[unique_id],
            # "predicted_rationales" : ann_to_rat[unique_id],
            "query":
            example_dict[unique_id].query,
            "doc_tokens":
            doc_tokens
        }
        file_pointer.write(json.dumps(output_dict) + "\n")
        # Print the rationale chunks predicted to file
    file_pointer.close()
    logger.info("Averaged classification results: %.4f" %
                np.average(task_performances))
    logger.info("Averaged Evidence results: %.4f" %
                np.average(evidence_performances))
    return results
Пример #5
0
def train(args, model_params, train_dataset, model, tokenizer):
    if args.local_rank in [-1, 0] and args.tf_summary:
        #if os.path.isdir(os.path.join("runs", os.path.basename(args.output_dir))):
        #    shutil.rmtree(os.path.join("runs", os.path.basename(args.output_dir)))
        tb_writer = SummaryWriter("runs/" + os.path.basename(args.output_dir))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    # TODO: Addiitonal parameters require more learning rate compared to BERT model
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    if args.evaluate_during_training:
        dataset, eval_features = load_and_cache_examples(args,
                                                         model_params,
                                                         tokenizer,
                                                         evaluate=True,
                                                         split="val",
                                                         output_examples=False)
        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(
            dataset) if args.local_rank == -1 else DistributedSampler(dataset)
        eval_dataloader = DataLoader(dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

    # Keep all eraser annotations at  the ready to go
    annotations = annotations_from_jsonl(
        os.path.join(args.data_dir, args.eval_split + '.jsonl'))
    true_rationales = list(
        chain.from_iterable(
            metrics.Rationale.from_annotation(ann) for ann in annotations))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    class_tr_loss, class_logging_loss = 0.0, 0.0
    info_tr_loss, info_logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    best_f1 = (-1, -1)  # Classfication F1 and accuracy
    wait_step = 0
    stop_training = False
    metric_name = "F1"
    epoch = 0
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                # 'token_type_ids':  None if args.model_type == 'xlm' or args.model_type == 'roberta' else batch[2],
                'labels': batch[3],
                'p_mask': batch[5]
            }
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)
            class_loss = outputs[1]
            info_loss = outputs[2]
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
                class_loss = class_loss.mean()
                info_loss = info_loss.mean()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                class_loss = class_loss / args.gradient_accumulation_steps
                info_loss = info_loss / args.gradient_accumulation_steps
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)

            tr_loss += loss.item()
            class_tr_loss += class_loss.item()
            info_tr_loss += info_loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = predict(args, model_params, model, tokenizer,
                                          eval_features, eval_dataloader,
                                          true_rationales, tb_writer,
                                          global_step)
                        #for key, value in results.items():
                        #    tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                        if best_f1 < results:
                            logger.info("Saving model with best %s: %.2f (Acc %.2f) -> %.2f (Acc %.2f) on epoch=%d" % \
                                        (metric_name, best_f1[0] * 100, best_f1[1] * 100, results[0] * 100, results[1] * 100,
                                         epoch))
                            output_dir = os.path.join(args.output_dir,
                                                      'best_model')
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Take care of distributed/parallel training
                            model_to_save.save_pretrained(output_dir)
                            torch.save(
                                args,
                                os.path.join(output_dir, 'training_args.bin'))
                            best_f1 = results
                            wait_step = 0
                            stop_training = False
                        else:
                            wait_step += 1
                            if wait_step == args.wait_step:
                                logger.info("Loosing Patience")
                                stop_training = True
                if args.tf_summary and global_step % 20 == 0:
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    tb_writer.add_scalar('class_loss',
                                         (class_tr_loss - class_logging_loss) /
                                         20, global_step)
                    tb_writer.add_scalar(
                        'info_loss', (info_tr_loss - info_logging_loss) / 20,
                        global_step)
                    logging_loss = tr_loss
                    class_logging_loss = class_tr_loss
                    info_logging_loss = info_tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if stop_training or args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if stop_training or args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
        epoch += 1
    if args.local_rank in [-1, 0] and args.tf_summary:
        tb_writer.close()

    return global_step, tr_loss / global_step
from ib_utils import read_examples
import argparse
import json
from itertools import chain
from eraser.rationale_benchmark.utils import load_documents, annotations_from_jsonl
import logging, os

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', dest='data_dir', required=True)
parser.add_argument('--model_params', dest='model_params', required=True)
parser.add_argument("--split", type=str, default="val")
parser.add_argument("--truncate", default=False, action="store_true")
parser.add_argument("--max_seq_length", default=512, type=int)
parser.add_argument("--max_query_length", default=24, type=int)
parser.add_argument("--max_num_sentences", default=20, type=int)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument('--low_resource', action="store_true", default=False)
args = parser.parse_args()

# Parse model args json
with open(args.model_params, 'r') as fp:
    logging.debug(f'Loading model parameters from {args.model_params}')
    model_params = json.load(fp)

dataset = annotations_from_jsonl(
    os.path.join(args.data_dir, args.split + ".jsonl"))
docids = set(e.docid for e in chain.from_iterable(
    chain.from_iterable(map(lambda ann: ann.evidences, chain(dataset)))))
documents = load_documents(args.data_dir, docids)
examples = read_examples(args, model_params, dataset, documents, args.split)
def evaluate(args,
             model_params,
             model,
             tokenizer,
             prefix="",
             output_examples=False,
             split="val"):

    # Keep all eraser annotations at  the ready to go
    annotations = annotations_from_jsonl(
        os.path.join(args.data_dir, split + '.jsonl'))
    true_rationales = list(
        chain.from_iterable(
            metrics.Rationale.from_annotation(ann) for ann in annotations))

    if output_examples:
        dataset, examples, features = load_and_cache_examples(
            args,
            model_params,
            tokenizer,
            evaluate=True,
            split=split,
            output_examples=output_examples)
    else:
        dataset, features = load_and_cache_examples(
            args,
            model_params,
            tokenizer,
            evaluate=True,
            split=split,
            output_examples=output_examples)
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(
        dataset) if args.local_rank == -1 else DistributedSampler(dataset)
    eval_dataloader = DataLoader(dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)

    # all_rationale_targets = []
    all_rationale_results = []
    class_interner = dict(
        (y, x) for (x, y) in enumerate(model_params['classes']))
    class_labels = [
        k for k, v in sorted(class_interner.items(), key=lambda x: x[1])
    ]
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)
        with torch.no_grad():
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                # 'token_type_ids': None if args.model_type == 'xlm' or args.model_type == 'roberta' else batch[2]  # XLM don't use segment_ids
                'p_mask': batch[5]
            }
            example_indices = batch[3]
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
            inputs.update({
                "sentence_starts": batch[7],
                "sentence_ends": batch[8],
                "sentence_mask": batch[9],
                "evidence_labels": batch[10],
            })
            inputs.update({"evaluate": True})
            outputs = model(**inputs)
            logits = outputs[0]
            hard_rationale_pred = to_list(torch.sigmoid(logits))
            hard_rationale_pred = [
                np.where(np.asarray(l) > 0.5)[0] for l in hard_rationale_pred
            ]
        for i, example_index in enumerate(example_indices):
            eval_feature = features[example_index.item()]
            unique_id = int(eval_feature.unique_id)

            # based on indices selected; get corresponding start and end ; add those tokens to the Rationale list
            key = (eval_feature.annotation_id, eval_feature.doc_id)
            last_sentence = sum(eval_feature.sentence_mask)
            for sentence in hard_rationale_pred[i]:
                if sentence >= last_sentence:
                    continue
                start_token = eval_feature.sentence_starts[sentence]
                end_token = eval_feature.sentence_ends[sentence]

                orig_start_token = eval_feature.token_to_orig_map[start_token]
                if eval_feature.tokens[end_token] in ["[SEP]", "</s>"]:
                    # last sentence, then end_token points to SEP, we want to get it to point to last word
                    end_token -= 1
                orig_end_token = eval_feature.token_to_orig_map[
                    end_token] + 1  # +1 since it needs to be exclusive

                all_rationale_results.append(
                    metrics.Rationale(ann_id=key[0],
                                      docid=key[1],
                                      start_token=orig_start_token,
                                      end_token=orig_end_token))

    output_ratioanale_prediction_file = os.path.join(
        args.output_dir, "rationale_predictions.json")
    file_pointer = open(output_ratioanale_prediction_file, "w+")
    ann_to_rat = metrics._keyed_rationale_from_list(all_rationale_results)
    ann_to_gold = metrics._keyed_rationale_from_list(true_rationales)
    example_dict = dict()
    for ex in examples:
        example_dict[(ex.id, ex.doc_id)] = ex

    evidence_iou = metrics.partial_match_score(true_rationales,
                                               all_rationale_results,
                                               [0.1, 0.5, 0.9])

    for e, feature in enumerate(features):
        unique_id = (feature.annotation_id, feature.doc_id)
        tokens = feature.tokens
        token_level_rationales = ann_to_rat[unique_id]
        doc_tokens = example_dict[unique_id].doc_toks
        class_label = class_labels[feature.class_label]
        # human_rationale = [tokens[i] for i in np.where(np.asarray(feature.evidence_label)== 1)[0]] # This is stale too since there is extra padding into what both models consider
        # query = " ".join([tok for tok in feature.tokens[:64] if tok not in ["[CLS]", "[PAD]", "[SEP]"]])
        rationales = []
        for rationale in token_level_rationales:
            start_token = rationale.start_token
            end_token = rationale.end_token
            rationale_text = doc_tokens[start_token:end_token]
            rationales.append({
                "start_token": start_token,
                "end_token": end_token
            })
        output_dict = {
            "annotation_id":
            feature.annotation_id,
            "rationales": [{
                "docid": feature.doc_id,
                "hard_rationale_predictions": rationales
            }],
            "gold_classification":
            class_label,
            # "gold_rationales" : ann_to_gold[unique_id],
            # "predicted_rationales" : ann_to_rat[unique_id],
            "query":
            example_dict[unique_id].query,
            "doc_tokens":
            doc_tokens
        }
        file_pointer.write(json.dumps(output_dict) + "\n")
        # Print the rationale chunks predicted to file
    result = evidence_iou
    file_pointer.close()
    return result