コード例 #1
0
ファイル: CEDRKNRM.py プロジェクト: larryli1999/capreolus
import math

import torch
from torch import nn
from transformers import BertModel, ElectraModel, AutoModel

from capreolus import ConfigOption, Dependency, get_logger
from capreolus.reranker import Reranker
from capreolus.reranker.common import RbfKernelBank

logger = get_logger(__name__)


class CEDRKNRM_Class(nn.Module):
    def __init__(self, extractor, config, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.extractor = extractor
        self.config = config

        if config["pretrained"] == "electra-base-msmarco":
            self.bert = ElectraModel.from_pretrained(
                "Capreolus/electra-base-msmarco",
                hidden_dropout_prob=config["hidden_dropout_prob"],
                output_hidden_states=True)
        elif config["pretrained"] == "electra-base":
            self.bert = ElectraModel.from_pretrained(
                "google/electra-base-discriminator",
                hidden_dropout_prob=config["hidden_dropout_prob"],
                output_hidden_states=True)
        elif config["pretrained"] == "bert-base-msmarco":
            self.bert = BertModel.from_pretrained(
コード例 #2
0
ファイル: __init__.py プロジェクト: vrdn-23/capreolus
import os
import json

import numpy as np
from capreolus import ModuleBase, get_logger

logger = get_logger(__name__)  # pylint: disable=invalid-name


class Trainer(ModuleBase):
    """Base class for Trainer modules. The purpose of a Trainer is to train a :class:`~capreolus.reranker.Reranker` module and use it to make predictions. Capreolus provides two trainers: :class:`~capreolus.trainer.pytorch.PytorchTrainer` and :class:`~capreolus.trainer.tensorflow.TensorFlowTrainer`

    Modules should provide:
        - a ``train`` method that trains a reranker on training and dev (validation) data
        - a ``predict`` method that uses a reranker to make predictions on data
    """

    module_type = "trainer"
    requires_random_seed = True

    @staticmethod
    def load_loss_file(fn):
        """Loads loss history from fn

        Args:
           fn (Path): path to a loss.txt file

        Returns:
            a list of losses ordered by iterations

        """