示例#1
0
def main():
    # Initialization for the parallel enviroment
    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."

    paddle.set_device(args.device)
    set_seed(args)
    # Define the model and metric
    model = BigBirdForSequenceClassification.from_pretrained(
        args.model_name_or_path)
    criterion = nn.CrossEntropyLoss()
    metric = paddle.metric.Accuracy()

    # Define the tokenizer and dataloader
    tokenizer = BigBirdTokenizer.from_pretrained(args.model_name_or_path)
    global config
    config = getattr(model,
                     BigBirdForSequenceClassification.base_model_prefix).config
    train_data_loader, test_data_loader = \
            create_dataloader(args.batch_size, args.max_encoder_length, tokenizer)

    # Define the Adam optimizer
    optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
                                      learning_rate=args.learning_rate,
                                      epsilon=1e-6)

    # Finetune the classification model
    do_train(model, criterion, metric, optimizer, train_data_loader, tokenizer)

    # Evaluate the finetune model
    do_evalute(model, criterion, metric, test_data_loader)
示例#2
0
def main():
    # Initialization for the parallel enviroment
    paddle.set_device(args.device)
    set_seed(args)
    # Define the model and metric
    # In finetune task, bigbird performs better when setting dropout to zero.
    model = BigBirdForSequenceClassification.from_pretrained(
        args.model_name_or_path,
        attn_dropout=args.attn_dropout,
        hidden_dropout_prob=args.hidden_dropout_prob)

    criterion = nn.CrossEntropyLoss()
    metric = paddle.metric.Accuracy()

    # Define the tokenizer and dataloader
    tokenizer = BigBirdTokenizer.from_pretrained(args.model_name_or_path)
    config = getattr(model,
                     BigBirdForSequenceClassification.base_model_prefix).config
    train_data_loader, test_data_loader = \
            create_dataloader(args.batch_size, args.max_encoder_length, tokenizer, config)

    # Define the Adam optimizer
    optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
                                      learning_rate=args.learning_rate,
                                      epsilon=1e-6)

    # Finetune the classification model
    do_train(model, criterion, metric, optimizer, train_data_loader, tokenizer)

    # Evaluate the finetune model
    do_evalute(model, criterion, metric, test_data_loader)
示例#3
0
 def setUp(self):
     np.random.seed(102)
     self.tokenizer = BigBirdTokenizer.from_pretrained(
         'bigbird-base-uncased')
     self.set_text()
     self.set_input()
     self.set_output()
示例#4
0
def main():
    # Initialization for the parallel enviroment
    paddle.set_device(args.device)
    set_seed(args)
    # Define the model and metric
    model = BigBirdForSequenceClassification.from_pretrained(
        args.model_name_or_path)
    criterion = nn.CrossEntropyLoss()
    metric = paddle.metric.Accuracy()

    # Define the tokenizer and dataloader
    tokenizer = BigBirdTokenizer.from_pretrained(args.model_name_or_path)
    global config
    config = BigBirdModel.pretrained_init_configuration[
        args.model_name_or_path]
    train_data_loader, test_data_loader = \
            create_dataloader(args.batch_size, args.max_encoder_length, tokenizer)

    # Define the Adam optimizer
    optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
                                      learning_rate=args.learning_rate,
                                      epsilon=1e-6)

    # Finetune the classification model
    do_train(model, criterion, metric, optimizer, train_data_loader,
             test_data_loader)

    # Evaluate the finetune model
    do_evalute(model, criterion, metric, test_data_loader)
示例#5
0
 def setUp(self):
     self.tokenizer = BigBirdTokenizer.from_pretrained(
         'bigbird-base-uncased')
示例#6
0
 def test_not_exist_file(self):
     self.tokenizer = BigBirdTokenizer(sentencepiece_model_file='')