예제 #1
0
def eval_wordstat(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    """
    random.seed(42)

    # Setup control information
    initialize_control_information(opt)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if opt.get('external_dict'):
        print('[ Using external dictionary from: {} ]'.format(opt['external_dict']))
        dict_opt = copy.deepcopy(opt)
        dict_opt['dict_file'] = opt['external_dict']
        dictionary = DictionaryAgent(dict_opt)
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    batch_size = opt['batchsize']

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    data = {}  # This will be written to the output json file
    data['opt'] = agent.opt  # Save the opt to json

    # Determine the output filename
    if opt['gold_response']:  # Special output file for gold response
        model_dir, _ = os.path.split(opt.get('model_file'))
        outfile = os.path.join(model_dir, 'goldresponse')
        if opt['use_reply'] != 'label':
            raise ValueError(
                'You should set --use-reply label (not --use-reply model) '
                'when measuring goldresponse stats'
            )
    else:
        outfile = "%s.%s.%s.%s" % (
            opt.get('model_file'),
            opt.get('datatype'),
            "use%sreply" % agent.opt['use_reply'],
            "beam%i" % agent.opt['beam_size'],
        )
        if agent.opt['beam_size'] > 1:
            outfile += ".beamminnbest%i" % agent.opt['beam_min_n_best']
        if len(agent.control_settings) > 0:
            outfile += ".setcontrols:" + "_".join(
                [
                    "%s%s" % (c, str(agent.control_settings[c]['set_value']))
                    for c in sorted(agent.control_settings.keys())
                ]
            )
        if agent.opt['beam_reorder'] not in ['none', False]:
            outfile += ".beamreorder_%s" % agent.opt['beam_reorder']
        if len(agent.wd_features) > 0:
            sorted_bfw = sorted(
                list(zip(agent.wd_features, agent.wd_wts)), key=lambda x: x[0]
            )
            outfile += ".WDfeatures:" + "_".join(
                ["%s%s" % (f, str(w)) for f, w in sorted_bfw]
            )
    if opt['num_examples'] != -1:
        outfile += ".numex%i" % opt['num_examples']
    outfile += ".wordstats.json"
    print("\nOutfile: %s\n" % outfile)

    cnt = 0
    word_statistics = {
        'mean_wlength': [],  # list of length (in words) of utterances
        'mean_clength': [],  # list of length (in chars) of utterances
        'freqs_cnt': Counter(),  # Counter for word frequencies, bucketed
        'word_cnt': 0,  # total number of words in all utterances
        'pred_list': [],  # list of generated utterances after applying normalize_answer
        'pure_pred_list': [],  # list of generated utterances
        'context_list': [],  # list of text inputs (persona and conversation history)
    }
    bins = [int(i) for i in opt['freq_bins'].split(',')]

    # This dictionary records all the sentence-level controllable attributes
    # For each attribute, we have a list of all the values
    sent_attrs = {attr: [] for attr in ATTR2SENTSCOREFN.keys()}  # str to list of floats

    # histories will be a list of ConvAI2History objects
    histories = []

    def process_prediction(prediction, word_statistics):
        word_statistics['pred_list'].append(normalize_answer(prediction))
        freqs, _cnt, wlength, clength = get_word_stats(
            prediction, dictionary, bins=bins
        )
        word_statistics['word_cnt'] += _cnt
        word_statistics['mean_wlength'].append(wlength)
        word_statistics['mean_clength'].append(clength)
        word_statistics['freqs_cnt'] += Counter(freqs)
        return word_statistics

    t0 = time.time()
    while not world.epoch_done():
        world.parley()
        # orig eval_wordstat.py handles bsz=1 but for simplicity we assume bsz>1
        assert batch_size != 1
        for w in world.worlds:
            try:
                try:
                    response_act = w.acts[-1]
                    prediction = response_act['text']
                except KeyError:
                    continue
                if opt['gold_response']:
                    # If we're measuring gold response, use eval_label as prediction
                    prediction = w.acts[0]['eval_labels'][0]
                    response_act = {'text': prediction}
                word_statistics['context_list'].append(w.acts[0]['text'])
                word_statistics['pure_pred_list'].append(prediction)
            except IndexError:
                continue
            cnt += 1
            word_statistics = process_prediction(prediction, word_statistics)

            # Compute and record sentence-level attributes
            history = ConvAI2History(w.acts[0]['text'])
            histories.append(history)
            sent_attrs = update_sent_attr_stats(sent_attrs, history, prediction)

        # Periodically log some info
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(), report)
            print(text)

        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    print("Time to process %i examples: %f seconds" % (cnt, time.time() - t0))

    # Compute percent unique
    # Note this is w.r.t. normalized pred_list not original pure_pred_list
    unique_list = []
    cntr = Counter(word_statistics['pred_list'])
    for k, v in cntr.items():
        if v == 1:
            unique_list.append(k)
    unique_percent = len(unique_list) / len(word_statistics['pred_list']) * 100

    # Print a final report
    report = world.report()
    if opt['gold_response']:
        report['ppl'] = 0.0  # For gold responses, overwrite the perplexity
    print(report)

    # Put all information in data dict
    data['unique_percent'] = unique_percent  # percent of all responses that are unique
    data['word_statistics'] = word_statistics  # word stats, as in orig eval_wordstat
    data['report'] = report  # the final report
    data['histories'] = [
        (hist.persona_lines, hist.partner_utts, hist.own_utts) for hist in histories
    ]  # history for each example
    data['sent_attrs'] = sent_attrs  # all sentence attribute values for responses

    # Write data to outfile
    print("Writing to %s..." % outfile)
    with open(outfile, 'w') as f:
        json.dump(data, f)
예제 #2
0
def make_dataset(opt):

    # Initialize control information so we can compute sentence attributes.
    # Here we set build_task=False so we don't download data/controllable_dialogue
    # (because we're trying to create it instead).
    initialize_control_information(opt, build_task=False)

    # Create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    ignorefields = opt.get('ignore_fields', '')
    outfile = opt['outfile']

    # Number of examples to process
    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']

    # List of controls to include:
    controls = opt['controls'].split(',') if opt['controls'] != '' else []

    print('[ starting to convert.. ]')
    print('[ saving output to {} ]'.format(outfile))
    fw = open(outfile, 'w')
    log_timer = TimeLogger()

    for _ in range(num_examples):
        world.parley()
        world.acts[0]['labels'] = world.acts[0].get(
            'labels', world.acts[0].pop('eval_labels', None))

        # Need to get history in order to compute control values
        hist = ConvAI2History(world.acts[0]['text'], assume_persontokens=False)
        response = world.acts[0]['labels'][0]

        # Compute control values
        for ctrl in controls:
            ctrl_val = eval_attr(response, hist, ctrl)
            if ctrl == 'avg_nidf':
                assert ctrl_val >= 0
                assert ctrl_val <= 1
            elif ctrl == 'question':
                assert ctrl_val in [0, 1]
            elif ctrl == 'lastuttsim':
                if ctrl_val is not None:
                    assert ctrl_val >= -1
                    assert ctrl_val <= 1
            else:
                raise Exception('unexpected ctrl name: %s' % ctrl)
            world.acts[0][ctrl] = ctrl_val  # add control value to act

        # Write to file
        txt = msg_to_str(world.acts[0], ignore_fields=ignorefields)
        fw.write(txt + '\n')
        if world.acts[0].get('episode_done', False):
            fw.write('\n')

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            print(text)

        if world.epoch_done():
            print('EPOCH DONE')
            break
    fw.close()