Exemple #1
0
def predictions(in_files, out_folder, model_paths, model_types, no_cuda,
                per_gpu_eval_batch_size, do_not_lower_case, lang_id, v2,
                n_best_size, max_answer_length, verbose_logging,
                null_score_diff_threshold, do_evaluate, **kwargs):
    assert len(model_paths) == len(model_types)
    for model_path, model_type in zip(model_paths, model_types):
        model = get_model(model_path)
        args = Args(model_path=model_path,
                    model_type=model_type,
                    predictions_folder=out_folder,
                    no_cuda=no_cuda,
                    do_not_lower_case=do_not_lower_case,
                    per_gpu_eval_batch_size=per_gpu_eval_batch_size,
                    lang_id=lang_id,
                    v2=v2,
                    n_best_size=n_best_size,
                    max_answer_length=max_answer_length,
                    verbose_logging=verbose_logging,
                    null_score_diff_threshold=null_score_diff_threshold,
                    **kwargs)
        tokenizer = get_tokenizer(model_path, args.do_lower_case)
        for in_file in in_files:
            args.eval_file = in_file
            logger.debug(args)
            dataset, examples, features = load_or_convert(args,
                                                          tokenizer,
                                                          evaluate=True)
            if do_evaluate:
                out_path = args.predictions_folder
                args.predictions_folder = None
                suffix = os.path.basename(os.path.normpath(model_path))
                score = evaluate(args,
                                 model,
                                 tokenizer,
                                 dataset,
                                 examples,
                                 features,
                                 suffix=suffix,
                                 return_raw=False)
                file_name = get_output_predictions_file_name(
                    args.eval_file, out_path, suffix)
                write_json(score, file_name)
                args.predictions_folder = out_path
            else:
                evaluate(args,
                         model,
                         tokenizer,
                         dataset,
                         examples,
                         features,
                         suffix=os.path.basename(os.path.normpath(model_path)))
Exemple #2
0
def train(**kwargs):
    # doc_stride = kwargs.pop("doc_stride")
    # max_query_length = kwargs.pop('max_query_length')
    # max_seq_length = kwargs.pop("max_seq_length")
    # num_workers = kwargs.pop('num_workers')
    # debug_features = kwargs.pop('debug_features')
    # do_lower_case = not kwargs.pop('do_not_lower_case')
    # kwargs['logging_steps'] = [int(i) for i in kwargs['logging_steps'].split(',')] if kwargs['logging_steps'] else []
    args = Args(**kwargs)
    args.local_rank = int(os.environ.get('LOCAL_RANK', -1))
    logger.debug(args)
    if (os.path.exists(args.save_model_folder)
            and os.listdir(args.save_model_folder)
            and not args.overwrite_output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.save_model_folder))
    # os.makedirs(args.predictions_folder, exist_ok=True)
    os.makedirs(args.save_model_folder, exist_ok=True)

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    if args.local_rank not in [-1, 0]:
        logger.remove()
        logger.add(sys.stdout, level="WARNING")
    logger.warning(
        f"Process rank: {args.local_rank}, device: {device}, n_gpu: "
        f"{args.n_gpu}, distributed training: "
        f"{bool(args.local_rank != -1)}, 16-bits training: {args.fp16}", )

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    args.model_type = args.model_type.lower()

    tokenizer = get_tokenizer(args.model_path, args.do_lower_case)
    model = get_model(args.model_path)

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    model.to(args.device)

    # logger.info("Training/evaluation parameters %s", args)

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum
    # if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
    # Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex

            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    train_dataset = load_or_convert(args, tokenizer, dataset_only=True)
    # train_dataset, e, f = load_examples(args.train_file)
    logger.info("loaded dataset")
    global_step, tr_loss = do_train(args, train_dataset, model, tokenizer)
    logger.info(f"global_step = {global_step}, average loss = {tr_loss}")

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        logger.info(f"Saving model checkpoint to {args.save_model_folder}")
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        # Take care of distributed/parallel training
        model_to_save = model.module if hasattr(model, "module") else model
        model_to_save.save_pretrained(args.save_model_folder)
        tokenizer.save_pretrained(args.save_model_folder)

        # Good practice: save your training arguments together with the trained model
        torch.save(args,
                   os.path.join(args.save_model_folder, "training_args.bin"))

        # Load a trained model and vocabulary that you have fine-tuned
        model = AutoModelForQuestionAnswering.from_pretrained(
            args.save_model_folder)  # , force_download=True)
        tokenizer = AutoTokenizer.from_pretrained(
            args.save_model_folder, do_lower_case=args.do_lower_case)
        model.to(args.device)

    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
    results = {}
    if args.do_eval_after_training and args.local_rank in [-1, 0]:
        logger.info("Loading checkpoints saved during training for evaluation")
        checkpoints = [args.save_model_folder]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.save_model_folder + "/**/" + WEIGHTS_NAME,
                              recursive=True)))
            # logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs

        logger.info(f"Evaluate the following checkpoints: {checkpoints}")
        dataset, examples, features = load_or_convert(args,
                                                      tokenizer,
                                                      evaluate=True)
        for checkpoint in checkpoints:
            # Reload the model
            global_step = checkpoint.split(
                "-")[-1] if len(checkpoints) > 1 else ""
            model = AutoModelForQuestionAnswering.from_pretrained(
                checkpoint)  # , force_download=True)
            model.to(args.device)

            # Evaluate

            result = evaluate(args,
                              model,
                              tokenizer,
                              dataset,
                              examples,
                              features,
                              suffix=global_step)

            result = dict(
                (k + ("_{}".format(global_step) if global_step else ""), v)
                for k, v in result.items())
            results.update(result)

    logger.info("Results: {}".format(results))
    write_json(results, os.path.join(args.save_model_folder,
                                     'dev-results.json'))
    return results
Exemple #3
0
def debug_eval(model_path, model_type, baseline_gold_file, no_cuda,
               do_not_lower_case, per_gpu_eval_batch_size, verbose_logging,
               max_answer_length, max_seq_length, doc_stride, max_query_length,
               num_workers, stfu, predictions_folder):
    eval_files = get_baseline_intervention_control_from_baseline(
        baseline_gold_file)
    model = get_model(model_path)
    do_lower_case = not do_not_lower_case
    tokenizer = get_tokenizer(model_path, do_lower_case)
    processor = SquadV1Processor()
    defs = []

    for eval_file in eval_files:
        data_dir = os.path.dirname(eval_file)
        file_name = os.path.basename(eval_file)
        examples = processor.get_dev_examples(data_dir, filename=file_name)

        features, dataset = squad_convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_training=False,
            return_dataset="pt",
            threads=num_workers,
        )
        defs.append((dataset, examples, features))
    args = Args(model_path,
                model_type,
                per_gpu_eval_batch_size=per_gpu_eval_batch_size,
                max_answer_length=max_answer_length,
                predictions_folder=predictions_folder)

    if args.local_rank == -1 or no_cuda:
        args.device = torch.device(
            "cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
        args.n_gpu = 0 if no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)

    model.to(device=args.device)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        args.device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )
    baseline_dataset, intervention_dataset, control_dataset = defs
    if not stfu:
        debug_features_examples_dataset(*baseline_dataset, tokenizer)
    args.eval_file = f'debug-{model_path}-baseline'
    baseline_predictions = evaluate(args,
                                    model,
                                    tokenizer,
                                    *baseline_dataset,
                                    return_raw=True)
    args.eval_file = f'debug-{model_path}-intervention'
    intervention_predictions = evaluate(args,
                                        model,
                                        tokenizer,
                                        *intervention_dataset,
                                        return_raw=True)

    args.eval_file = f'debug-{model_path}-control'
    control_predictions = evaluate(args,
                                   model,
                                   tokenizer,
                                   *control_dataset,
                                   return_raw=True)
    golds = tuple(load_json(g) for g in eval_files)
    aligneds = align(*golds)
    # obtain predictions on all three of them
    (overall_results, results_baseline, results_intervention, results_control,
     correct_before_intervention, correct_change_correct, correct_keep_wrong,
     correct_change_wrong, wrong_change_right, wrong_keep_right,
     correct_baseline_control,
     correct_baseline_control_intervention) = evaluate_intervention(
         *aligneds, baseline_predictions, intervention_predictions,
         control_predictions)
    print_examples(correct_baseline_control,
                   correct_baseline_control_intervention,
                   correct_change_correct, correct_keep_wrong,
                   correct_change_wrong, wrong_change_right, wrong_keep_right)
    click.echo(f"Got {sum(results_baseline)} correct for baseline.")
    click.echo(f"Got {sum(results_intervention)} correct for intervention.")
    click.echo(
        f"Out of {sum(results_baseline)} correct baseline results, got {len(correct_change_correct)} "
        f"correct after intervention.")
    click.echo(
        f"Out of {len(correct_baseline_control)} correct for both baseline and control "
        f"got {len(correct_baseline_control_intervention)} correct after intervention."
    )
    click.echo(
        f"Interventions that the model 'ignored': {len(correct_keep_wrong)}")
    click.echo(
        f"Interventions that left the model 'confused': {len(correct_change_wrong)}"
    )
    click.echo(
        f"Wrong predictions that the model changed to correct: {len(wrong_change_right)}"
    )
    click.echo(
        f"Wrong predictions that the model didn't change but that became correct: {len(wrong_keep_right)}"
    )
Exemple #4
0
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model_args: ModelArguments
    data_args: DataTrainingArguments
    training_args: TrainingArguments
    if training_args.fp16:
        try:
            import apex

            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    # if training_args.do_eval and not training_args.do_train and not data_args.predictions_folder:
    #     raise ValueError("Supply predictions folder destination to save the predictions!")
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )

    logger.debug(model_args)
    logger.debug(training_args)
    logger.debug(data_args)
    # raise NotImplementedError
    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            f"Use --overwrite_output_dir to overcome.")

    # Set seed
    set_seed(training_args.seed)
    if training_args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
    tokenizer = get_tokenizer(model_args.model_name_or_path,
                              do_lower_case=False)
    if data_args.model_parallel == 4:
        model = T5ForConditionalGeneration4WayParallel.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )
    elif data_args.model_parallel == 2:
        model = T5ForConditionalGeneration2WayParallel.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )
    elif data_args.model_parallel is None:
        model = T5ForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )
    else:
        raise ValueError(
            f"Can only have no, 2way or 4way model parallelism! (expected: {data_args.model_parallel})"
        )
    if training_args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
    # Get datasets
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        eval_dataset, examples = get_dataset(data_args.eval_file_path,
                                             tokenizer,
                                             data_args,
                                             evaluate=True)
    else:
        eval_dataset, examples = None, None
    # Training
    if training_args.do_train:
        if training_args.local_rank in [-1, 0]:
            train_dataset, _ = get_dataset(data_args.train_file_path,
                                           tokenizer, data_args)
            torch.save(train_dataset, 'features.bin')
        else:
            torch.distributed.barrier()
            train_dataset = None

        if training_args.local_rank == 0:
            torch.distributed.barrier()

        else:
            train_dataset = torch.load('features.bin')
        # Initialize our Trainer
        if data_args.model_parallel:
            trainer = MyTrainer(model=model,
                                args=training_args,
                                train_dataset=train_dataset,
                                eval_dataset=eval_dataset,
                                data_collator=collate_training,
                                prediction_loss_only=True)
            model.set_parallel()
        else:
            trainer = Trainer(model=model,
                              args=training_args,
                              train_dataset=train_dataset,
                              eval_dataset=eval_dataset,
                              data_collator=collate_training,
                              prediction_loss_only=True)
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        if training_args.do_train:
            model_path = os.path.basename(training_args.output_dir)
        else:
            model_path = os.path.basename(model_args.model_name_or_path)
        checkpoints = [training_args.output_dir]
        if data_args.eval_all_checkpoints and training_args.do_train:
            logger.info(
                "Loading checkpoints saved during training for evaluation")
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(training_args.output_dir + "/**/" + WEIGHTS_NAME,
                              recursive=True)))
            # logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs

        logger.info(f"Evaluate the following checkpoints: {checkpoints}")
        results = {}

        logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)
        for checkpoint in checkpoints:
            # Reload the model
            global_step = checkpoint.split("-")[-1]
            if not all(s in string.digits for s in global_step):
                global_step = ''
            # no model parallelism here (didnt check model.generate)
            model = T5ForConditionalGeneration.from_pretrained(checkpoint)
            device = torch.device("cuda" if torch.cuda.is_available()
                                  and not training_args.no_cuda else "cpu")
            model.to(device)
            model_str = f'{model_path}-{global_step}' if global_step else model_path
            # Note that DistributedSampler samples
            click.echo(
                f"Generating predictions for model {click.style(model_str, fg='blue')}, "
                f"running on {click.style(str(training_args.device), fg='green')}"
            )
            predictions = generate_predictions(eval_dataset, examples, model,
                                               tokenizer, training_args)
            final_metric = squad_evaluate(examples, predictions)

            if is_wandb_available():
                if training_args.do_train:
                    step = int(
                        global_step) if global_step else trainer.global_step
                else:
                    step = 0
                # for now WANDB cannot 'log back in time'
                wandb.log(final_metric, step=step)
            print(f"GLOBAL STEP: {global_step}")
            result = dict(
                (k + ("_{}".format(global_step) if global_step else '_final'),
                 v) for k, v in final_metric.items())

            logger.info(f"Result for {model_str}: {result}")
            results.update(result)

        # sort results by best
        checkpoint_scores = {
            c.split('_')[-1]: v
            for c, v in results.items()
            if any(c.endswith(digit)
                   for digit in string.digits) and c.startswith('exact')
        }
        sorted_checkpoint_scores = {
            k: v
            for k, v in sorted(checkpoint_scores.items(),
                               key=lambda k_v: k_v[1],
                               reverse=True)
        }
        best_cp = next((c for c, v in sorted_checkpoint_scores.items()
                        if v > results['exact_final']), None)

        if best_cp:
            click.echo(f"Best checkpoint is: {best_cp}")
            # copy over best results
            best_cp_folder = f'checkpoint-{best_cp}'

            click.echo(
                f"Copying over files: from {os.path.join(training_args.output_dir, best_cp_folder)} "
                f"to {training_args.output_dir}")
            files_to_copy = glob.glob(
                os.path.join(training_args.output_dir, best_cp_folder, '*'))
            for file in files_to_copy:
                shutil.copy(file, training_args.output_dir)
        else:
            click.echo("best checkpoint is the last step...")
        # remove 'kek'points
        folders_to_remove = [
            p for p in glob.glob(os.path.join(training_args.output_dir, '*'))
            if os.path.isdir(p)
        ]
        click.echo('Folders to remove: ')
        for folder in folders_to_remove:
            click.echo(f"Removing {folder}")
            shutil.rmtree(folder)
        if training_args.do_train:
            logger.info(results)
            write_json(
                results,
                os.path.join(training_args.output_dir, 'dev-results.json'))
        else:
            write_json(
                predictions,
                get_output_predictions_file_name(
                    data_args.eval_file_path, training_args.output_dir,
                    os.path.basename(
                        os.path.normpath(model_args.model_name_or_path))))
Exemple #5
0
def finetune(optimize_consistency, evaluate_on, original_dev_dataset,
             runs_per_trial, hyperparam_opt_runs, out_file, mute,
             baseline_gold_file, hyperparams, keep_predictions,
             original_ans_length, **kwargs):
    gold_files = get_baseline_intervention_control_from_baseline(
        baseline_gold_file)

    golds = tuple(load_json(g) for g in gold_files)
    # load eval gold for evaluation
    aligneds = align(*golds, assert_same=True)

    hyper_params = [{
        'name': hp['name'],
        'type': hp.get("type", 'range'),
        'bounds': hp['bounds'],
        'value_type': hp.get('value_type', 'float'),
        'log_scale': hp.get('log_scale', True)
    } for hp in json.loads(hyperparams)]

    logger.info(hyper_params)

    args = Args(**kwargs)

    args.debug_features = not mute
    tokenizer = get_tokenizer(args.model_path, args.do_lower_case)
    features = []
    for f in gold_files:
        args.eval_file = f
        features.append(load_or_convert(args, tokenizer, evaluate=True))
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        kwargs['n_gpu'] = 0 if args.no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        kwargs['n_gpu'] = 1
    kwargs['device'] = device
    args.n_gpu = kwargs['n_gpu']
    args.device = kwargs['device']
    if args.seed:
        set_seed(args)
    logger.debug(args)

    if args.fp16:
        try:
            import apex

            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    # load train dataset

    train_dataset, train_examples, train_features = load_or_convert(
        args, tokenizer)
    if not mute:
        debug_features_examples_dataset(train_dataset, train_examples,
                                        train_features, tokenizer)
    if original_dev_dataset:
        args.eval_file = original_dev_dataset
        original_dev_dataset = load_or_convert(args, tokenizer, evaluate=True)
    ax_client = AxClient()
    ax_client.create_experiment(
        name=f'{args.model_path}@{args.train_file}',
        parameters=hyper_params,
        objective_name=evaluate_on,
        minimize=False,
    )
    result = {
        "trials": [],
        "tried_params":
        defaultdict(list),
        "best_params":
        ...,
        'pre_eval':
        train_and_eval_single_step(args,
                                   train_dataset,
                                   *aligneds,
                                   *features,
                                   original_dev_dataset,
                                   *gold_files,
                                   run_nr='eval',
                                   train=False,
                                   evaluate_on=evaluate_on,
                                   original_ans_length=original_ans_length)
    }
    # first, eval and save what is the performance before training

    click.echo(f"Results: {json.dumps(result['pre_eval'], indent=4)}")
    # run hyperparam optimisation
    predictions_folder = keep_predictions
    for i in trange(hyperparam_opt_runs):
        parameters, trial_index = ax_client.get_next_trial()
        logger.info(f"Trying parameters: {parameters}")
        single_step_args = deepcopy(kwargs)
        single_step_args.update(parameters)
        args = Args(**single_step_args)
        args.predictions_folder = str(predictions_folder)
        trial_result = train_and_eval_single_step(
            args,
            train_dataset,
            *aligneds,
            *features,
            original_dev_dataset,
            *gold_files,
            run_nr=i,
            num_runs=runs_per_trial,
            evaluate_on=evaluate_on,
            original_ans_length=original_ans_length)
        #
        if optimize_consistency:
            assert evaluate_on == 'eoi'
            mean = trial_result['consistency']
        else:
            mean = trial_result['overall' if evaluate_on ==
                                'eoi' else 'EMRelaxed']
        if runs_per_trial > 1:
            mean, var, ci = mean
        if original_dev_dataset:
            logger.info(f"Mean: ({mean} * 100 + {trial_result['original']})/2")
            mean = (mean * 100 + trial_result['original']) / 2

        trial_result["mean"] = mean

        logger.info(f"Result: {mean}")
        logger.info(f"Results: {json.dumps(trial_result, indent=4)}")
        result["trials"].append(trial_result)
        result['tried_params'][i].append(parameters)
        ax_client.complete_trial(trial_index=trial_index, raw_data=mean)
    best_params, metrics = ax_client.get_best_parameters()
    result['best_params'] = best_params
    result['best_metrics'] = metrics
    click.echo(f"What is metrics? {metrics}")
    click.echo(json.dumps(result, indent=4))
    write_json(result, out_file)
Exemple #6
0
def train_and_eval_single_step(args: Args,
                               train_dataset,
                               aligned_baseline,
                               aligned_intervention,
                               aligned_control,
                               baseline_dataset,
                               intervention_dataset,
                               control_dataset,
                               original_dev_dataset,
                               baseline_gold_path,
                               intervention_gold_path,
                               control_gold_path,
                               run_nr=0,
                               train=True,
                               num_runs=1,
                               evaluate_on='eoi',
                               original_ans_length=30):
    results = []
    for i in range(num_runs):
        set_seed(args)
        # load model
        model = get_model(args.model_path)
        tokenizer = get_tokenizer(args.model_path,
                                  do_lower_case=args.do_lower_case)
        model.to(args.device)
        # train

        if train:
            step, loss = do_train(args, train_dataset, model, tokenizer)

        args.eval_file = baseline_gold_path
        if evaluate_on == 'eoi' or evaluate_on == 'baseline':
            baseline_predictions = evaluate(args,
                                            model,
                                            tokenizer,
                                            *baseline_dataset,
                                            f'baseline-{run_nr}',
                                            return_raw=True)
        args.eval_file = intervention_gold_path

        if evaluate_on == 'eoi' or evaluate_on == 'intervention':
            intervention_predictions = evaluate(args,
                                                model,
                                                tokenizer,
                                                *intervention_dataset,
                                                f'intervention-{run_nr}',
                                                return_raw=True)

        args.eval_file = control_gold_path
        if evaluate_on == 'eoi':
            control_predictions = evaluate(args,
                                           model,
                                           tokenizer,
                                           *control_dataset,
                                           f'control-{run_nr}',
                                           return_raw=True)

            # obtain predictions on all three of them
            (overall_results, results_baseline, results_intervention,
             results_control, correct_before_intervention,
             correct_change_correct, correct_keep_wrong, correct_change_wrong,
             wrong_change_right, wrong_keep_right, correct_baseline_control,
             correct_baseline_control_intervention) = evaluate_intervention(
                 aligned_baseline, aligned_intervention, aligned_control,
                 baseline_predictions, intervention_predictions,
                 control_predictions)

            mean, *_ = get_mean_var_ci_bernoulli(overall_results)
            # there's no point to evaluate multiple times if not training

            result = {
                "overall":
                mean,
                'consistency':
                len(correct_change_correct) / len(aligned_baseline),
                'consistency+control':
                len(correct_baseline_control_intervention) /
                len(aligned_baseline),
                "acc_baseline":
                sum(results_baseline) / len(results_baseline),
                "acc_intervention":
                sum(results_intervention) / len(results_intervention),
                "acc_control":
                sum(results_control) / len(results_control),
                'correct->change->correct':
                len(correct_change_correct),
                'correct(baseline+control)/correct(baseline)':
                len(correct_baseline_control) / sum(results_baseline),
                'correct+control->change->correct':
                len(correct_baseline_control_intervention),
            }
        elif evaluate_on == 'baseline':
            metric_results = EMRelaxed(max_length=args.max_answer_length)(
                aligned_baseline, baseline_predictions)
            result = {
                "EMRelaxed": get_mean_var_ci_bernoulli(metric_results)[0]
            }
        elif evaluate_on == 'intervention':
            metric_results = EMRelaxed(max_length=args.max_answer_length)(
                aligned_baseline, baseline_predictions)
            result = {
                "EMRelaxed": get_mean_var_ci_bernoulli(metric_results)[0]
            }
        else:
            raise NotImplementedError()
        if original_dev_dataset is not None:
            # there is surely a better way to do this
            ans_length = args.max_answer_length
            args.max_answer_length = original_ans_length
            original_dev_result = evaluate(args,
                                           model,
                                           tokenizer,
                                           *original_dev_dataset,
                                           f'original-dev-{run_nr}',
                                           return_raw=False)
            args.max_answer_length = ans_length
            result['original'] = original_dev_result['exact']
        results.append(result)
    if num_runs == 1:
        return results[0]

    if evaluate_on == 'eoi':
        final_result = {
            "overall":
            get_mean_var_ci([r['overall'] for r in results]),
            "acc_baseline":
            get_mean_var_ci([r['acc_baseline'] for r in results]),
            "acc_intervention":
            get_mean_var_ci([r['acc_intervention'] for r in results]),
            "acc_control":
            get_mean_var_ci([r['acc_control'] for r in results]),
            'correct->change->correct':
            get_mean_var_ci([r['correct->change->correct'] for r in results]),
            'correct(baseline+control)/correct(baseline):':
            get_mean_var_ci([
                r['correct(baseline+control)/correct(baseline)']
                for r in results
            ]),
            'correct+control->change->correct':
            get_mean_var_ci(
                [r['correct+control->change->correct'] for r in results]),
        }
    else:
        final_result = {
            key: get_mean_var_ci([r[key] for r in results])
            for key in results[0].keys()
        }
    logger.info(final_result)

    return final_result