Esempio n. 1
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    args.warmup_steps = int(t_total * args.warmup_proportion)
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    global_step = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path
                      ) and "checkpoint" in args.model_name_or_path:
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    if args.do_adv:
        fgm = FGM(model, emb_name=args.adv_name, epsilon=args.adv_epsilon)
    model.zero_grad()
    seed_everything(
        args.seed
    )  # Added here for reproductibility (even between python 2 and 3)
    for _ in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        for step, batch in enumerate(train_dataloader):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                # XLM and RoBERTa don"t use segment_ids
                inputs["token_type_ids"] = (batch[2] if args.model_type
                                            in ["bert", "xlnet"] else None)
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            if args.do_adv:
                fgm.attack()
                loss_adv = model(**inputs)[0]
                if args.n_gpu > 1:
                    loss_adv = loss_adv.mean()
                loss_adv.backward()
                fgm.restore()
            pbar(step, {'loss': loss.item()})
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    print(" ")
                    if args.local_rank == -1:
                        # Only evaluate when single GPU otherwise metrics may not average well
                        evaluate(args, model, tokenizer)

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Take care of distributed/parallel training
                    model_to_save = (model.module
                                     if hasattr(model, "module") else model)
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    tokenizer.save_vocabulary(output_dir)
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)
        print(" ")
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()
    return global_step, tr_loss / global_step
Esempio n. 2
0
def train(args, train_features, model, tokenizer, use_crf):
    """ Train the model """
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    bert_param_optimizer = list(model.bert.named_parameters())

    if args.model_encdec == 'bert2crf':
        crf_param_optimizer = list(model.crf.named_parameters())
        linear_param_optimizer = list(model.classifier.named_parameters())
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in bert_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in bert_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in crf_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.crf_learning_rate
        }, {
            'params': [
                p for n, p in crf_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.crf_learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.crf_learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.crf_learning_rate
        }]

    elif args.model_encdec == 'bert2gru':
        gru_param_optimizer = list(model.decoder.named_parameters())
        linear_param_optimizer = list(model.clsdense.named_parameters())
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in bert_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in bert_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in gru_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.crf_learning_rate
        }, {
            'params': [
                p for n, p in gru_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.crf_learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.crf_learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.crf_learning_rate
        }]

    elif args.model_encdec == 'bert2soft':
        linear_param_optimizer = list(model.classifier.named_parameters())
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in bert_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in bert_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.crf_learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.crf_learning_rate
        }]

    elif args.model_encdec == 'multi2point':
        # gru_param_optimizer = list(model.decoder.named_parameters())
        linear_param_optimizer = list(model.pointer.named_parameters())
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in bert_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in bert_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.point_learning_rate
        }, {
            'params': [
                p for n, p in linear_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.point_learning_rate
        }]

    t_total = len(train_features) // args.batch_size * args.num_train_epochs
    args.warmup_steps = int(t_total * args.warmup_proportion)

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_features))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.batch_size *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    # logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    pre_result = {}
    model.zero_grad()
    seed_everything(
        args.seed
    )  # Added here for reproductibility (even between python 2 and 3)
    total_step = 0
    best_spanf = -1

    test_results = {}
    for ep in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_features) // args.batch_size,
                           desc='Training')
        if ep == int(args.num_train_epochs) - 1:
            eval_features = load_and_cache_examples(args,
                                                    args.data_type,
                                                    tokenizer,
                                                    data_type='dev')
            train_features.extend(eval_features)

        step = 0
        for batch in batch_generator(features=train_features,
                                     batch_size=args.batch_size,
                                     use_crf=use_crf,
                                     answer_seq_len=args.answer_seq_len):
            batch_input_ids, batch_input_mask, batch_segment_ids, batch_label_ids, batch_multi_span_label, batch_context_mask, batch_start_position, batch_end_position, batch_raw_labels, _, batch_example = batch
            model.train()
            if args.model_encdec == 'bert2crf' or args.model_encdec == 'bert2gru' or args.model_encdec == 'bert2soft':
                batch_inputs = tuple(t.to(args.device) for t in batch[0:6])
                inputs = {
                    "input_ids": batch_inputs[0],
                    "attention_mask": batch_inputs[1],
                    "token_type_ids": batch_inputs[2],
                    "context_mask": batch_inputs[5],
                    "labels": batch_inputs[3],
                    "testing": False
                }

            elif args.model_encdec == 'multi2point':
                batch_inputs = tuple(t.to(args.device) for t in batch[0:5])
                inputs = {
                    "input_ids": batch_inputs[0],
                    "attention_mask": batch_inputs[1],
                    "token_type_ids": batch_inputs[2],
                    "span_label": batch_inputs[4],
                    "testing": False
                }

            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            loss.backward()
            if step % 15 == 0:
                pbar(step, {'epoch': ep, 'loss': loss.item()})
            step += 1
            tr_loss += loss.item()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)
            scheduler.step()  # Update learning rate schedule
            optimizer.step()
            model.zero_grad()
            global_step += 1
            if args.local_rank in [
                    -1, 0
            ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                # Log metrics
                print("start evalue")
                if args.local_rank == -1:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    results = evaluate(args=args,
                                       model=model,
                                       tokenizer=tokenizer,
                                       prefix="dev",
                                       use_crf=use_crf)
                    span_f = results['span_f']
                    if span_f > best_spanf:
                        output_dir = os.path.join(args.output_dir,
                                                  "checkpoint-bestf")
                        if os.path.exists(output_dir):
                            shutil.rmtree(output_dir)
                            print('remove file', args.output_dir)
                        print('\n\n eval results:', results)
                        test_results = evaluate(args=args,
                                                model=model,
                                                tokenizer=tokenizer,
                                                prefix="test",
                                                use_crf=use_crf)
                        print('\n\n test results', test_results)
                        print('\n epoch = :', ep)

                        best_spanf = span_f
                        os.makedirs(output_dir)
                        # print('dir = ', output_dir)
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        torch.save(
                            args, os.path.join(output_dir,
                                               "training_args.bin"))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)
                        tokenizer.save_vocabulary(output_dir)
                        torch.save(optimizer.state_dict(),
                                   os.path.join(output_dir, "optimizer.pt"))
                        torch.save(scheduler.state_dict(),
                                   os.path.join(output_dir, "scheduler.pt"))
                        logger.info(
                            "Saving optimizer and scheduler states to %s",
                            output_dir)

        np.random.seed()
        np.random.shuffle(train_features)
        logger.info("\n")
        # if 'cuda' in str(args.device):
        torch.cuda.empty_cache()
    return global_step, tr_loss / global_step, test_results