def run(self):
        self._load_ranker_weights(self.ranker, self.vocab, self.trainer,
                                  self.valid_pred, self.train_ds)
        device = util.device(self.config, self.logger)

        unsupervised_run = self.test_ds.run_dict()
        supervised_run = self.test_pred.rerank_dict(self.ranker, device)
        unsupervised_run_indexed = run_indexed(unsupervised_run)
        qrels = self.test_ds.qrels()
        measures = list(self.test_pred.config['measures'].split(','))

        top_threshold, top_metric = 0, 0
        for threshold in range(1, self.test_ds.config['ranktopk'] + 1):
            threshold_run = rerank_cutoff(threshold, unsupervised_run_indexed,
                                          supervised_run)
            metrics_by_query = metrics.calc(qrels, threshold_run,
                                            set(measures))
            metrics_mean = metrics.mean(metrics_by_query)
            message = ' '.join(
                [f'threshold={threshold}'] +
                [f'{k}={v:.4f}' for k, v in metrics_mean.items()])
            if metrics_mean[measures[0]] > top_metric:
                top_threshold, top_metric = threshold, metrics_mean[
                    measures[0]]
                message += ' <--'
            self.logger.debug(message)
        self.logger.info('top_threshold={} {}={:.4f}'.format(
            top_threshold, measures[0], top_metric))
Beispiel #2
0
    def run(self):
        if self.config['output_vecs'] == '':
            raise ValueError('must provide pipeline.output_vecs setting (name of weights file)')
        vecs_file = self.config['output_vecs']
        if os.path.exists(vecs_file) and not self.config['overwrite']:
            raise ValueError(f'{vecs_file} already exists. Please rename pipeline.output_vecs or set pipeline.overwrite=True')
        self._load_ranker_weights(self.ranker, self.vocab, self.trainer, self.valid_pred, self.train_ds)

        device = util.device(self.config, self.logger)
        doc_iter = self._iter_doc_vectors(self.ranker, device)
        docno = 0
        PRUNE = self.config['prune']
        LEX_SIZE = self.vocab.lexicon_size()
        with util.finialized_file(vecs_file, 'wb') as outf:
            doc_iter = self.logger.pbar(doc_iter, desc='documents', total=self.test_ds.num_docs())
            for did, dvec in doc_iter:
                assert docno == int(did) # these must go in sequence! Only works for MS-MARCO dataset. TODO: include some DID-to-docno mapping
                if PRUNE == 0:
                    outf.write(dvec.tobytes())
                else:
                    idxs = np.argpartition(dvec, LEX_SIZE - PRUNE)[-PRUNE:].astype(np.int16)
                    idxs.sort()
                    vals = dvec[idxs]
                    outf.write(idxs.tobytes())
                    outf.write(vals.tobytes())
                docno += 1
Beispiel #3
0
    def pred_ctxt(self):
        device = util.device(self.config, self.logger)

        if self.config['preload']:
            datasource = self._preload_batches(device)
        else:
            datasource = self._reload_batches(device)

        return PredictorContext(self, datasource, device)
Beispiel #4
0
    def __init__(self, config, ranker, vocab, train_ds, logger, random):
        self.config = config
        self.ranker = ranker
        self.vocab = vocab
        self.logger = logger
        self.dataset = train_ds
        self.random = random

        self.batch_size = self.config['batch_size']
        if self.config['grad_acc_batch'] > 0:
            assert self.config['batch_size'] % self.config['grad_acc_batch'] == 0, \
                "batch_size must be a multiple of grad_acc_batch"
            self.batch_size = self.config['grad_acc_batch']

        self.device = util.device(self.config, self.logger)
Beispiel #5
0
    def run(self):
        if self.config['queries']:
            logger.debug(
                'loading queries from {queries}'.format(**self.config))
            query_iter = plaintext.read_tsv(self.config['queries'])
        else:
            logger.debug('loading queries test_ds')
            query_iter = self.test_ds.all_queries_raw()

        if self.config['rerank']:
            if not self.config['dvec_file']:
                raise ValueError('must provide dvec_file')
            self._load_ranker_weights(self.ranker, self.vocab, self.trainer,
                                      self.valid_pred, self.train_ds)
            self.ranker.eval()
            input_spec = self.ranker.input_spec()
            fields = {
                f
                for f in input_spec['fields'] if f.startswith('query_')
            }
            device = util.device(self.config, logger)
            vocab_size = self.vocab.lexicon_size()
            num_docs = self.test_ds.num_docs()
            dvec_cache = EpicCacheReader(self.config['dvec_file'],
                                         self.config['prune'], num_docs,
                                         vocab_size, self.config['dvec_inmem'],
                                         self.config['gpu'])
        else:
            pass  # only do initial retrieval

        self.timer = util.DurationTimer(gpu_sync=self.config['gpu'])
        with torch.no_grad():
            if self.config['mode'] == 'time':
                self.time(query_iter, dvec_cache, fields, input_spec, device)
            if self.config['mode'] == 'predict':
                self.predict(query_iter, dvec_cache, fields, input_spec,
                             device)