def eval_wordstat(opt): """ Evaluates a model. :param opt: tells the evaluation function how to run """ random.seed(42) # Setup control information initialize_control_information(opt) # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) world = create_task(opt, agent) if opt.get('external_dict'): print('[ Using external dictionary from: {} ]'.format(opt['external_dict'])) dict_opt = copy.deepcopy(opt) dict_opt['dict_file'] = opt['external_dict'] dictionary = DictionaryAgent(dict_opt) else: print('[ Using model bundled dictionary ]') dictionary = agent.dict batch_size = opt['batchsize'] log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() data = {} # This will be written to the output json file data['opt'] = agent.opt # Save the opt to json # Determine the output filename if opt['gold_response']: # Special output file for gold response model_dir, _ = os.path.split(opt.get('model_file')) outfile = os.path.join(model_dir, 'goldresponse') if opt['use_reply'] != 'label': raise ValueError( 'You should set --use-reply label (not --use-reply model) ' 'when measuring goldresponse stats' ) else: outfile = "%s.%s.%s.%s" % ( opt.get('model_file'), opt.get('datatype'), "use%sreply" % agent.opt['use_reply'], "beam%i" % agent.opt['beam_size'], ) if agent.opt['beam_size'] > 1: outfile += ".beamminnbest%i" % agent.opt['beam_min_n_best'] if len(agent.control_settings) > 0: outfile += ".setcontrols:" + "_".join( [ "%s%s" % (c, str(agent.control_settings[c]['set_value'])) for c in sorted(agent.control_settings.keys()) ] ) if agent.opt['beam_reorder'] not in ['none', False]: outfile += ".beamreorder_%s" % agent.opt['beam_reorder'] if len(agent.wd_features) > 0: sorted_bfw = sorted( list(zip(agent.wd_features, agent.wd_wts)), key=lambda x: x[0] ) outfile += ".WDfeatures:" + "_".join( ["%s%s" % (f, str(w)) for f, w in sorted_bfw] ) if opt['num_examples'] != -1: outfile += ".numex%i" % opt['num_examples'] outfile += ".wordstats.json" print("\nOutfile: %s\n" % outfile) cnt = 0 word_statistics = { 'mean_wlength': [], # list of length (in words) of utterances 'mean_clength': [], # list of length (in chars) of utterances 'freqs_cnt': Counter(), # Counter for word frequencies, bucketed 'word_cnt': 0, # total number of words in all utterances 'pred_list': [], # list of generated utterances after applying normalize_answer 'pure_pred_list': [], # list of generated utterances 'context_list': [], # list of text inputs (persona and conversation history) } bins = [int(i) for i in opt['freq_bins'].split(',')] # This dictionary records all the sentence-level controllable attributes # For each attribute, we have a list of all the values sent_attrs = {attr: [] for attr in ATTR2SENTSCOREFN.keys()} # str to list of floats # histories will be a list of ConvAI2History objects histories = [] def process_prediction(prediction, word_statistics): word_statistics['pred_list'].append(normalize_answer(prediction)) freqs, _cnt, wlength, clength = get_word_stats( prediction, dictionary, bins=bins ) word_statistics['word_cnt'] += _cnt word_statistics['mean_wlength'].append(wlength) word_statistics['mean_clength'].append(clength) word_statistics['freqs_cnt'] += Counter(freqs) return word_statistics t0 = time.time() while not world.epoch_done(): world.parley() # orig eval_wordstat.py handles bsz=1 but for simplicity we assume bsz>1 assert batch_size != 1 for w in world.worlds: try: try: response_act = w.acts[-1] prediction = response_act['text'] except KeyError: continue if opt['gold_response']: # If we're measuring gold response, use eval_label as prediction prediction = w.acts[0]['eval_labels'][0] response_act = {'text': prediction} word_statistics['context_list'].append(w.acts[0]['text']) word_statistics['pure_pred_list'].append(prediction) except IndexError: continue cnt += 1 word_statistics = process_prediction(prediction, word_statistics) # Compute and record sentence-level attributes history = ConvAI2History(w.acts[0]['text']) histories.append(history) sent_attrs = update_sent_attr_stats(sent_attrs, history, prediction) # Periodically log some info if log_time.time() > log_every_n_secs: report = world.report() text, report = log_time.log(report['exs'], world.num_examples(), report) print(text) if opt['num_examples'] > 0 and cnt >= opt['num_examples']: break if world.epoch_done(): print("EPOCH DONE") print("Time to process %i examples: %f seconds" % (cnt, time.time() - t0)) # Compute percent unique # Note this is w.r.t. normalized pred_list not original pure_pred_list unique_list = [] cntr = Counter(word_statistics['pred_list']) for k, v in cntr.items(): if v == 1: unique_list.append(k) unique_percent = len(unique_list) / len(word_statistics['pred_list']) * 100 # Print a final report report = world.report() if opt['gold_response']: report['ppl'] = 0.0 # For gold responses, overwrite the perplexity print(report) # Put all information in data dict data['unique_percent'] = unique_percent # percent of all responses that are unique data['word_statistics'] = word_statistics # word stats, as in orig eval_wordstat data['report'] = report # the final report data['histories'] = [ (hist.persona_lines, hist.partner_utts, hist.own_utts) for hist in histories ] # history for each example data['sent_attrs'] = sent_attrs # all sentence attribute values for responses # Write data to outfile print("Writing to %s..." % outfile) with open(outfile, 'w') as f: json.dump(data, f)
def make_dataset(opt): # Initialize control information so we can compute sentence attributes. # Here we set build_task=False so we don't download data/controllable_dialogue # (because we're trying to create it instead). initialize_control_information(opt, build_task=False) # Create repeat label agent and assign it to the specified task agent = RepeatLabelAgent(opt) world = create_task(opt, agent) ignorefields = opt.get('ignore_fields', '') outfile = opt['outfile'] # Number of examples to process if opt['num_examples'] == -1: num_examples = world.num_examples() else: num_examples = opt['num_examples'] # List of controls to include: controls = opt['controls'].split(',') if opt['controls'] != '' else [] print('[ starting to convert.. ]') print('[ saving output to {} ]'.format(outfile)) fw = open(outfile, 'w') log_timer = TimeLogger() for _ in range(num_examples): world.parley() world.acts[0]['labels'] = world.acts[0].get( 'labels', world.acts[0].pop('eval_labels', None)) # Need to get history in order to compute control values hist = ConvAI2History(world.acts[0]['text'], assume_persontokens=False) response = world.acts[0]['labels'][0] # Compute control values for ctrl in controls: ctrl_val = eval_attr(response, hist, ctrl) if ctrl == 'avg_nidf': assert ctrl_val >= 0 assert ctrl_val <= 1 elif ctrl == 'question': assert ctrl_val in [0, 1] elif ctrl == 'lastuttsim': if ctrl_val is not None: assert ctrl_val >= -1 assert ctrl_val <= 1 else: raise Exception('unexpected ctrl name: %s' % ctrl) world.acts[0][ctrl] = ctrl_val # add control value to act # Write to file txt = msg_to_str(world.acts[0], ignore_fields=ignorefields) fw.write(txt + '\n') if world.acts[0].get('episode_done', False): fw.write('\n') if log_timer.time() > opt['log_every_n_secs']: text, _log = log_timer.log(world.total_parleys, world.num_examples()) print(text) if world.epoch_done(): print('EPOCH DONE') break fw.close()