def __init__(self, config, ranker, logger, train_ds, vocab, random): super().__init__(config, ranker, vocab, train_ds, logger, random) self.loss_fn = { 'softmax': self.softmax, 'cross_entropy': self.cross_entropy, 'nogueira_cross_entropy': self.nogueira_cross_entropy, 'hinge': self.hinge }[config['lossfn']] self.dataset = train_ds self.input_spec = ranker.input_spec() self.iter_fields = self.input_spec['fields'] | {'runscore'} self.train_iter_core = onir.datasets.pair_iter( train_ds, fields=self.iter_fields, pos_source=self.config['pos_source'], neg_source=self.config['neg_source'], sampling=self.config['sampling'], pos_minrel=self.config['pos_minrel'], unjudged_rel=self.config['unjudged_rel'], num_neg=self.config['num_neg'], random=self.random, inf=True) self.train_iter = util.background( self.iter_batches(self.train_iter_core)) self.numneg = config['num_neg'] self._ewc = False
def iter_scores(self, ranker, datasource, device): if isinstance( ranker, Trivial ) and not ranker.neg and not ranker.qsum and not ranker.max: for qid, values in self.dataset.run().items(): for did, score in values.items(): yield qid, did, score return if isinstance( ranker, Trivial) and not ranker.neg and not ranker.qsum and ranker.max: qrels = self.dataset.qrels() for qid, values in self.dataset.run().items(): q_qrels = qrels.get(qid, {}) for did in values: yield qid, did, q_qrels.get(did, -1) return with torch.no_grad(): ranker.eval() ds = next(datasource, None) total = None if isinstance(ds, list): total = sum(len(d['query_id']) for d in ds) elif self.source == 'run': if self.run_threshold > 0: total = sum( min(len(v), self.run_threshold) for v in self.dataset.run().values()) else: total = sum(len(v) for v in self.dataset.run().values()) elif self.source == 'qrels': total = sum(len(v) for v in self.dataset.qrels().values()) with self.logger.pbar_raw(total=total, desc='pred', quiet=True) as pbar: for batch in util.background(ds): batch = { k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items() } rel_scores = self.ranker(**batch).cpu() if len(rel_scores.shape) == 2: rel_scores = rel_scores[:, 0] triples = list( zip(batch['query_id'], batch['doc_id'], rel_scores)) for qid, did, score in triples: yield qid, did, score.item() pbar.update(len(batch['query_id']))