コード例 #1
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer, orig) -> Tuple[int, float]:
    """ Train the model """
    record_result = []

    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    pruning_step = 0

    print('starting pruning')
    pruning_model(model, 1 / (10 - pruning_step))
    rate_weight_equal_zero = see_weight_rate(model)
    pruning_step += 1
    print('zero_rate = ', rate_weight_equal_zero)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well

                        rate_weight_equal_zero = see_weight_rate(model)
                        print('zero_rate = ', rate_weight_equal_zero)

                        results = evaluate(args, model, tokenizer)
                        print(results)

                        record_result.append(results)

                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    # model_to_save = (
                    #     model.module if hasattr(model, "module") else model
                    # )  # Take care of distributed/parallel training
                    # model_to_save.save_pretrained(output_dir)
                    # tokenizer.save_pretrained(output_dir)

                    if hasattr(model, "module"):
                        torch.save(model.module,
                                   os.path.join(output_dir, "model.pt"))
                    else:
                        torch.save(model, os.path.join(output_dir, "model.pt"))

                    # torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    # torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    # torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    # logger.info("Saving optimizer and scheduler states to %s", output_dir)

                    print('starting pruning')
                    print('saving model for ', 100 - pruning_step * 10)
                    pruning_model(model, 1 / (10 - pruning_step))
                    rate_weight_equal_zero = see_weight_rate(model)
                    pruning_step += 1
                    print('zero_rate = ', rate_weight_equal_zero)

                    print('rewinding')
                    model_dict = model.state_dict()
                    model_dict.update(orig)
                    model.load_state_dict(model_dict)

                    print('optimizer rewinding')
                    no_decay = ["bias", "LayerNorm.weight"]
                    optimizer_grouped_parameters = [
                        {
                            "params": [
                                p for n, p in model.named_parameters()
                                if not any(nd in n for nd in no_decay)
                            ],
                            "weight_decay":
                            args.weight_decay,
                        },
                        {
                            "params": [
                                p for n, p in model.named_parameters()
                                if any(nd in n for nd in no_decay)
                            ],
                            "weight_decay":
                            0.0
                        },
                    ]
                    optimizer = AdamW(optimizer_grouped_parameters,
                                      lr=args.learning_rate,
                                      eps=args.adam_epsilon)
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=t_total)

            if pruning_step == 10:
                epoch_iterator.close()
                break

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if pruning_step == 10:
            epoch_iterator.close()
            break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    torch.save(record_result, os.path.join(args.output_dir, "result.pt"))

    return global_step, tr_loss / global_step
コード例 #2
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer):

    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())

    writer = SummaryWriter()

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  args.train_batch_size,
                                  sampler=train_sampler,
                                  collate_fn=collate,
                                  num_workers=args.num_workers)

    t_total = len(train_dataloader) \
        // args.gradient_accumulation_steps \
        * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                args.warmup_steps,
                                                num_training_steps=t_total)

    if args.fp16:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=params.fp16_opt_level)

    if args.n_gpu > 1:
        model = nn.DataParallel(model)

    # TODO: Loading checkpoint for AMP
    # Train!
    logger.info('***** Running training *****')
    logger.info(f'  Num examples = {len(train_dataset)}')
    logger.info(f'  Num Epochs = {args.num_train_epochs}')
    logger.info(
        f'  Instantaneous batch size per GPU = {args.per_gpu_train_batch_size}'
    )
    logger.info('  Total train batch size (w. parallel, & accumulation) = %d',
                args.train_batch_size * args.gradient_accumulation_steps)
    logger.info(
        f'  Gradient Accumulation steps = {args.gradient_accumulation_steps}')
    logger.info(f'  Total optimization steps = {t_total}')

    global_step = 0
    best_perplexity = 0.0
    training_loss, running_loss = 0.0, 0.0

    # Take care of distributed/parallel training
    model_to_resize = model.module if hasattr(model, "module") else model
    model_to_resize.resize_token_embeddings(len(tokenizer))
    model.train()

    for epoch in range(args.num_train_epochs):
        print(f'Epoch {epoch}/{args.num_train_epochs - 1}')
        print('-' * 10)

        for step, batch in enumerate(tqdm(train_dataloader)):
            inputs, labels = mask_tokens(batch, tokenizer, args) \
                if args.mlm else (batch, batch)
            inputs, labels = inputs.to(args.device), labels.to(args.device)

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = model(inputs, masked_lm_labels=labels) \
                if args.mlm else model(inputs, labels=labels)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            training_loss += loss.item()
            running_loss += loss.item() * inputs.size(0)
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                             args.max_grad_norm)
                else:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                global_step += 1

                # TODO: args.evaluate_during_training
                writer.add_scalar('learning_rate',
                                  scheduler.get_lr()[0], global_step)
                writer.add_scalar('loss/training', training_loss, global_step)
                training_loss = 0.0

        epoch_loss = running_loss / len(train_dataset)
        # TODO: Evaluates and saves checkpoint after every epoch
        result = evaluate(args, model, tokenizer)
        epoch_perplexity = result.get('perplexity')

        if step == 0:
            best_perplexity = epoch_perplexity
        else:
            if epoch_perplexity < best_perplexity:
                best_perplexity = epoch_perplexity

        writer.add_scalar('perplexity per epoch', epoch_perplexity, epoch)
        print(f'Loss: {epoch_loss:.4f} perplexity:{epoch_perplexity}')

    writer.close()

    time_elapsed = time.time() - since
    print('Training completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print(f'Perplexity: {best_perplexity}')

    return model