def recall_at_k(cls, labels, scores, k=1, summary_map=None): """Recall of the top-k candidates. Args: labels: binary (batch_size, num_candidates), 1 means good candidate scores: ranking scores (batch_size, num_candidates) summary_map (bool): if true, accumulates relevant statistics. Returns: The percentage of good candidates in the top-k candidates. """ topk_candidates = np.argsort(-1.*scores, axis=1)[:, :k] # num_candidates might be smaller than k batch_size, actual_k = topk_candidates.shape binary_preds = np.zeros_like(labels, dtype=np.int32) row_inds = np.tile(np.arange(batch_size), [actual_k, 1]).T binary_preds[row_inds, topk_candidates] = 1 num_true_positive = np.sum(np.logical_and(labels == 1, labels == binary_preds)) num_positive_example = np.sum(labels) recall = safe_div(float(num_true_positive), num_positive_example) if summary_map is not None: prefix = 'recall_at_{}'.format(k) logstats.update_summary_map(summary_map, { '{}_tp'.format(prefix): num_true_positive, '{}_p'.format(prefix): num_positive_example, }) return recall
def eval_joint(self, kb, span): #print 'eval_joint:', span logstats.update_summary_map(self.summary_map, {'joint_fact': 1}) num, ent1, _, ent2 = span ent1 = ent1[1] ent2 = ent2[1] if ent1 == ent2: #print 'repeated' logstats.update_summary_map(self.summary_map, {'repeated': 1}) return # Same type, i.e. in the same column if ent1[1] == ent2[1]: #print 'same column' logstats.update_summary_map(self.summary_map, {'same_col': 1}) return num = self.str_to_num(num) count = 0 for i, item in enumerate(kb.items): entities = [entity for entity in self.item_entities(item)] if ent1 in entities and ent2 in entities: count += 1 #print 'correct joint ent' logstats.update_summary_map(self.summary_map, {'correct_joint_ent': 1}) if count == num: #print 'correct joint' logstats.update_summary_map(self.summary_map, {'correct_joint': 1})
def __init__(self): keys = ('undecided', 'fact', 'single_fact', 'joint_fact', 'coref', 'correct_single', 'correct_joint', 'correct_joint_ent', 'repeated', 'same_col') self.summary_map = {} for k in keys: logstats.update_summary_map(self.summary_map, {k: 0})
def update_entity_stats(self, summary_map, batch_preds, batch_targets, prefix=''): def get_entity(x): return [e for e in x if is_entity(e)] pos_target = prefix + 'pos_target' pos_pred = prefix + 'pos_pred' tp = prefix + 'tp' for preds, targets in izip (batch_preds, batch_targets): preds = set(get_entity(preds)) targets = set(get_entity(targets)) # Don't record cases where no entity is presented if len(targets) > 0: logstats.update_summary_map(summary_map, {pos_target: len(targets), pos_pred: len(preds)}) logstats.update_summary_map(summary_map, {tp: sum([1 if e in preds else 0 for e in targets])})
def eval_single(self, kb, span): #print 'eval_single:', span logstats.update_summary_map(self.summary_map, {'single_fact': 1}) num, ent = span ent = ent[1] # take the canonical form num = self.str_to_num(num) count = 0 for i, item in enumerate(kb.items): for entity in self.item_entities(item): if entity == ent: count += 1 if num == count: #print 'correct single' logstats.update_summary_map(self.summary_map, {'correct_single': 1})
def update_selection_stats(self, summary_map, scores, targets, prefix=''): # NOTE: targets are from ground truth response and many contain new entities. # Ideally this would not happen as a mentioned entity is either from the agent's # KB or from partner's mentions (which is added to the graph), so during decoding # there shouldn't be new entities. However, the lexicon may "create" an entity. batch_size, num_nodes = scores.shape targets = targets[:, :num_nodes] pos_pred = scores > 0 pos_target = targets == 1 tp = np.sum(np.logical_and(pos_pred, pos_target)) logstats.update_summary_map( summary_map, { prefix + 'tp': tp, prefix + 'pos_pred': np.sum(pos_pred), prefix + 'pos_target': np.sum(pos_target) })
def _run_batch_basic(self, dialogue_batch, sess, summary_map, test=False): ''' Run truncated RNN through a sequence of batch examples. ''' encoder_init_state = None matched_items = dialogue_batch['matched_items'] for batch in dialogue_batch['batch_seq']: feed_dict = self._get_feed_dict(batch, encoder_init_state, matched_items=matched_items) if test: logits, final_state, loss, seq_loss, total_loss = sess.run( [ self.model.decoder.output_dict['logits'], self.model.final_state, self.model.loss, self.model.seq_loss, self.model.total_loss ], feed_dict=feed_dict) else: _, logits, final_state, loss, seq_loss, gn = sess.run( [ self.train_op, self.model.decoder.output_dict['logits'], self.model.final_state, self.model.loss, self.model.seq_loss, self.grad_norm ], feed_dict=feed_dict) encoder_init_state = final_state if self.verbose: preds = np.argmax(logits, axis=2) self._print_batch(batch, preds, seq_loss) if test: logstats.update_summary_map(summary_map, { 'total_loss': total_loss[0], 'num_tokens': total_loss[1] }) else: logstats.update_summary_map(summary_map, {'loss': loss}) logstats.update_summary_map(summary_map, {'grad_norm': gn})
def _run_batch_graph(self, dialogue_batch, sess, summary_map, test=False): ''' Run truncated RNN through a sequence of batch examples with knowledge graphs. ''' encoder_init_state = None utterances = None graphs = dialogue_batch['graph'] matched_items = dialogue_batch['matched_items'] for i, batch in enumerate(dialogue_batch['batch_seq']): graph_data = graphs.get_batch_data(batch['encoder_tokens'], batch['decoder_tokens'], batch['encoder_entities'], batch['decoder_entities'], utterances, self.vocab) init_checklists = graphs.get_zero_checklists(1) feed_dict = self._get_feed_dict(batch, encoder_init_state, graph_data, graphs, self.data.copy, init_checklists, graph_data['encoder_nodes'], graph_data['decoder_nodes'], matched_items) if test: logits, final_state, utterances, loss, seq_loss, total_loss, sel_loss = sess.run( [ self.model.decoder.output_dict['logits'], self.model.final_state, self.model.decoder.output_dict['utterances'], self.model.loss, self.model.seq_loss, self.model.total_loss, self.model.select_loss ], feed_dict=feed_dict) else: _, logits, final_state, utterances, loss, seq_loss, sel_loss, gn = sess.run( [ self.train_op, self.model.decoder.output_dict['logits'], self.model.final_state, self.model.decoder.output_dict['utterances'], self.model.loss, self.model.seq_loss, self.model.select_loss, self.grad_norm ], feed_dict=feed_dict) encoder_init_state = final_state if self.verbose: preds = np.argmax(logits, axis=2) if self.data.copy: preds = graphs.copy_preds(preds, self.data.mappings['vocab'].size) self._print_batch(batch, preds, seq_loss) if test: logstats.update_summary_map(summary_map, { 'total_loss': total_loss[0], 'num_tokens': total_loss[1] }) else: logstats.update_summary_map(summary_map, {'loss': loss}) logstats.update_summary_map(summary_map, {'sel_loss': sel_loss}) logstats.update_summary_map(summary_map, {'grad_norm': gn})
def inc_coref(self): logstats.update_summary_map(self.summary_map, {'coref': 1})
def inc_fact(self): logstats.update_summary_map(self.summary_map, {'fact': 1})
def inc_undecided(self): logstats.update_summary_map(self.summary_map, {'undecided': 1})
def _run_batch(self, dialogue_batch, sess, summary_map, test=False): ''' Run truncated RNN through a sequence of batch examples. ''' encoder_init_state = None init_price_history = None for batch in dialogue_batch['batch_seq']: # TODO: hacky if init_price_history is None and hasattr(self.model.decoder, 'price_predictor'): batch_size = batch['encoder_inputs'].shape[0] init_price_history = self.model.decoder.price_predictor.zero_init_price( batch_size) feed_dict = self._get_feed_dict( batch, encoder_init_state, test=test, init_price_history=init_price_history) fetches = { 'loss': self.model.loss, } if self.model.name == 'encdec': fetches['raw_preds'] = self.model.decoder.output_dict['logits'] elif self.model.name == 'selector': fetches['raw_preds'] = self.model.decoder.output_dict['scores'] else: raise ValueError if not test: fetches['train_op'] = self.train_op fetches['gn'] = self.grad_norm else: fetches['total_loss'] = self.model.total_loss if self.model.stateful: fetches['final_state'] = self.model.final_state if hasattr(self.model.decoder, 'price_predictor'): fetches['price_history'] = self.model.decoder.output_dict[ 'price_history'] if not test: fetches['merged'] = self.merged_summary results = sess.run(fetches, feed_dict=feed_dict) if not test: self.global_step += 1 if self.global_step % 100 == 0: self.train_writer.add_summary(results['merged'], self.global_step) if self.model.stateful: encoder_init_state = results['final_state'] else: encoder_init_state = None if 'price_history' in results: init_price_history = results['price_history'] if self.verbose: preds = self.model.output_to_preds(results['raw_preds']) self._print_batch(batch, preds, results['loss']) if test: total_loss = results['total_loss'] logstats.update_summary_map(summary_map, { 'total_loss': total_loss[0], 'num_tokens': total_loss[1] }) else: logstats.update_summary_map(summary_map, {'loss': results['loss']}) logstats.update_summary_map(summary_map, {'grad_norm': results['gn']}) # TODO: refactor if self.model.name == 'selector': labels = batch['decoder_args']['candidate_labels'] preds = results['raw_preds'] for k in (1, 5): recall = self.evaluator.recall_at_k( labels, preds, k=k, summary_map=summary_map) logstats.update_summary_map( summary_map, {'recall_at_{}'.format(k): recall})