def do_train(args):
    set_seed(args)
    tokenizer_class, eval_name, test_name, = DATASET_INFO[args.dataset]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    train_ds, eval_ds, test_ds = load_dataset(
        args.dataset, splits=["train", eval_name, test_name])
    num_classes = len(train_ds.label_list)
    no_entity_id = num_classes - 1

    paddle.set_device(args.device)
    trainer_num = paddle.distributed.get_world_size()
    if trainer_num > 1:
        paddle.distributed.init_parallel_env()
    rank = paddle.distributed.get_rank()
    if rank == 0:
        if os.path.exists(args.model_name_or_path):
            logger.info("init checkpoint from %s" % args.model_name_or_path)
    model = ErnieDocForTokenClassification.from_pretrained(
        args.model_name_or_path, num_classes=num_classes)
    model_config = model.ernie_doc.config
    if trainer_num > 1:
        model = paddle.DataParallel(model)

    train_ds_iter = SequenceLabelingIterator(
        train_ds,
        args.batch_size,
        tokenizer,
        trainer_num,
        trainer_id=rank,
        memory_len=model_config["memory_len"],
        max_seq_length=args.max_seq_length,
        random_seed=args.seed,
        no_entity_id=no_entity_id)
    eval_ds_iter = SequenceLabelingIterator(
        eval_ds,
        args.batch_size,
        tokenizer,
        trainer_num,
        trainer_id=rank,
        memory_len=model_config["memory_len"],
        max_seq_length=args.max_seq_length,
        mode="eval",
        no_entity_id=no_entity_id)
    test_ds_iter = SequenceLabelingIterator(
        test_ds,
        args.batch_size,
        tokenizer,
        trainer_num,
        trainer_id=rank,
        memory_len=model_config["memory_len"],
        max_seq_length=args.max_seq_length,
        mode="test",
        no_entity_id=no_entity_id)

    train_dataloader = paddle.io.DataLoader.from_generator(capacity=70,
                                                           return_list=True)
    train_dataloader.set_batch_generator(train_ds_iter, paddle.get_device())
    eval_dataloader = paddle.io.DataLoader.from_generator(capacity=70,
                                                          return_list=True)
    eval_dataloader.set_batch_generator(eval_ds_iter, paddle.get_device())
    test_dataloader = paddle.io.DataLoader.from_generator(capacity=70,
                                                          return_list=True)
    test_dataloader.set_batch_generator(test_ds_iter, paddle.get_device())

    num_training_examples = train_ds_iter.get_num_examples()
    num_training_steps = args.epochs * num_training_examples // args.batch_size // trainer_num
    logger.info("Device count: %d, trainer_id: %d" % (trainer_num, rank))
    logger.info("Num train examples: %d" % num_training_examples)
    logger.info("Max train steps: %d" % num_training_steps)
    logger.info("Num warmup steps: %d" %
                int(num_training_steps * args.warmup_proportion))

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps,
                                         args.warmup_proportion)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    # Construct dict
    name_dict = dict()
    for n, p in model.named_parameters():
        name_dict[p.name] = n

    optimizer = AdamWDL(learning_rate=lr_scheduler,
                        parameters=model.parameters(),
                        weight_decay=args.weight_decay,
                        apply_decay_param_fun=lambda x: x in decay_params,
                        n_layers=model_config["num_hidden_layers"],
                        layerwise_decay=args.layerwise_decay,
                        name_dict=name_dict)

    criterion = paddle.nn.loss.CrossEntropyLoss()
    metric = ChunkEvaluator(label_list=train_ds.label_list)

    global_steps = 0

    create_memory = partial(init_memory, args.batch_size, args.memory_length,
                            model_config["hidden_size"],
                            model_config["num_hidden_layers"])
    # Copy the memory
    memories = create_memory()
    tic_train = time.time()
    best_f1 = 0
    for epoch in range(args.epochs):
        train_ds_iter.shuffle_sample()
        train_dataloader.set_batch_generator(train_ds_iter,
                                             paddle.get_device())
        for step, batch in enumerate(train_dataloader, start=1):
            global_steps += 1
            input_ids, position_ids, token_type_ids, attn_mask, labels, lengths, qids, \
                gather_idx, need_cal_loss = batch
            logits, memories = model(input_ids, memories, token_type_ids,
                                     position_ids, attn_mask)
            logits, labels = list(
                map(lambda x: paddle.gather(x, gather_idx), [logits, labels]))

            loss = criterion(logits, labels) * need_cal_loss
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

            if global_steps % args.logging_steps == 0:
                logger.info(
                    "train: global step %d, epoch: %d, loss: %f, lr: %f, speed: %.2f step/s"
                    % (global_steps, epoch, loss, lr_scheduler.get_lr(),
                       args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()
            if global_steps % args.save_steps == 0:
                # Evaluate
                logger.info("Eval:")
                precision, recall, f1_score = evaluate(model, metric,
                                                       eval_dataloader,
                                                       create_memory())
                # Save
                if rank == 0:
                    output_dir = os.path.join(args.output_dir,
                                              "model_%d" % (global_steps))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    if f1_score > best_f1:
                        logger.info("Save best model......")
                        best_f1 = f1_score
                        best_model_dir = os.path.join(args.output_dir,
                                                      "best_model")
                        if not os.path.exists(best_model_dir):
                            os.makedirs(best_model_dir)
                        model_to_save.save_pretrained(best_model_dir)
                        tokenizer.save_pretrained(best_model_dir)

            if args.max_steps > 0 and global_steps >= args.max_steps:
                return

    logger.info("Final test result:")
    eval_acc = evaluate(model, metric, test_dataloader, create_memory())
Exemplo n.º 2
0
def train():
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    train_dataset, dev_dataset = Poetry.get_datasets(['train', 'dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len,
                                 noise_prob=args.noise_prob,
                                 use_random_noice=args.use_random_noice)

    train_dataset = train_dataset.apply(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)
    dev_data_loader = DataLoader(dataset=dev_dataset,
                                 batch_sampler=dev_batch_sampler,
                                 collate_fn=batchify_fn,
                                 num_workers=0,
                                 return_list=True)

    label_num = model.word_emb.weight.shape[0]
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    max_steps = len(train_data_loader) * args.num_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, max_steps,
                                         args.warmup_proportion)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(1.0),
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    global_step = 1
    tic_train = time.time()
    for epoch in range(args.num_epochs):
        for step, batch in enumerate(train_data_loader, start=1):
            (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids,
             attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
             mask_attn_2_srctgtattn, tgt_labels, _) = batch
            # import pdb; pdb.set_trace()
            _, __, info = model(src_ids,
                                sent_ids=src_sids,
                                pos_ids=src_pids,
                                attn_bias=mask_src_2_src,
                                encode_only=True)
            cached_k, cached_v = info['caches']
            _, __, info = model(tgt_ids,
                                sent_ids=tgt_sids,
                                pos_ids=tgt_pids,
                                attn_bias=mask_tgt_2_srctgt,
                                past_cache=(cached_k, cached_v),
                                encode_only=True)
            cached_k2, cached_v2 = info['caches']
            past_cache_k = [
                paddle.concat([k, k2], 1)
                for k, k2 in zip(cached_k, cached_k2)
            ]
            past_cache_v = [
                paddle.concat([v, v2], 1)
                for v, v2 in zip(cached_v, cached_v2)
            ]
            if args.label_smooth > 0.:
                tgt_labels = nn.functional.label_smooth(
                    nn.functional.one_hot(tgt_labels, label_num),
                    epsilon=args.label_smooth)
            loss, _, __ = model(attn_ids,
                                sent_ids=tgt_sids,
                                pos_ids=tgt_pids,
                                attn_bias=mask_attn_2_srctgtattn,
                                past_cache=(past_cache_k, past_cache_v),
                                tgt_labels=tgt_labels,
                                tgt_pos=paddle.nonzero(attn_ids == attn_id))
            if global_step % args.logging_steps == 0:
                if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train), lr_scheduler.get_lr()))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_gradients()
            if global_step % args.save_steps == 0 and (
                (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0):
                evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2,
                         attn_id, tgt_type_id, args)
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
            global_step += 1
Exemplo n.º 3
0
def train():
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    train_dataset, dev_dataset = load_dataset(
        'poetry', splits=('train', 'dev'), lazy=False)
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(
        tokenizer=tokenizer,
        attn_id=attn_id,
        tgt_type_id=tgt_type_id,
        max_encode_len=args.max_encode_len,
        max_decode_len=args.max_decode_len,
        noise_prob=args.noise_prob,
        use_random_noice=args.use_random_noice)

    train_dataset = train_dataset.map(trans_func)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # src_tids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # tgt_tids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_sampler=train_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)

    dev_dataset = dev_dataset.map(trans_func)
    dev_data_loader = DataLoader(
        dataset=dev_dataset,
        batch_size=args.batch_size,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)

    label_num = model.word_emb.weight.shape[0]
    train_model = StackModel(model)
    if paddle.distributed.get_world_size() > 1:
        # All 'forward' outputs derived from the module parameters using in DataParallel
        # must participate in the calculation of losses and subsequent gradient calculations.
        # So we use StackModel here to make the model only output loss in its 'forward' function.
        train_model = paddle.DataParallel(train_model)

    max_steps = len(train_data_loader) * args.num_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, max_steps,
                                         args.warmup_proportion)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(1.0),
        apply_decay_param_fun=lambda x: x in decay_params)

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    global_step = 1
    tic_train = time.time()
    for epoch in range(args.num_epochs):
        for step, batch in enumerate(train_data_loader, start=1):
            (src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids,
             mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
             tgt_labels, _) = batch
            # import pdb; pdb.set_trace()
            if args.label_smooth > 0.:
                tgt_labels = nn.functional.label_smooth(
                    nn.functional.one_hot(tgt_labels, label_num),
                    epsilon=args.label_smooth)
            tgt_pos = paddle.nonzero(attn_ids == attn_id)
            loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids,
                               tgt_pids, attn_ids, mask_src_2_src,
                               mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
                               tgt_labels, tgt_pos)
            if global_step % args.logging_steps == 0:
                if paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train), lr_scheduler.get_lr()))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0 and paddle.distributed.get_rank(
            ) == 0:
                evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2,
                         attn_id, tgt_type_id, args)
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
            global_step += 1
Exemplo n.º 4
0
def do_train(args):
    set_seed(args)

    DEV, TEST, TOKENIZER_CLASS = DATASET_INFO[args.dataset]
    tokenizer = TOKENIZER_CLASS.from_pretrained(args.model_name_or_path)

    train_ds, eval_ds, test_ds = load_dataset(args.dataset,
                                              splits=['train', DEV, TEST])

    paddle.set_device(args.device)
    trainer_num = paddle.distributed.get_world_size()
    if trainer_num > 1:
        paddle.distributed.init_parallel_env()
    rank = paddle.distributed.get_rank()
    if rank == 0:
        if os.path.exists(args.model_name_or_path):
            logger.info("init checkpoint from %s" % args.model_name_or_path)

    model = ErnieDocForQuestionAnswering.from_pretrained(
        args.model_name_or_path, dropout=args.dropout)
    model_config = model.ernie_doc.config
    if trainer_num > 1:
        model = paddle.DataParallel(model)

    train_ds_iter = MRCIterator(train_ds,
                                args.batch_size,
                                tokenizer,
                                trainer_num,
                                trainer_id=rank,
                                memory_len=model_config["memory_len"],
                                max_seq_length=args.max_seq_length,
                                random_seed=args.seed)

    eval_ds_iter = MRCIterator(eval_ds,
                               args.batch_size,
                               tokenizer,
                               trainer_num,
                               trainer_id=rank,
                               memory_len=model_config["memory_len"],
                               max_seq_length=args.max_seq_length,
                               mode="eval",
                               random_seed=args.seed)

    test_ds_iter = MRCIterator(test_ds,
                               args.batch_size,
                               tokenizer,
                               trainer_num,
                               trainer_id=rank,
                               memory_len=model_config["memory_len"],
                               max_seq_length=args.max_seq_length,
                               mode="test",
                               random_seed=args.seed)

    train_dataloader = paddle.io.DataLoader.from_generator(capacity=70,
                                                           return_list=True)
    train_dataloader.set_batch_generator(train_ds_iter, paddle.get_device())

    eval_dataloader = paddle.io.DataLoader.from_generator(capacity=70,
                                                          return_list=True)
    eval_dataloader.set_batch_generator(eval_ds_iter, paddle.get_device())

    test_dataloader = paddle.io.DataLoader.from_generator(capacity=70,
                                                          return_list=True)
    test_dataloader.set_batch_generator(test_ds_iter, paddle.get_device())

    num_training_examples = train_ds_iter.get_num_examples()
    num_training_steps = args.epochs * num_training_examples // args.batch_size // trainer_num
    logger.info("Device count: %d, trainer_id: %d" % (trainer_num, rank))
    logger.info("Num train examples: %d" % num_training_examples)
    logger.info("Max train steps: %d" % num_training_steps)
    logger.info("Num warmup steps: %d" %
                int(num_training_steps * args.warmup_proportion))

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps,
                                         args.warmup_proportion)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    # Construct dict
    name_dict = dict()
    for n, p in model.named_parameters():
        name_dict[p.name] = n

    optimizer = AdamWDL(learning_rate=lr_scheduler,
                        parameters=model.parameters(),
                        weight_decay=args.weight_decay,
                        apply_decay_param_fun=lambda x: x in decay_params,
                        n_layers=model_config["num_hidden_layers"],
                        layerwise_decay=args.layerwise_decay,
                        name_dict=name_dict)

    global_steps = 0
    create_memory = partial(init_memory, args.batch_size, args.memory_length,
                            model_config["hidden_size"],
                            model_config["num_hidden_layers"])

    criterion = CrossEntropyLossForQA()

    memories = create_memory()
    tic_train = time.time()
    best_avg_metric = -1
    for epoch in range(args.epochs):
        train_ds_iter.shuffle_sample()
        train_dataloader.set_batch_generator(train_ds_iter,
                                             paddle.get_device())
        for step, batch in enumerate(train_dataloader, start=1):
            global_steps += 1
            input_ids, position_ids, token_type_ids, attn_mask, start_position, \
                end_position, qids, gather_idx, need_cal_loss = batch
            start_logits, end_logits, memories = model(input_ids, memories,
                                                       token_type_ids,
                                                       position_ids, attn_mask)

            start_logits, end_logits, qids, start_position, end_position = list(
                map(lambda x: paddle.gather(x, gather_idx), [
                    start_logits, end_logits, qids, start_position,
                    end_position
                ]))
            loss = criterion([start_logits, end_logits],
                             [start_position, end_position]) * need_cal_loss

            mean_loss = loss.mean()
            mean_loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

            if global_steps % args.logging_steps == 0:
                logger.info(
                    "train: global step %d, epoch: %d, loss: %f, lr: %f, speed: %.2f step/s"
                    % (global_steps, epoch, mean_loss, lr_scheduler.get_lr(),
                       args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()

            if global_steps % args.save_steps == 0:
                # Evaluate
                logger.info("Eval:")
                EM, F1, AVG = evaluate(args, model, criterion,
                                       EM_AND_F1(), eval_dataloader,
                                       create_memory(), tokenizer)
                if rank == 0:
                    output_dir = os.path.join(args.output_dir,
                                              "model_%d" % (global_steps))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model._layers if isinstance(
                        model, paddle.DataParallel) else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    if best_avg_metric < AVG:
                        output_dir = os.path.join(args.output_dir,
                                                  "best_model")
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model._layers if isinstance(
                            model, paddle.DataParallel) else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)

            if args.max_steps > 0 and global_steps >= args.max_steps:
                return
    logger.info("Test:")
    evaluate(args, model, criterion, EM_AND_F1(), test_dataloader,
             create_memory(), tokenizer)
    if rank == 0:
        output_dir = os.path.join(args.output_dir, "model_%d" % (global_steps))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        model_to_save = model._layers if isinstance(
            model, paddle.DataParallel) else model
        model_to_save.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
Exemplo n.º 5
0
    def finetune(
        self,
        train_path,
        dev_path=None,
        save_dir="ernie_gen_result",
        init_ckpt_path=None,
        use_gpu=True,
        max_steps=500,
        batch_size=8,
        max_encode_len=50,
        max_decode_len=50,
        learning_rate=5e-5,
        warmup_proportion=0.1,
        weight_decay=0.1,
        noise_prob=0,
        label_smooth=0,
        beam_width=5,
        length_penalty=1.0,
        log_interval=100,
        save_interval=200,
    ):
        """
        finetune with the specified dataset.

        Args:
            train_path(str): the train dataset path.
            dev_path(str): the dev dataset path.
            save_dir(str): the model params and dev dataset predict result save path.
            init_ckpt_path(str): incremental training load path.
            use_gpu(bool): use gpu or not.
            max_steps(int): max training steps.
            batch_size(int): the batch size.
            max_encode_len(int): the max encode length.
            max_decode_len(int): the max decode length.
            learning_rate(float): the learning rate.
            warmup_proportion(float): the warmup proportion.
            weight_decay(float): the weight decay magnitude.
            noise_prob(float): the nosie probability. see the ernie gen paper for details.
            label_smooth(float): the label smooth magnitude.
            beam_width(int): the beam size during evaluating the dev dataset.
            length_penalty(float): the length penalty during evaluating the dev dataset.
            log_interval(int): the log interval.
            save_interval(int): the save interval. dev set will be evaluated after saving.

        Return:
            result(dict): A Dictionary of shape::
                {
                    last_save_path(str): last model save path.
                    last_ppl(float): last model ppl.
                }
        """
        paddle.disable_static()
        paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')

        if init_ckpt_path is not None:
            logger.info('loading checkpoint from %s' % init_ckpt_path)
            sd = paddle.load(init_ckpt_path)
            self.model.set_state_dict(sd)

        train_dataset = self._load_dataset(train_path)
        attn_id = self.tokenizer.vocab['[MASK]']
        trans_func = convert_example(tokenizer=self.tokenizer,
                                     attn_id=attn_id,
                                     tgt_type_id=1,
                                     max_encode_len=max_encode_len,
                                     max_decode_len=max_decode_len,
                                     noise_prob=noise_prob)

        train_dataset = train_dataset.map(trans_func)
        train_batch_sampler = paddle.io.BatchSampler(train_dataset, batch_size=batch_size, shuffle=True)
        batchify_fn = lambda samples, fn=Tuple(
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # src_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # src_pids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id),  # src_tids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_pids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id),  # tgt_tids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # attn_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_labels
        ): after_padding(fn(samples))
        train_data_loader = DataLoader(dataset=train_dataset,
                                       batch_sampler=train_batch_sampler,
                                       collate_fn=batchify_fn,
                                       num_workers=0,
                                       return_list=True)

        if dev_path:
            dev_dataset = self._load_dataset(dev_path)
            dev_dataset = dev_dataset.map(trans_func)
            dev_data_loader = DataLoader(dataset=dev_dataset,
                                         batch_size=batch_size,
                                         collate_fn=batchify_fn,
                                         num_workers=0,
                                         return_list=True)

        label_num = self.model.word_emb.weight.shape[0]
        train_model = StackModel(self.model)
        lr_scheduler = LinearDecayWithWarmup(learning_rate, max_steps, warmup_proportion)
        # Generate parameter names needed to perform weight decay.
        # All bias and LayerNorm parameters are excluded.
        decay_params = [p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
        optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,
                                           parameters=self.model.parameters(),
                                           weight_decay=weight_decay,
                                           grad_clip=nn.ClipGradByGlobalNorm(1.0),
                                           apply_decay_param_fun=lambda x: x in decay_params)

        rouge1 = Rouge1()
        rouge2 = Rouge2()
        global_step = 1
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir)
        while True:
            for batch in train_data_loader:
                (src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
                 mask_attn_2_srctgtattn, tgt_labels, _) = batch
                if label_smooth > 0.:
                    tgt_labels = nn.functional.label_smooth(nn.functional.one_hot(tgt_labels, label_num),
                                                            epsilon=label_smooth)

                tgt_pos = paddle.nonzero(attn_ids == attn_id)
                loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src,
                                   mask_tgt_2_srctgt, mask_attn_2_srctgtattn, tgt_labels, tgt_pos)

                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()

                if global_step % log_interval == 0 and paddle.distributed.get_rank() == 0:
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
                    logger.info('[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' %
                                (global_step, max_steps, loss_np, ppl, lr_scheduler.get_lr()))
                if save_dir and global_step % save_interval == 0 and global_step > 0:
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
                    save_name = "step_%s_ppl_%.5f.params" % (global_step, ppl)
                    save_path = os.path.join(save_dir, save_name)
                    logger.info("save the model in %s" % save_path)
                    paddle.save(self.model.state_dict(), save_path)

                    if dev_path:
                        self._evaluate(self.model, dev_data_loader, self.tokenizer, rouge1, rouge2, attn_id,
                                       max_decode_len, max_encode_len, beam_width, length_penalty)

                if global_step >= max_steps:
                    break
                global_step += 1

            if global_step >= max_steps:
                break

        if global_step % save_interval != 0:
            loss_np = loss.numpy()
            ppl = np.exp(loss_np)
            logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' %
                        (global_step, loss_np, ppl, lr_scheduler.get_lr()))
            if save_dir:
                save_name = "step_%s_ppl_%.5f.pdparams" % (global_step, ppl)
                save_path = os.path.join(save_dir, save_name)
                logger.info("save the model in %s" % save_path)
                paddle.save(self.model.state_dict(), save_path)

                if dev_path:
                    self._evaluate(self.model, dev_data_loader, self.tokenizer, rouge1, rouge2, attn_id, max_decode_len,
                                   max_encode_len, beam_width, length_penalty)

        result = {
            "last_save_path": "%s" % save_path,
            "last_ppl": ppl[0],
        }

        return result
Exemplo n.º 6
0
def do_train(args):
    paddle.set_device(args.device)

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    if worker_num > 1:
        paddle.distributed.init_parallel_env()

    if args.dp_degree * args.sharding_degree == 1:
        args.dp_degree = worker_num
        args.sharding_degree = 1

    args_post_process(args, worker_num)

    logger.info('{:20}:{}'.format("paddle commit id", paddle.version.commit))
    for arg in vars(args):
        logger.info('{:20}:{}'.format(arg, getattr(args, arg)))

    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": 1,
        "pp_degree": 1,
        "sharding_degree": 1
    }

    fleet.init(is_collective=True, strategy=strategy)
    hcg = fleet.get_hybrid_communicate_group()

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # Create the random seed for the worker
    set_seed(args)

    assert args.dp_degree * args.sharding_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    # Create log write,
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size, args.use_amp,
            args.use_recompute, worker_index).lower())
    log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
        model = model_class(base_class(**model_config))
    else:
        model = model_class.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    criterion = criterion_class()

    # Create the learning_rate sheduler and optimizer
    if args.decay_steps is None:
        args.decay_steps = args.max_steps

    lr_scheduler = LinearDecayWithWarmup(args.max_lr,
                                         args.max_steps,
                                         args.warmup_rate,
                                         last_epoch=global_step)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.fluid.clip.GradientClipByGlobalNorm(
            clip_norm=args.grad_clip)

    decay_param = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    logger.info("Using paddle.optimizer.AdamW.")
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler
        if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_param,
        multi_precision=args.use_amp)

    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(models=model,
                                    level='O2',
                                    save_dtype='float32')

    if paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    data_file = get_train_data_file(args)

    train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
        args,
        data_file,
        tokenizer,
        data_world_size=worker_num,
        data_world_rank=worker_index,
        max_seq_len=args.max_seq_len,
        current_step=global_step)

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            logger.info("Try to load checkpoint from %s " % checkpoint_dir)
            opt_path = os.path.join(checkpoint_dir, "model_state.pdopt")
            params_path = os.path.join(checkpoint_dir, "model_state.pdparams")

            if os.path.exists(opt_path):
                opt_dict = paddle.load(opt_path)
                optimizer.set_state_dict(opt_dict)
                model_dict = paddle.load(params_path)
                model.set_state_dict(model_dict)
            else:
                logger.warning("No optimizer checkpoint file found in %s." %
                               opt_path)
            logger.info(
                "Checkpoint loaded from global step: {}".format(global_step))

    tic_train = time.time()
    while True:
        # If not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        # time count
        train_reader_cost = 0.0
        train_run_cost = 0.0
        reader_start = time.time()

        for step, batch in enumerate(train_data_loader()):
            train_reader_cost += time.time() - reader_start
            train_start = time.time()

            # 0. input_ids,
            # 1. segment_ids,
            # 2. input_mask,
            # 3. masked_lm_positions,
            # 4. masked_lm_labels,
            # 5. next_sentence_labels

            input_ids, segment_ids, input_mask, masked_lm_positions, \
            masked_lm_labels, next_sentence_labels = batch

            with paddle.amp.auto_cast(args.use_amp,
                                      custom_black_list=[
                                          "reduce_sum",
                                          "c_softmax_with_cross_entropy",
                                          "elementwise_div"
                                      ],
                                      level='O2'):

                # Create the model for the ernie pretrain
                prediction_scores, seq_relationship_score = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    position_ids=None,
                    attention_mask=input_mask,
                    masked_positions=masked_lm_positions)

                lm_loss, sop_loss = criterion(prediction_scores,
                                              seq_relationship_score,
                                              masked_lm_labels,
                                              next_sentence_labels)
                loss = lm_loss + sop_loss

            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
            else:
                loss.backward()
                optimizer.step()

            optimizer.clear_grad()
            train_run_cost += time.time() - train_start

            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue

            global_step += 1

            if global_step % args.logging_freq == 0:
                speed = args.logging_freq / (time.time() - tic_train)
                common_loginfo = "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                    global_step, loss.item(), lm_loss.item(),
                    sop_loss.item(), speed, speed * args.global_batch_size,
                    lr_scheduler.get_lr())
                addition_info = ""
                if args.use_amp:
                    addition_info = " loss_scaling: %.1f, incr_count: %d, decr_count: %d" % (
                        scaler._scale.numpy(), scaler._incr_count,
                        scaler._decr_count)
                logger.info(common_loginfo + addition_info)
                log_writer.add_scalar("loss", loss.item(), global_step)
                log_writer.add_scalar("lm_loss", lm_loss.item(), global_step)
                log_writer.add_scalar("sop_loss", sop_loss.item(), global_step)

                tic_train = time.time()

            if lr_scheduler is not None:
                lr_scheduler.step()

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation

                run_evaluate(valid_data_loader,
                             model,
                             criterion,
                             args.eval_iters,
                             log_writer,
                             global_step,
                             args,
                             task_name="valid")
                tic_train = time.time()

            def save_ckpt(output_dir, model, tokenizer, args, global_step):
                step_config = {
                    "model_name": args.model_name_or_path,
                    "global_step": global_step,
                    "global_batch_size": args.global_batch_size,
                    "consumed_samples": global_step * args.global_batch_size,
                }

                logger.debug("saving models to {}".format(output_dir))
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model

                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                paddle.save(optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))

                with open(os.path.join(output_dir, "config.yml"), "w") as f:
                    yaml.dump(step_config,
                              f,
                              encoding='utf-8',
                              allow_unicode=True)

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if worker_index == 0:
                    save_ckpt(output_dir, model, tokenizer, args, global_step)

                if worker_num > 1:
                    paddle.distributed.barrier()
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)
                    save_ckpt(output_dir, model, tokenizer, args, global_step)

                if worker_num > 1:
                    paddle.distributed.barrier()

            if global_step >= args.max_steps:
                run_evaluate(test_data_loader,
                             model,
                             criterion,
                             args.test_iters,
                             log_writer,
                             global_step,
                             args,
                             task_name="test")
                del train_data_loader
                return