def update_sent_attr_stats(sent_attrs, history, prediction): """ Update the sent_attrs dict with the attributes of a prediction with given history. Inputs: sent_attrs: dictionary mapping each attr (a string) to a list of floats (the scores). history: a ConvAI2History prediction: string. the response text for which we measure sent attributes """ for attr in sent_attrs.keys(): attr_score = eval_attr(prediction, history, attr) sent_attrs[attr].append(attr_score) return sent_attrs
def make_dataset(opt): # Initialize control information so we can compute sentence attributes. # Here we set build_task=False so we don't download data/controllable_dialogue # (because we're trying to create it instead). initialize_control_information(opt, build_task=False) # Create repeat label agent and assign it to the specified task agent = RepeatLabelAgent(opt) world = create_task(opt, agent) ignorefields = opt.get('ignore_fields', '') outfile = opt['outfile'] # Number of examples to process if opt['num_examples'] == -1: num_examples = world.num_examples() else: num_examples = opt['num_examples'] # List of controls to include: controls = opt['controls'].split(',') if opt['controls'] != '' else [] print('[ starting to convert.. ]') print('[ saving output to {} ]'.format(outfile)) fw = open(outfile, 'w') log_timer = TimeLogger() for _ in range(num_examples): world.parley() world.acts[0]['labels'] = world.acts[0].get( 'labels', world.acts[0].pop('eval_labels', None)) # Need to get history in order to compute control values hist = ConvAI2History(world.acts[0]['text'], assume_persontokens=False) response = world.acts[0]['labels'][0] # Compute control values for ctrl in controls: ctrl_val = eval_attr(response, hist, ctrl) if ctrl == 'avg_nidf': assert ctrl_val >= 0 assert ctrl_val <= 1 elif ctrl == 'question': assert ctrl_val in [0, 1] elif ctrl == 'lastuttsim': if ctrl_val is not None: assert ctrl_val >= -1 assert ctrl_val <= 1 else: raise Exception('unexpected ctrl name: %s' % ctrl) world.acts[0][ctrl] = ctrl_val # add control value to act # Write to file txt = msg_to_str(world.acts[0], ignore_fields=ignorefields) fw.write(txt + '\n') if world.acts[0].get('episode_done', False): fw.write('\n') if log_timer.time() > opt['log_every_n_secs']: text, _log = log_timer.log(world.total_parleys, world.num_examples()) print(text) if world.epoch_done(): print('EPOCH DONE') break fw.close()