Esempio n. 1
0
    def recall_at_k(cls, labels, scores, k=1, summary_map=None):
        """Recall of the top-k candidates.

        Args:
            labels: binary (batch_size, num_candidates), 1 means good candidate
            scores: ranking scores (batch_size, num_candidates)
            summary_map (bool): if true, accumulates relevant statistics.

        Returns:
            The percentage of good candidates in the top-k candidates.

        """
        topk_candidates = np.argsort(-1.*scores, axis=1)[:, :k]
        # num_candidates might be smaller than k
        batch_size, actual_k = topk_candidates.shape

        binary_preds = np.zeros_like(labels, dtype=np.int32)
        row_inds = np.tile(np.arange(batch_size), [actual_k, 1]).T
        binary_preds[row_inds, topk_candidates] = 1

        num_true_positive = np.sum(np.logical_and(labels == 1, labels == binary_preds))
        num_positive_example = np.sum(labels)
        recall = safe_div(float(num_true_positive), num_positive_example)

        if summary_map is not None:
            prefix = 'recall_at_{}'.format(k)
            logstats.update_summary_map(summary_map, {
                '{}_tp'.format(prefix): num_true_positive,
                '{}_p'.format(prefix): num_positive_example,
                })

        return recall
Esempio n. 2
0
 def eval_joint(self, kb, span):
     #print 'eval_joint:', span
     logstats.update_summary_map(self.summary_map, {'joint_fact': 1})
     num, ent1, _, ent2 = span
     ent1 = ent1[1]
     ent2 = ent2[1]
     if ent1 == ent2:
         #print 'repeated'
         logstats.update_summary_map(self.summary_map, {'repeated': 1})
         return
     # Same type, i.e. in the same column
     if ent1[1] == ent2[1]:
         #print 'same column'
         logstats.update_summary_map(self.summary_map, {'same_col': 1})
         return
     num = self.str_to_num(num)
     count = 0
     for i, item in enumerate(kb.items):
         entities = [entity for entity in self.item_entities(item)]
         if ent1 in entities and ent2 in entities:
             count += 1
     #print 'correct joint ent'
     logstats.update_summary_map(self.summary_map, {'correct_joint_ent': 1})
     if count == num:
         #print 'correct joint'
         logstats.update_summary_map(self.summary_map, {'correct_joint': 1})
Esempio n. 3
0
 def __init__(self):
     keys = ('undecided', 'fact', 'single_fact', 'joint_fact', 'coref',
             'correct_single', 'correct_joint', 'correct_joint_ent',
             'repeated', 'same_col')
     self.summary_map = {}
     for k in keys:
         logstats.update_summary_map(self.summary_map, {k: 0})
Esempio n. 4
0
 def update_entity_stats(self, summary_map, batch_preds, batch_targets, prefix=''):
     def get_entity(x):
         return [e for e in x if is_entity(e)]
     pos_target = prefix + 'pos_target'
     pos_pred = prefix + 'pos_pred'
     tp = prefix + 'tp'
     for preds, targets in izip (batch_preds, batch_targets):
         preds = set(get_entity(preds))
         targets = set(get_entity(targets))
         # Don't record cases where no entity is presented
         if len(targets) > 0:
             logstats.update_summary_map(summary_map, {pos_target: len(targets), pos_pred: len(preds)})
             logstats.update_summary_map(summary_map, {tp: sum([1 if e in preds else 0 for e in targets])})
Esempio n. 5
0
 def eval_single(self, kb, span):
     #print 'eval_single:', span
     logstats.update_summary_map(self.summary_map, {'single_fact': 1})
     num, ent = span
     ent = ent[1]  # take the canonical form
     num = self.str_to_num(num)
     count = 0
     for i, item in enumerate(kb.items):
         for entity in self.item_entities(item):
             if entity == ent:
                 count += 1
     if num == count:
         #print 'correct single'
         logstats.update_summary_map(self.summary_map,
                                     {'correct_single': 1})
Esempio n. 6
0
    def update_selection_stats(self, summary_map, scores, targets, prefix=''):
        # NOTE: targets are from ground truth response and many contain new entities.
        # Ideally this would not happen as a mentioned entity is either from the agent's
        # KB or from partner's mentions (which is added to the graph), so during decoding
        # there shouldn't be new entities. However, the lexicon may "create" an entity.
        batch_size, num_nodes = scores.shape
        targets = targets[:, :num_nodes]

        pos_pred = scores > 0
        pos_target = targets == 1
        tp = np.sum(np.logical_and(pos_pred, pos_target))
        logstats.update_summary_map(
            summary_map, {
                prefix + 'tp': tp,
                prefix + 'pos_pred': np.sum(pos_pred),
                prefix + 'pos_target': np.sum(pos_target)
            })
Esempio n. 7
0
    def _run_batch_basic(self, dialogue_batch, sess, summary_map, test=False):
        '''
        Run truncated RNN through a sequence of batch examples.
        '''
        encoder_init_state = None
        matched_items = dialogue_batch['matched_items']
        for batch in dialogue_batch['batch_seq']:
            feed_dict = self._get_feed_dict(batch,
                                            encoder_init_state,
                                            matched_items=matched_items)
            if test:
                logits, final_state, loss, seq_loss, total_loss = sess.run(
                    [
                        self.model.decoder.output_dict['logits'],
                        self.model.final_state, self.model.loss,
                        self.model.seq_loss, self.model.total_loss
                    ],
                    feed_dict=feed_dict)
            else:
                _, logits, final_state, loss, seq_loss, gn = sess.run(
                    [
                        self.train_op,
                        self.model.decoder.output_dict['logits'],
                        self.model.final_state, self.model.loss,
                        self.model.seq_loss, self.grad_norm
                    ],
                    feed_dict=feed_dict)
            encoder_init_state = final_state

            if self.verbose:
                preds = np.argmax(logits, axis=2)
                self._print_batch(batch, preds, seq_loss)

            if test:
                logstats.update_summary_map(summary_map, {
                    'total_loss': total_loss[0],
                    'num_tokens': total_loss[1]
                })
            else:
                logstats.update_summary_map(summary_map, {'loss': loss})
                logstats.update_summary_map(summary_map, {'grad_norm': gn})
Esempio n. 8
0
    def _run_batch_graph(self, dialogue_batch, sess, summary_map, test=False):
        '''
        Run truncated RNN through a sequence of batch examples with knowledge graphs.
        '''
        encoder_init_state = None
        utterances = None
        graphs = dialogue_batch['graph']
        matched_items = dialogue_batch['matched_items']
        for i, batch in enumerate(dialogue_batch['batch_seq']):
            graph_data = graphs.get_batch_data(batch['encoder_tokens'],
                                               batch['decoder_tokens'],
                                               batch['encoder_entities'],
                                               batch['decoder_entities'],
                                               utterances, self.vocab)
            init_checklists = graphs.get_zero_checklists(1)
            feed_dict = self._get_feed_dict(batch, encoder_init_state,
                                            graph_data, graphs, self.data.copy,
                                            init_checklists,
                                            graph_data['encoder_nodes'],
                                            graph_data['decoder_nodes'],
                                            matched_items)
            if test:
                logits, final_state, utterances, loss, seq_loss, total_loss, sel_loss = sess.run(
                    [
                        self.model.decoder.output_dict['logits'],
                        self.model.final_state,
                        self.model.decoder.output_dict['utterances'],
                        self.model.loss, self.model.seq_loss,
                        self.model.total_loss, self.model.select_loss
                    ],
                    feed_dict=feed_dict)
            else:
                _, logits, final_state, utterances, loss, seq_loss, sel_loss, gn = sess.run(
                    [
                        self.train_op,
                        self.model.decoder.output_dict['logits'],
                        self.model.final_state,
                        self.model.decoder.output_dict['utterances'],
                        self.model.loss, self.model.seq_loss,
                        self.model.select_loss, self.grad_norm
                    ],
                    feed_dict=feed_dict)
            encoder_init_state = final_state

            if self.verbose:
                preds = np.argmax(logits, axis=2)
                if self.data.copy:
                    preds = graphs.copy_preds(preds,
                                              self.data.mappings['vocab'].size)
                self._print_batch(batch, preds, seq_loss)

            if test:
                logstats.update_summary_map(summary_map, {
                    'total_loss': total_loss[0],
                    'num_tokens': total_loss[1]
                })
            else:
                logstats.update_summary_map(summary_map, {'loss': loss})
                logstats.update_summary_map(summary_map,
                                            {'sel_loss': sel_loss})
                logstats.update_summary_map(summary_map, {'grad_norm': gn})
Esempio n. 9
0
 def inc_coref(self):
     logstats.update_summary_map(self.summary_map, {'coref': 1})
Esempio n. 10
0
 def inc_fact(self):
     logstats.update_summary_map(self.summary_map, {'fact': 1})
Esempio n. 11
0
 def inc_undecided(self):
     logstats.update_summary_map(self.summary_map, {'undecided': 1})
Esempio n. 12
0
    def _run_batch(self, dialogue_batch, sess, summary_map, test=False):
        '''
        Run truncated RNN through a sequence of batch examples.
        '''
        encoder_init_state = None
        init_price_history = None
        for batch in dialogue_batch['batch_seq']:
            # TODO: hacky
            if init_price_history is None and hasattr(self.model.decoder,
                                                      'price_predictor'):
                batch_size = batch['encoder_inputs'].shape[0]
                init_price_history = self.model.decoder.price_predictor.zero_init_price(
                    batch_size)
            feed_dict = self._get_feed_dict(
                batch,
                encoder_init_state,
                test=test,
                init_price_history=init_price_history)
            fetches = {
                'loss': self.model.loss,
            }

            if self.model.name == 'encdec':
                fetches['raw_preds'] = self.model.decoder.output_dict['logits']
            elif self.model.name == 'selector':
                fetches['raw_preds'] = self.model.decoder.output_dict['scores']
            else:
                raise ValueError

            if not test:
                fetches['train_op'] = self.train_op
                fetches['gn'] = self.grad_norm
            else:
                fetches['total_loss'] = self.model.total_loss

            if self.model.stateful:
                fetches['final_state'] = self.model.final_state

            if hasattr(self.model.decoder, 'price_predictor'):
                fetches['price_history'] = self.model.decoder.output_dict[
                    'price_history']

            if not test:
                fetches['merged'] = self.merged_summary

            results = sess.run(fetches, feed_dict=feed_dict)
            if not test:
                self.global_step += 1
                if self.global_step % 100 == 0:
                    self.train_writer.add_summary(results['merged'],
                                                  self.global_step)

            if self.model.stateful:
                encoder_init_state = results['final_state']
            else:
                encoder_init_state = None

            if 'price_history' in results:
                init_price_history = results['price_history']

            if self.verbose:
                preds = self.model.output_to_preds(results['raw_preds'])
                self._print_batch(batch, preds, results['loss'])

            if test:
                total_loss = results['total_loss']
                logstats.update_summary_map(summary_map, {
                    'total_loss': total_loss[0],
                    'num_tokens': total_loss[1]
                })
            else:
                logstats.update_summary_map(summary_map,
                                            {'loss': results['loss']})
                logstats.update_summary_map(summary_map,
                                            {'grad_norm': results['gn']})

            # TODO: refactor
            if self.model.name == 'selector':
                labels = batch['decoder_args']['candidate_labels']
                preds = results['raw_preds']
                for k in (1, 5):
                    recall = self.evaluator.recall_at_k(
                        labels, preds, k=k, summary_map=summary_map)
                    logstats.update_summary_map(
                        summary_map, {'recall_at_{}'.format(k): recall})