コード例 #1
0
def compute_gen_metrics_wrapper(output_file_path):
    """ Computes evaluation metrics for generative models directly on model output. """
    def _read_jsonl(input_file):
        """ Reads a .jsonl file. """
        records = []
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in f:
                records.append(json.loads(line))
        return records

    # Read-in results file
    file_records = _read_jsonl(output_file_path)
    inputs = list()
    preds = list()
    targets = list()

    for frec in file_records:
        inputs.append(frec['prefix'])
        preds.append(frec['prediction'])
        targets.append(frec['target'])

    # Compute metrics
    gen_metrics = compute_gen_metrics(preds, targets)

    # Report
    print('***** Test results *****')
    for key in sorted(gen_metrics.keys()):
        print('  %s = %s', key, str(gen_metrics[key]))
コード例 #2
0
def action_gen_with_ranking(args, split):
    """ Generates moral / immoral actions by ranking a set of hypotheses. """

    # Use pre-trained action generation model with nucleus sampling and high-p (e.g. 0.95) to generate N (e.g. 10)
    # actions from the same story prefix
    # Instantiate model and tokenizer
    args.p = 0.90
    args.action_draft_model_type = args.action_draft_model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.action_draft_model_type]
    global_step = int(args.action_generator_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.action_generator_checkpoint)
    tokenizer = tokenizer_class.from_pretrained(args.action_generator_checkpoint)
    model.to(args.device)

    specified_batch_size = args.per_gpu_eval_batch_size

    # Set up data-serving pipeline
    eval_dataset = load_and_cache_examples(args, tokenizer, split, 'action|context_gen', args.action_draft_model_type)
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)
    args.per_gpu_eval_batch_size = 1  # set batch size to 1
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Generate action predictions
    logging.info('\n' + '*' * 20)
    logging.info('Generating initial action predictions using the model from checkpoint {}'
                 .format(args.action_generator_checkpoint))
    logging.info('*' * 20 + '\n')

    action_predictions, generation_inputs, generation_targets = test_gen(args,
                                                                         model,
                                                                         tokenizer,
                                                                         eval_dataset,
                                                                         eval_dataloader,
                                                                         'action|context_gen',
                                                                         args.action_draft_model_type,
                                                                         split,
                                                                         global_step,
                                                                         gen_per_prefix=args.num_actions)

    # Sort action predictions
    action_pass_generations = {pass_id: list() for pass_id in range(args.num_actions)}
    for batch_generations in action_predictions:
        for pi in range(args.num_actions):
            action_pass_generations[pi] += batch_generations[pi]

    # Rank predicted actions using a pre-trained classifier
    logging.info('\n' + '*' * 20)
    logging.info('Ranking action predictions using a pre-trained classifier')
    logging.info('*' * 20 + '\n')
    # Load classifier
    args.action_classifier_model_type = args.action_classifier_model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.action_classifier_model_type]
    global_step = int(args.action_classifier_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.action_classifier_checkpoint)
    try:
        tokenizer = tokenizer_class.from_pretrained(args.action_classifier_checkpoint)
    except Exception:
        tokenizer = tokenizer_class.from_pretrained('roberta-large')  # hack
    model.to(args.device)

    action_consequence_table = {ex_id: list() for ex_id in range(len(action_pass_generations[0]))}
    for pass_id in range(args.num_actions):
        initial_action_predictions = action_pass_generations[pass_id]
        # Set up data-serving pipeline
        eval_dataset = load_and_cache_examples(args, tokenizer, split, 'action+context_cls',
                                               args.consequence_generation_model_type,
                                               predictions=['actions', initial_action_predictions])
        eval_sampler = SequentialSampler(eval_dataset)
        args.per_gpu_eval_batch_size = specified_batch_size
        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        # Obtain clasisfier predictions
        results, mean_eval_loss, softmax_scores = \
            evaluate(args, model, eval_dataset, eval_dataloader, 'action+context_cls',
                     args.action_classifier_model_type, split, global_step)
        # Assign scores to actions
        for act_id, action in enumerate(initial_action_predictions):
            score_id = 1 if act_id % 2 == 0 else 0
            action_consequence_table[act_id].append((action, None, softmax_scores[act_id][score_id]))

    # Return the action corresponding to the 'most good' consequence as generation output
    logging.info('\n' + '*' * 20)
    logging.info('Picking the best actions from the predicted alternatives')
    logging.info('*' * 20 + '\n')
    # Log predictions
    best_predictions = list()
    output_pred_file = \
        os.path.join(args.output_dir, 'best_ranked_actions_{}.lst'.format(split))
    with open(output_pred_file, 'w') as writer:
        logger.info('***** Write predictions *****')
        for story_id, ac_list in action_consequence_table.items():
            # Sort action predictions
            sorted_actions = sorted(ac_list, reverse=True, key=lambda x: x[2])
            best_predictions.append(sorted_actions[0][0])
            # Write to file
            writer.write(json.dumps({'prefix': generation_inputs[story_id],
                                     'target': generation_targets[story_id],
                                     'prediction': sorted_actions[0][0]}) + '\n')
            if story_id < 10:
                logging.info('***** Example ranking *****')
                logging.info('Story prefix: {:s}'.format(generation_inputs[story_id]))
                logging.info('Gold reference action: {:s}'.format(generation_targets[story_id]))
                logging.info('Ranked action predictions by the generator:')
                for tpl_id, tpl in enumerate(sorted_actions):
                    logging.info('-' * 10)
                    logging.info('Rank {:d}'.format(tpl_id))
                    logging.info('Predicted action: {:s}'.format(tpl[0]))
                    logging.info('Anticipated consequence: {:s}'.format(str(tpl[1])))
                    logging.info('Score: {:.4f}'.format(tpl[2]))

    # Compute and update evaluation metric values
    best_result = compute_gen_metrics(best_predictions, generation_targets)
    # Log metrics
    output_eval_file = os.path.join(args.output_dir, 'generation_test_results_{}_{}_{}.txt'.format(
        'action_refinement', split, global_step))
    with open(output_eval_file, 'w') as writer:
        logger.info('***** Test results (best actions) *****')
        writer.write('STEP: {:s}\n'.format(str(global_step)))
        for key in sorted(best_result.keys()):
            logger.info('  %s = %s', key, str(best_result[key]))
            writer.write('%s = %s\n' % (key, str(best_result[key])))
コード例 #3
0
def test_gen(args, model, tokenizer, dataset, dataloader, task_name, model_type, split, step, gen_per_prefix=1):
    """ Evaluates generative models. """

    # Test!
    logger.info('***** Testing generation on the test set *****')
    logger.info('  Num examples = %d', len(dataset))
    logger.info('  Batch size = %d', args.eval_batch_size)
    generation_inputs = list()
    generation_targets = list()
    generated_sequences = list()

    # Iterate through the test corpus
    model.eval()
    for batch_id, batch in enumerate(tqdm(dataloader, desc='Testing', mininterval=10, ncols=100)):
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'labels': batch[3],
                      'gen_prompt': batch[5]}

            # Modify inputs (assumes a batch size of 1)
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']
            if model_type == 'gpt2' or 'action' in task_name:
                pad_token_id = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
                if pad_token_id is None:
                    pad_token_id = tokenizer.convert_tokens_to_ids([tokenizer.eos_token])[0]
                try:
                    input_ids = inputs['input_ids'].tolist()
                    first_pad_idx = input_ids[0].index(pad_token_id)
                    input_ids = torch.tensor([input_ids[0][: first_pad_idx]], dtype=torch.long).to(args.device)
                    attention_mask = inputs['attention_mask'].tolist()
                    attention_mask = \
                        torch.tensor([attention_mask[0][: first_pad_idx]], dtype=torch.long).to(args.device)
                except ValueError:
                    input_ids = inputs['input_ids']
                    attention_mask = inputs['attention_mask']

            max_gen_length = args.max_gen_length
            if model_type == 'gpt2':
                max_gen_length += torch.max(torch.sum(attention_mask, axis=-1)).item()
            batch_generations = list()
            for _ in range(gen_per_prefix):
                if model_type == 'gpt2':
                    outputs = model.generate(input_ids=input_ids,
                                             attention_mask=attention_mask,
                                             min_length=5,
                                             max_length=max_gen_length,
                                             temperature=args.temperature,
                                             top_k=args.k if args.k > 0 else None,
                                             top_p=args.p if args.p > 0 else None,
                                             num_beams=args.num_beams if args.num_beams > 0 else None,
                                             do_sample=args.do_sample,
                                             early_stopping=True,
                                             no_repeat_ngram_size=3)
                else:
                    gen_prompt = \
                        inputs['gen_prompt'].item() if 'action|' in task_name else inputs['gen_prompt'][0].item()
                    outputs = model.generate(input_ids=input_ids,
                                             attention_mask=attention_mask,
                                             min_length=5,
                                             max_length=max_gen_length,
                                             temperature=args.temperature,
                                             top_k=args.k if args.k > 0 else None,
                                             top_p=args.p if args.p > 0 else None,
                                             num_beams=args.num_beams if args.num_beams > 0 else None,
                                             do_sample=args.do_sample,
                                             early_stopping=True,
                                             no_repeat_ngram_size=3,
                                             decoder_start_token_id=gen_prompt)

                # Remove the batch dimension when returning multiple sequences
                if len(outputs.shape) > 2:
                    outputs.squeeze_()
                batch_generations.append(outputs)

        # Convert model predictions to text sequences
        input_ids = inputs['input_ids'].tolist()
        target_ids = inputs['labels'].tolist()
        # Post-process model predictions and prediction targets
        batch_predictions = list()
        len_gen_input = 0
        for pass_id, pass_output in enumerate(batch_generations):
            pass_predictions = list()
            for generated_sequence_idx, generated_sequence in enumerate(pass_output):
                generated_sequence = generated_sequence.tolist()
                # GPT2
                if model_type == 'gpt2':
                    if pass_id == 0:
                        # Prepare inputs
                        gen_input = input_ids[generated_sequence_idx]
                        try:
                            gen_input = gen_input[: gen_input.index(tokenizer.eos_token_id)]
                        except ValueError:
                            pass
                        generation_inputs.append(tokenizer.decode(gen_input, clean_up_tokenization_spaces=True))
                        len_gen_input = len(gen_input)

                    # Prepare predictions
                    generated_sequence = generated_sequence[len_gen_input:]
                    try:
                        if generated_sequence.index(tokenizer.eos_token_id) == 0:
                            generated_sequence = generated_sequence[1:]
                        generated_sequence = generated_sequence[: generated_sequence.index(tokenizer.eos_token_id)]
                    except ValueError:
                        pass
                    pass_predictions.append(tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True))

                    if pass_id == 0:
                        # Prepare generation targets
                        gen_target = target_ids[generated_sequence_idx][len_gen_input:]
                        try:
                            gen_target = gen_target[: gen_target.index(tokenizer.eos_token_id)]
                        except ValueError:
                            pass
                        generation_targets.append(tokenizer.decode(gen_target, clean_up_tokenization_spaces=True))

                # For T5, split-off the initial <pad> token
                if model_type in ['t5', 'bart']:
                    # Prepare predictions
                    try:
                        generated_sequence = generated_sequence[: generated_sequence.index(tokenizer.eos_token_id)]
                    except ValueError:
                        pass
                    pass_predictions.append(
                        tokenizer.decode(generated_sequence[1:], clean_up_tokenization_spaces=True))

                    if pass_id == 0:
                        # Prepare inputs
                        gen_input = input_ids[generated_sequence_idx]
                        try:
                            if model_type == 't5':
                                gen_input = gen_input[: gen_input.index(tokenizer.eos_token_id)]
                            else:
                                gen_input = gen_input[: gen_input.index(tokenizer.pad_token_id)]
                        except ValueError:
                            pass
                        generation_inputs.append(tokenizer.decode(gen_input, clean_up_tokenization_spaces=True))

                        # Prepare generation targets
                        gen_target = target_ids[generated_sequence_idx]
                        try:
                            gen_target = gen_target[: gen_target.index(tokenizer.eos_token_id)]
                        except ValueError:
                            pass
                        generation_targets.append(tokenizer.decode(gen_target[1:], clean_up_tokenization_spaces=True))

            batch_predictions.append(pass_predictions)
        generated_sequences.append(batch_predictions)

    # Report sample generation results
    first_pass_predictions = list()
    for bp in generated_sequences:
        first_pass_predictions += bp[0]
    logging.info('***** Example generations (first pass only) *****')
    for s_id, gen_input in enumerate(generation_inputs):
        if s_id >= 10:
            break
        logging.info('  Inputs: {:s}'.format(gen_input))
        logging.info('  Reference: {:s}'.format(generation_targets[s_id]))
        logging.info('  Prediction: {:s}'.format(first_pass_predictions[s_id]))

    # Compute and update evaluation metric values
    curr_result = compute_gen_metrics(first_pass_predictions, generation_targets)

    # Log metrics
    output_eval_file = \
        os.path.join(args.output_dir, 'generation_test_results_{}_{}_{}.txt'.format(task_name, split, step))
    with open(output_eval_file, 'w') as writer:
        logger.info('***** Test results (first pass only) *****')
        writer.write('STEP: {:s}\n'.format(str(step)))
        for key in sorted(curr_result.keys()):
            logger.info('  %s = %s', key, str(curr_result[key]))
            writer.write('%s = %s\n' % (key, str(curr_result[key])))

    # Log predictions
    output_pred_file = \
        os.path.join(args.output_dir, 'generation_test_predictions_{}_{}_{}.lst'.format(task_name, split, step))
    with open(output_pred_file, 'w') as writer:
        logger.info('***** Write predictions (first pass only) *****')
        for gsi, gs in enumerate(first_pass_predictions):
            writer.write(json.dumps({'prefix': generation_inputs[gsi],
                                     'target': generation_targets[gsi],
                                     'prediction': gs}) + '\n')

    # For simplicity
    if gen_per_prefix == 1:
        generated_sequences = first_pass_predictions

    return generated_sequences, generation_inputs, generation_targets
def action_refinement_with_ranking(args, split):
    """ Generates moral / immoral actions by taking into account their anticipated consequences. """

    # Generate action draft hypotheses
    args.p = 0.90
    args.action_generator_model_type = args.action_generator_model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[
        args.action_generator_model_type]
    global_step = int(args.action_generator_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.action_generator_checkpoint)
    tokenizer = tokenizer_class.from_pretrained(
        args.action_generator_checkpoint)
    model.to(args.device)

    specified_batch_size = args.per_gpu_eval_batch_size

    # Set up data-serving pipeline
    eval_dataset = load_and_cache_examples(args, tokenizer, split,
                                           'action|context_gen',
                                           args.action_generator_model_type)
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)
    args.per_gpu_eval_batch_size = 1  # set batch size to 1
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # Generate action predictions
    logging.info('\n' + '*' * 20)
    logging.info(
        'Generating initial action predictions using the model from checkpoint {}'
        .format(args.action_generator_checkpoint))
    logging.info('*' * 20 + '\n')

    action_predictions, generation_inputs, generation_targets = test_gen(
        args,
        model,
        tokenizer,
        eval_dataset,
        eval_dataloader,
        'action|context_gen',
        args.action_generator_model_type,
        split,
        global_step,
        gen_per_prefix=args.num_actions)

    # ==================================================================================================================

    # Sort action predictions
    action_pass_generations = {
        pass_id: list()
        for pass_id in range(args.num_actions)
    }
    for batch_generations in action_predictions:
        for pi in range(args.num_actions):
            action_pass_generations[pi] += batch_generations[pi]

    # Rank predicted actions using a pre-trained classifier
    logging.info('\n' + '*' * 20)
    logging.info(
        'Ranking initial action predictions using a pre-trained classifier')
    logging.info('*' * 20 + '\n')
    # Load classifier
    args.action_classifier_model_type = args.action_classifier_model_type.lower(
    )
    config_class, model_class, tokenizer_class = MODEL_CLASSES[
        args.action_classifier_model_type]
    global_step = int(args.action_classifier_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.action_classifier_checkpoint)
    try:
        tokenizer = tokenizer_class.from_pretrained(
            args.action_classifier_checkpoint)
    except Exception:
        tokenizer = tokenizer_class.from_pretrained('roberta-large')  # hack
    model.to(args.device)

    action_score_table = {
        ex_id: list()
        for ex_id in range(len(action_pass_generations[0]))
    }
    for pass_id in range(args.num_actions):
        initial_action_predictions = action_pass_generations[pass_id]
        # Set up data-serving pipeline
        eval_dataset = load_and_cache_examples(
            args,
            tokenizer,
            split,
            'action+context_cls',
            args.action_classifier_model_type,
            predictions=['actions', initial_action_predictions])
        eval_sampler = SequentialSampler(eval_dataset)
        args.per_gpu_eval_batch_size = specified_batch_size
        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # Obtain classifier predictions
        results, mean_eval_loss, softmax_scores = \
            evaluate(args, model, eval_dataset, eval_dataloader, 'action+context_cls',
                     args.action_classifier_model_type, split, global_step)
        # Assign scores to actions
        for act_id, action in enumerate(initial_action_predictions):
            score_id = 1 if act_id % 2 == 0 else 0
            action_score_table[act_id].append(
                (action, None, softmax_scores[act_id][score_id]))

    # ==================================================================================================================

    # Identify best predicted action
    logging.info('\n' + '*' * 20)
    logging.info('Picking the best actions from the predicted alternatives')
    logging.info('*' * 20 + '\n')
    # Log predictions
    best_action_predictions = list()
    for story_id, ac_list in action_score_table.items():
        # Sort action predictions
        sorted_actions = sorted(ac_list, reverse=True, key=lambda x: x[2])
        best_action_predictions.append(sorted_actions[0][0])

    # ==================================================================================================================

    # Include model generations into the input data for the consequence|story+action generator
    args.consequence_generator_model_type = args.consequence_generator_model_type.lower(
    )
    config_class, model_class, tokenizer_class = MODEL_CLASSES[
        args.consequence_generator_model_type]
    global_step = int(args.consequence_generator_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.consequence_generator_checkpoint)
    tokenizer = tokenizer_class.from_pretrained(
        args.consequence_generator_checkpoint)
    model.to(args.device)

    # Set up data-serving pipeline
    eval_dataset = load_and_cache_examples(
        args,
        tokenizer,
        split,
        'consequence|action+context_gen',
        args.consequence_generator_model_type,
        predictions=['actions', best_action_predictions])
    eval_sampler = SequentialSampler(eval_dataset)
    if args.consequence_generator_model_type == 'gpt2':
        args.per_gpu_eval_batch_size = 1
    else:
        args.per_gpu_eval_batch_size = specified_batch_size
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # Generate consequences for previously predicted actions using the consequence|story+action generator
    logging.info('\n' + '*' * 20)
    logging.info(
        'Generating consequences for the initial actions using the model from checkpoint {}'
        .format(args.consequence_generator_checkpoint))
    logging.info('*' * 20 + '\n')

    consequence_predictions, generation_inputs, generation_targets = test_gen(
        args,
        model,
        tokenizer,
        eval_dataset,
        eval_dataloader,
        'consequence|action+context_gen',
        args.consequence_generator_model_type,
        split,
        global_step,
        gen_per_prefix=args.num_actions)

    # ==================================================================================================================

    # Sort consequence predictions
    consequence_pass_generations = {
        pass_id: list()
        for pass_id in range(args.num_actions)
    }
    for batch_generations in consequence_predictions:
        for pi in range(args.num_actions):
            consequence_pass_generations[pi] += batch_generations[pi]

    # Rank predicted consequences using a pre-trained classifier
    logging.info('\n' + '*' * 20)
    logging.info(
        'Ranking synthetic consequence predictions using a pre-trained classifier'
    )
    logging.info('*' * 20 + '\n')
    # Load classifier
    args.consequence_classifier_model_type = args.consequence_classifier_model_type.lower(
    )
    config_class, model_class, tokenizer_class = MODEL_CLASSES[
        args.consequence_classifier_model_type]
    global_step = int(args.consequence_classifier_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.consequence_classifier_checkpoint)
    try:
        tokenizer = tokenizer_class.from_pretrained(
            args.consequence_classifier_checkpoint)
    except Exception:
        tokenizer = tokenizer_class.from_pretrained('roberta-large')  # hack
    model.to(args.device)

    consequence_score_table = {
        ex_id: list()
        for ex_id in range(len(consequence_pass_generations[0]))
    }
    for pass_id in range(args.num_actions):
        initial_consequence_predictions = consequence_pass_generations[pass_id]
        # Set up data-serving pipeline
        eval_dataset = load_and_cache_examples(
            args,
            tokenizer,
            split,
            'consequence+action+context_cls',
            args.consequence_classifier_model_type,
            predictions=[('actions', best_action_predictions),
                         ('consequences', initial_consequence_predictions)])
        eval_sampler = SequentialSampler(eval_dataset)
        args.per_gpu_eval_batch_size = specified_batch_size
        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # Obtain classifier predictions
        results, mean_eval_loss, softmax_scores = \
            evaluate(args, model, eval_dataset, eval_dataloader, 'consequence+action+context_cls',
                     args.consequence_classifier_model_type, split, global_step)
        # Assign scores to consequences
        for csq_id, consequence in enumerate(initial_consequence_predictions):
            consequence_score_table[csq_id].append(
                (consequence, None, softmax_scores[csq_id][1]))

    # ==================================================================================================================

    # Identify best predicted consequence
    logging.info('\n' + '*' * 20)
    logging.info(
        'Picking the best consequences from the predicted alternatives')
    logging.info('*' * 20 + '\n')
    # Log predictions
    best_consequence_predictions = list()
    for story_id, ac_list in consequence_score_table.items():
        # Sort consequence predictions
        sorted_consequences = sorted(ac_list, reverse=True, key=lambda x: x[2])
        best_consequence_predictions.append(sorted_consequences[0][0])

    # ==================================================================================================================

    # Include model generations into the input data for the action|story+consequence generator
    args.action_refiner_model_type = args.action_refiner_model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[
        args.action_refiner_model_type]
    global_step = int(args.action_refiner_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.action_refiner_checkpoint)
    tokenizer = tokenizer_class.from_pretrained(args.action_refiner_checkpoint)
    model.to(args.device)

    # Set up data-serving pipeline
    eval_dataset = load_and_cache_examples(
        args,
        tokenizer,
        split,
        'action|context+consequence_gen',
        args.action_refiner_model_type,
        predictions=['consequences', best_consequence_predictions])
    eval_sampler = SequentialSampler(eval_dataset)
    args.per_gpu_eval_batch_size = 1  # set batch size to 1
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # Generate action predictions
    logging.info('\n' + '*' * 20)
    logging.info(
        'Generating refined action predictions using the model from checkpoint {}'
        .format(args.action_generator_checkpoint))
    logging.info('*' * 20 + '\n')

    action_refinements, generation_inputs, generation_targets = test_gen(
        args,
        model,
        tokenizer,
        eval_dataset,
        eval_dataloader,
        'action|context+consequence_gen',
        args.action_refiner_model_type,
        split,
        global_step,
        gen_per_prefix=args.num_actions)

    # ==================================================================================================================

    # Sort action predictions
    action_pass_generations = {
        pass_id: list()
        for pass_id in range(args.num_actions)
    }
    for batch_generations in action_refinements:
        for pi in range(args.num_actions):
            action_pass_generations[pi] += batch_generations[pi]

    # Rank predicted actions using a pre-trained classifier
    logging.info('\n' + '*' * 20)
    logging.info(
        'Ranking refined action predictions using a pre-trained classifier')
    logging.info('*' * 20 + '\n')
    # Load classifier
    args.action_classifier_model_type = args.action_classifier_model_type.lower(
    )
    config_class, model_class, tokenizer_class = MODEL_CLASSES[
        args.action_classifier_model_type]
    global_step = int(args.action_classifier_checkpoint.split('-')[-1])
    model = model_class.from_pretrained(args.action_classifier_checkpoint)
    try:
        tokenizer = tokenizer_class.from_pretrained(
            args.action_classifier_checkpoint)
    except Exception:
        tokenizer = tokenizer_class.from_pretrained('roberta-large')  # hack
    model.to(args.device)

    refined_action_score_table = {
        ex_id: list()
        for ex_id in range(len(action_pass_generations[0]))
    }
    for pass_id in range(args.num_actions):
        refined_action_predictions = action_pass_generations[pass_id]
        # Set up data-serving pipeline
        eval_dataset = load_and_cache_examples(
            args,
            tokenizer,
            split,
            'action+context_cls',
            args.action_classifier_model_type,
            predictions=['actions', refined_action_predictions])
        eval_sampler = SequentialSampler(eval_dataset)
        args.per_gpu_eval_batch_size = specified_batch_size
        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # Obtain classifier predictions
        results, mean_eval_loss, softmax_scores = \
            evaluate(args, model, eval_dataset, eval_dataloader, 'action+context_cls',
                     args.action_classifier_model_type, split, global_step)
        # Assign scores to actions
        for act_id, action in enumerate(refined_action_predictions):
            score_id = 1 if act_id % 2 == 0 else 0
            refined_action_score_table[act_id].append(
                (action, None, softmax_scores[act_id][score_id]))

    # ==================================================================================================================

    # Identify best predicted action
    logging.info('\n' + '*' * 20)
    logging.info(
        'Picking the best refined actions from the predicted alternatives')
    logging.info('*' * 20 + '\n')
    # Log predictions
    best_refined_action_predictions = list()
    for story_id, ac_list in refined_action_score_table.items():
        # Sort action predictions
        sorted_actions = sorted(ac_list, reverse=True, key=lambda x: x[2])
        best_refined_action_predictions.append(sorted_actions[0][0])

    # Compute and update evaluation metric values
    best_result = compute_gen_metrics(best_refined_action_predictions,
                                      generation_targets)
    # Log metrics
    output_eval_file = os.path.join(
        args.output_dir,
        'generation_test_results_{}_{}_{}.txt'.format('action_refinement',
                                                      split, global_step))
    with open(output_eval_file, 'w') as writer:
        logger.info('***** Test results (best actions) *****')
        writer.write('STEP: {:s}\n'.format(str(global_step)))
        for key in sorted(best_result.keys()):
            logger.info('  %s = %s', key, str(best_result[key]))
            writer.write('%s = %s\n' % (key, str(best_result[key])))

    # ==================================================================================================================

    output_pred_file = os.path.join(
        args.output_dir, 'best_ranked_refined_actions_{}.lst'.format(split))
    with open(output_pred_file, 'w') as writer:
        logger.info('***** Write predictions *****')
        for pred_id, pred in enumerate(best_refined_action_predictions):
            # Write to file
            writer.write(
                json.dumps({
                    'prefix': generation_inputs[pred_id],
                    'target': generation_targets[pred_id],
                    'prediction': pred
                }) + '\n')

    # Compare initial and refined action samples
    logging.info('***** Action refinement outcomes *****')
    for ia_id, ia in enumerate(best_action_predictions):
        if ia_id >= 10:
            break
        logging.info('  Initial action: {:s}'.format(ia))
        logging.info('  Consequence of initial action: {:s}'.format(
            best_consequence_predictions[ia_id]))
        logging.info('  Refined action: {:s}'.format(
            best_refined_action_predictions[ia_id]))
        logging.info('-' * 20 + '\n')

    return best_action_predictions, best_refined_action_predictions