Пример #1
0
def define_model(name, config=None, location=None):
    # config가 있으면 처음 training하는 경우, 없으면 체크포인트 불러오기
    if name in [
            "bert-base-multilingual-cased",
            "sangrimlee/bert-base-multilingual-cased-korquad",
            "kykim/bert-kor-base", "monologg/kobert"
    ]:
        return BertForSequenceClassification.from_pretrained(
            name, config=config
        ) if config else BertForSequenceClassification.from_pretrained(
            location)
    elif name in [
            "monologg/koelectra-base-v3-discriminator",
            "kykim/electra-kor-base"
    ]:
        return ElectraForSequenceClassification.from_pretrained(
            name, config=config
        ) if config else ElectraForSequenceClassification.from_pretrained(
            location)
    elif name in ["xlm-roberta-large"]:
        return XLMRobertaForSequenceClassification.from_pretrained(
            name, config=config
        ) if config else XLMRobertaForSequenceClassification.from_pretrained(
            location)
    elif name in ["kykim/funnel-kor-base"]:
        return FunnelForSequenceClassification.from_pretrained(
            name, config=config
        ) if config else FunnelForSequenceClassification.from_pretrained(
            location)
 def create_and_check_for_sequence_classification(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     sequence_labels,
     token_labels,
     choice_labels,
     fake_token_labels,
 ):
     config.num_labels = self.num_labels
     model = FunnelForSequenceClassification(config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
     self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))