Esempio n. 1
0
def amp_gen_logits(config: fine_tune.config.BaseConfig,
                   dataset: fine_tune.task.Dataset,
                   model: fine_tune.model.Model,
                   tokenizer: transformers.PreTrainedTokenizer):
    r"""Generate fine-tuned model logits with `torch.cuda.amp` on task specific dataset.

    Args:
        config:
            `fine_tune.config.BaseConfig` subclass which attributes are used
            for experiment setup.
        dataset:
            Task specific dataset.
        model:
            Model which will generate logits on `dataset`.
        tokenizer:
            Tokenizer paired with `model`.
    """
    # Evaluation mode.
    model.eval()

    # Model running device.
    device = config.device

    # Get experiment name and model name.
    experiment_name = fine_tune.config.BaseConfig.experiment_name(
        experiment=config.experiment, model=config.model, task=config.task)

    # Create dataloader.
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        collate_fn=dataset.create_collate_fn(max_seq_len=config.max_seq_len,
                                             tokenizer=tokenizer),
        shuffle=False)

    # Sample index for updating logits.
    sample_index = 0

    # Generate logits through mini-batch loop.
    for (input_ids, attention_mask, token_type_ids, _, _) in tqdm(dataloader):

        # Enable autocast.
        with torch.cuda.amp.autocast():
            # Get mini-batch logits.
            batch_logits = model(
                input_ids=input_ids.to(device),
                token_type_ids=token_type_ids.to(device),
                attention_mask=attention_mask.to(device)).to('cpu').tolist()

        # Update logits.
        for index, logits in enumerate(batch_logits):
            dataset.update_logits(index=index + sample_index, logits=logits)

        # Shift sample index.
        sample_index += len(batch_logits)

    # Save logits.
    dataset.save_for_distill(experiment_name)
Esempio n. 2
0
def train(
    config: fine_tune.config.BaseConfig,
    dataset: fine_tune.task.Dataset,
    model: fine_tune.model.Model,
    optimizer: torch.optim.AdamW,
    scheduler: torch.optim.lr_scheduler.LambdaLR,
    tokenizer: transformers.PreTrainedTokenizer,
):
    r"""Fine-tune or distill model on task specific dataset.

    Args:
        config:
            `fine_tune.config.BaseConfig` subclass which attributes are used
            for experiment setup.
        dataset:
            Task specific dataset.
        model:
            Model which will be fine-tuned on `dataset`.
        optimizer:
            `torch.optim.AdamW` optimizer.
        schduler:
            Linear warmup scheduler provided by `transformers` package.
        tokenizer:
            Tokenizer paired with `model`.
    """
    # Training mode.
    model.train()

    # Model running device.
    device = config.device

    # Clean all gradient.
    optimizer.zero_grad()

    # Get experiment name and path.
    experiment_name = fine_tune.config.BaseConfig.experiment_name(
        experiment=config.experiment, model=config.model, task=config.task)
    experiment_dir = os.path.join(fine_tune.path.FINE_TUNE_EXPERIMENT,
                                  experiment_name)

    # Create dataloader.
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size // config.accum_step,
        collate_fn=dataset.create_collate_fn(max_seq_len=config.max_seq_len,
                                             tokenizer=tokenizer),
        shuffle=True)

    # Create tensorboard's `SummaryWriter`.
    writer = torch.utils.tensorboard.SummaryWriter(
        os.path.join(fine_tune.path.LOG, experiment_name))

    # Use cross-entropy as objective.
    objective = nn.CrossEntropyLoss()

    # Step and accumulation step counter.
    step = 0
    accum_step = 0
    total_accum_step = config.total_step * config.accum_step

    # Mini-batch loss and accumulate loss.
    # Update when accumulate to `config.batch_size`.
    loss = 0
    accum_loss = 0

    # `tqdm` CLI Logger. We will manually update progress bar.
    cli_logger = tqdm(desc=f'loss: {loss:.6f}', total=config.total_step)

    # Total update times: `config.total_step`.
    while accum_step < total_accum_step:

        # Mini-batch loop.
        for (input_ids, attention_mask, token_type_ids, label,
             _) in dataloader:

            # Accumulate cross-entropy loss.
            # Use `model(...)` to do forward pass.
            accum_loss = objective(input=model(
                input_ids=input_ids.to(device),
                token_type_ids=token_type_ids.to(device),
                attention_mask=attention_mask.to(device)),
                                   target=label.to(device)) / config.accum_step

            # Mini-batch cross-entropy loss. Only used as log.
            loss += accum_loss.item()

            # Backward pass accumulation loss.
            accum_loss.backward()

            # Increment accumulation step.
            accum_step += 1

            # Perform gradient descend when achieve actual mini-batch size.
            if accum_step % config.accum_step == 0:
                # Gradient clipping.
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_norm)

                # Gradient descend.
                optimizer.step()

                # Update learning rate.
                scheduler.step()

                # Log on CLI.
                cli_logger.update()
                cli_logger.set_description(f'loss: {loss:.6f}')

                # Increment actual step.
                step += 1

                # Log loss and learning rate for each `config.log_step` step.
                if step % config.log_step == 0:
                    writer.add_scalar(f'{config.task}/{config.dataset}/loss',
                                      loss, step)
                    writer.add_scalar(
                        f'{config.task}/{config.dataset}/lr',
                        optimizer.state_dict()['param_groups'][0]['lr'], step)

                # Clean up mini-batch loss.
                loss = 0

                # Clean up gradient.
                optimizer.zero_grad()

                # Save model for each `config.ckpt_step` step.
                if step % config.ckpt_step == 0:
                    torch.save(
                        model.state_dict(),
                        os.path.join(experiment_dir, f'model-{step}.pt'))

            # Stop training condition.
            if accum_step >= total_accum_step:
                break

    # Release IO resources.
    writer.flush()
    writer.close()
    cli_logger.close()

    # Save the lastest model.
    torch.save(model.state_dict(),
               os.path.join(experiment_dir, f'model-{step}.pt'))