def construct_seq_class_transformer(options:
                                    KaggleEvaluationOptions) -> Reranker:
    try:
        model = MonoBERT.get_model(options.model_name, device=options.device)
    except OSError:
        try:
            model = MonoBERT.get_model(
                        options.model_name,
                        from_tf=True,
                        device=options.device)
        except AttributeError:
            # Hotfix for BioBERT MS MARCO. Refactor.
            BertForSequenceClassification.bias = torch.nn.Parameter(
                                                    torch.zeros(2))
            BertForSequenceClassification.weight = torch.nn.Parameter(
                                                    torch.zeros(2, 768))
            model = BertForSequenceClassification.from_pretrained(
                        options.model_name, from_tf=True)
            model.classifier.weight = BertForSequenceClassification.weight
            model.classifier.bias = BertForSequenceClassification.bias
            device = torch.device(options.device)
            model = model.to(device).eval()
    tokenizer = MonoBERT.get_tokenizer(
                    options.tokenizer_name, do_lower_case=options.do_lower_case)
    return MonoBERT(model, tokenizer)
def construct_seq_class_transformer(
        options: DocumentRankingEvaluationOptions) -> Reranker:
    model = MonoBERT.get_model(options.model,
                               from_tf=options.from_tf,
                               device=options.device)
    tokenizer = MonoBERT.get_tokenizer(options.tokenizer_name)
    return MonoBERT(model, tokenizer)
Пример #3
0
def build_bert_reranker(
    name_or_path: str = "castorini/monobert-large-msmarco-finetune-only",
    device: str = None,
):
    """Returns a BERT reranker using the provided model name or path to load from"""
    model = MonoBERT.get_model(name_or_path, device=device)
    tokenizer = MonoBERT.get_tokenizer(name_or_path)
    return MonoBERT(model, tokenizer)