Ejemplo n.º 1
0
def train_model(model: torch.nn.Module, train_dl: torch.utils.data.DataLoader,
                dev_dl: torch.utils.data.DataLoader,
                optimizer: torch.optim.Optimizer,
                scheduler: torch.optim.lr_scheduler.LambdaLR,
                n_epochs: int) -> (Dict, Dict):
    loss_f = torch.nn.CrossEntropyLoss()
    best_val, best_model_weights = {'val_f1': 0}, None

    for ep in range(n_epochs):
        model.train()
        for batch in tqdm(train_dl, desc='Training'):
            optimizer.zero_grad()
            logits = model(batch[0])
            loss = loss_f(logits, batch[1])
            loss.backward()
            optimizer.step()

        val_p, val_r, val_f1, val_loss, _, _ = eval_model(model, dev_dl)
        current_val = {
            'val_p': val_p,
            'val_r': val_r,
            'val_f1': val_f1,
            'val_loss': val_loss,
            'ep': ep
        }

        print(current_val, flush=True)

        if current_val['val_f1'] > best_val['val_f1']:
            best_val = current_val
            best_model_weights = model.state_dict()

        scheduler.step(val_loss)

    return best_model_weights, best_val
Ejemplo n.º 2
0
def restore_state(optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR,
                  args, device):
    latest_checkpoint_file = get_latest_checkpoint_path(args.output)

    if latest_checkpoint_file or args.warm_start:
        if args.local_rank in [None, 0]:
            print(
                'Trying to load optimizer and scheduler state from checkpoint.'
            )

        checkpoint = torch.load(latest_checkpoint_file or args.pre_model,
                                map_location=device)
        iteration = checkpoint.get('iteration', 0)

        if checkpoint.get('scheduler_state_dict', None):
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            if args.local_rank in [None, 0]:
                print("Scheduler restored.")

        if checkpoint.get('optimizer_state_dict', None):
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if args.local_rank in [None, 0]:
                print("Optimizer restored.")
        loss = checkpoint.get('loss', 0)
        return iteration, loss
    else:
        return 0, 0.0
Ejemplo n.º 3
0
def train_one_epoch(model, train_loader: DataLoader, optimizer: AdamW,
                    scheduler: torch.optim.lr_scheduler.LambdaLR, epoch: int,
                    max_epoch: int):
    """Loops through the dataloader and trains the model and optimizer

    Args:
        model ([BertForSequenceClassification]): [description]
        dataloader (DataLoader): [description]
        optimizer (AdamW): [description]
        scheduler (torch.optim.lr_scheduler.LambdaLR): [description]

    Returns:
        [type]: [description]
    """

    total_loss = 0
    model.train()
    train_bar = tqdm(train_loader)
    train_bar.desc = f"Epoch {epoch}/{max_epoch} Loss: 0"
    for batch in train_bar:
        # `batch` pytorch tensors:
        #   [0]: input ids
        #   [1]: attention masks
        #   [2]: labels
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()

        outputs = model(b_input_ids,
                        token_type_ids=None,
                        attention_mask=b_input_mask,
                        labels=b_labels)

        tmp_eval_loss, logits = outputs[:2]
        loss = tmp_eval_loss  # torch.nn.functional.cross_entropy(logits,b_labels)

        total_loss += loss.item()
        loss.backward()
        utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        train_bar.desc = f"Epoch {epoch}/{max_epoch} Batch Loss: {loss.item()}"  # Note cross entropy loss does not increase with batch size

    avg_train_loss = total_loss / len(train_loader)
    return model, optimizer, scheduler, avg_train_loss
def train_model(model: torch.nn.Module,
                train_dl: BatchSampler, dev_dl: BatchSampler,
                optimizer: torch.optim.Optimizer,
                scheduler: torch.optim.lr_scheduler.LambdaLR,
                n_epochs: int,
                early_stopping: EarlyStopping) -> (Dict, Dict):
    if args.labels == 2:
        loss_f = torch.nn.BCEWithLogitsLoss()
    else:
        loss_f = torch.nn.CrossEntropyLoss()

    best_val, best_model_weights = {'val_f1': 0}, None

    for ep in range(n_epochs):
        losses = []
        model.train()
        for i, batch in enumerate(tqdm(train_dl, desc='Training')):
            optimizer.zero_grad()
            prediction = model(batch[0])
            loss = loss_f(prediction, batch[1])
            loss.backward()

            optimizer.step()
            losses.append(loss.item())

        print('Training loss:', np.mean(losses))
        val_p, val_r, val_f1, val_loss = eval_model(model, dev_dl)
        current_val = {
            'val_p': val_p, 'val_r': val_r, 'val_f1': val_f1,
            'val_loss': val_loss, 'ep': ep
        }

        print(current_val, flush=True)

        if current_val['val_f1'] > best_val['val_f1']:
            best_val = current_val
            best_model_weights = model.state_dict()

        scheduler.step(val_loss)
        if early_stopping.step(val_loss):
            print('Early stopping...')
            break

    return best_model_weights, best_val
Ejemplo n.º 5
0
def train_model(model: torch.nn.Module,
                train_dl: BatchSampler,
                dev_dl: BatchSampler,
                optimizer: torch.optim.Optimizer,
                scheduler: torch.optim.lr_scheduler.LambdaLR,
                n_epochs: int,
                labels: int = 3,
                early_stopping: EarlyStopping = None) -> (Dict, Dict):
    best_val, best_model_weights = {'val_f1': 0}, None

    for ep in range(n_epochs):
        for batch in tqdm(train_dl, desc='Training'):
            model.train()
            optimizer.zero_grad()
            loss, _ = model(batch[0],
                            attention_mask=batch[1],
                            labels=batch[2].long())[:2]

            loss.backward()
            optimizer.step()
            scheduler.step()

        val_p, val_r, val_f1, val_loss, _, _ = eval_model(
            model, dev_dl, labels)
        current_val = {
            'val_f1': val_f1,
            'val_p': val_p,
            'val_r': val_r,
            'val_loss': val_loss,
            'ep': ep
        }
        print(current_val, flush=True)

        if current_val['val_f1'] > best_val['val_f1']:
            best_val = current_val
            best_model_weights = model.state_dict()

        if early_stopping and early_stopping.step(val_f1):
            print('Early stopping...')
            break

    return best_model_weights, best_val
Ejemplo n.º 6
0
def train(classifier_model: BertForMultipleChoice,
          optimizer: torch.optim.Optimizer,
          scheduler: torch.optim.lr_scheduler.LambdaLR,
          dataloader: TensorDataset, device: torch.device,
          logger: logging.Logger, logging_steps: int) -> float:
    """
    モデルの訓練を行う。
    """
    classifier_model.train()

    count_steps = 0
    total_loss = 0

    for batch_idx, batch in enumerate(dataloader):
        batch = tuple(t for t in batch)
        bert_inputs = {
            "input_ids": batch[0].to(device),
            "attention_mask": batch[1].to(device),
            "token_type_ids": batch[2].to(device),
            "labels": batch[3].to(device)
        }

        classifier_model.zero_grad()
        #Forward propagation
        classifier_outputs = classifier_model(**bert_inputs)
        loss = classifier_outputs[0]
        #Backward propagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(classifier_model.parameters(), 1.0)
        #Update parameters
        optimizer.step()
        scheduler.step()

        count_steps += 1
        total_loss += loss.item()

        if batch_idx % logging_steps == 0:
            logger.info("Step: {}\tLoss: {}\tlr: {}".format(
                batch_idx, loss.item(), optimizer.param_groups[0]["lr"]))

    return total_loss / count_steps
def train_model(model: torch.nn.Module,
                train_dl: BatchSampler, dev_dl: BatchSampler,
                optimizer: torch.optim.Optimizer,
                scheduler: torch.optim.lr_scheduler.LambdaLR,
                n_epochs: int) -> (Dict, Dict):
    best_val, best_model_weights = {'val_f1': 0}, None

    for ep in range(n_epochs):
        for i, batch in enumerate(tqdm(train_dl, desc='Training')):
            model.train()
            optimizer.zero_grad()
            loss, _ = model(batch[0],
                                 attention_mask=batch[
                                                    0] !=
                                                tokenizer.pad_token_id,
                                 labels=batch[1])

            loss.backward()

            optimizer.step()
            scheduler.step()

            if i in [600, 700, 100]:
                print(eval_model(model, dev_dl), flush=True)

        val_p, val_r, val_f1, val_loss = eval_model(model, dev_dl)
        current_val = {
            'val_f1': val_f1,
            'val_p': val_p,
            'val_r': val_r,
            'val_loss': val_loss,
            'ep': ep
        }
        print(current_val, flush=True)

        if current_val['val_f1'] > best_val['val_f1']:
            best_val = current_val
            best_model_weights = model.state_dict()

    return best_model_weights, best_val
Ejemplo n.º 8
0
def train(
    config: fine_tune.config.BaseConfig,
    dataset: fine_tune.task.Dataset,
    model: fine_tune.model.Model,
    optimizer: torch.optim.AdamW,
    scheduler: torch.optim.lr_scheduler.LambdaLR,
    tokenizer: transformers.PreTrainedTokenizer,
):
    r"""Fine-tune or distill model on task specific dataset.

    Args:
        config:
            `fine_tune.config.BaseConfig` subclass which attributes are used
            for experiment setup.
        dataset:
            Task specific dataset.
        model:
            Model which will be fine-tuned on `dataset`.
        optimizer:
            `torch.optim.AdamW` optimizer.
        schduler:
            Linear warmup scheduler provided by `transformers` package.
        tokenizer:
            Tokenizer paired with `model`.
    """
    # Training mode.
    model.train()

    # Model running device.
    device = config.device

    # Clean all gradient.
    optimizer.zero_grad()

    # Get experiment name and path.
    experiment_name = fine_tune.config.BaseConfig.experiment_name(
        experiment=config.experiment, model=config.model, task=config.task)
    experiment_dir = os.path.join(fine_tune.path.FINE_TUNE_EXPERIMENT,
                                  experiment_name)

    # Create dataloader.
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size // config.accum_step,
        collate_fn=dataset.create_collate_fn(max_seq_len=config.max_seq_len,
                                             tokenizer=tokenizer),
        shuffle=True)

    # Create tensorboard's `SummaryWriter`.
    writer = torch.utils.tensorboard.SummaryWriter(
        os.path.join(fine_tune.path.LOG, experiment_name))

    # Use cross-entropy as objective.
    objective = nn.CrossEntropyLoss()

    # Step and accumulation step counter.
    step = 0
    accum_step = 0
    total_accum_step = config.total_step * config.accum_step

    # Mini-batch loss and accumulate loss.
    # Update when accumulate to `config.batch_size`.
    loss = 0
    accum_loss = 0

    # `tqdm` CLI Logger. We will manually update progress bar.
    cli_logger = tqdm(desc=f'loss: {loss:.6f}', total=config.total_step)

    # Total update times: `config.total_step`.
    while accum_step < total_accum_step:

        # Mini-batch loop.
        for (input_ids, attention_mask, token_type_ids, label,
             _) in dataloader:

            # Accumulate cross-entropy loss.
            # Use `model(...)` to do forward pass.
            accum_loss = objective(input=model(
                input_ids=input_ids.to(device),
                token_type_ids=token_type_ids.to(device),
                attention_mask=attention_mask.to(device)),
                                   target=label.to(device)) / config.accum_step

            # Mini-batch cross-entropy loss. Only used as log.
            loss += accum_loss.item()

            # Backward pass accumulation loss.
            accum_loss.backward()

            # Increment accumulation step.
            accum_step += 1

            # Perform gradient descend when achieve actual mini-batch size.
            if accum_step % config.accum_step == 0:
                # Gradient clipping.
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_norm)

                # Gradient descend.
                optimizer.step()

                # Update learning rate.
                scheduler.step()

                # Log on CLI.
                cli_logger.update()
                cli_logger.set_description(f'loss: {loss:.6f}')

                # Increment actual step.
                step += 1

                # Log loss and learning rate for each `config.log_step` step.
                if step % config.log_step == 0:
                    writer.add_scalar(f'{config.task}/{config.dataset}/loss',
                                      loss, step)
                    writer.add_scalar(
                        f'{config.task}/{config.dataset}/lr',
                        optimizer.state_dict()['param_groups'][0]['lr'], step)

                # Clean up mini-batch loss.
                loss = 0

                # Clean up gradient.
                optimizer.zero_grad()

                # Save model for each `config.ckpt_step` step.
                if step % config.ckpt_step == 0:
                    torch.save(
                        model.state_dict(),
                        os.path.join(experiment_dir, f'model-{step}.pt'))

            # Stop training condition.
            if accum_step >= total_accum_step:
                break

    # Release IO resources.
    writer.flush()
    writer.close()
    cli_logger.close()

    # Save the lastest model.
    torch.save(model.state_dict(),
               os.path.join(experiment_dir, f'model-{step}.pt'))
Ejemplo n.º 9
0
    def train_epoch(self, model: Reader, optimizer: torch.optim.Optimizer,
                    scaler: GradScaler, train: DataLoader, val: DataLoader,
                    scheduler: torch.optim.lr_scheduler.LambdaLR) -> float:
        """
        Performs one training epoch.

        :param model: The model you are training.
        :type model: Reader
        :param optimizer: Use this optimizer for training.
        :type optimizer: torch.optim.Optimizer
        :param scaler: Scaler for gradients when the mixed precision is used.
        :type scaler: GradScaler
        :param train: The train dataset loader.
        :type train: DataLoader
        :param val: The validation dataset loader.
        :type val: DataLoader
        :param scheduler: Learning rate scheduler.
        :type scheduler: torch.optim.lr_scheduler.LambdaLR
        :return: Best achieved exact match among validations.
        :rtype: float
        """

        model.train()
        loss_sum = 0
        samples = 0
        startTime = time.time()

        total_tokens = 0
        optimizer.zero_grad()

        initStep = 0
        if self.resumeSkip is not None:
            initStep = self.resumeSkip
            self.resumeSkip = None

        iterator = tqdm(enumerate(train), total=len(train), initial=initStep)

        bestExactMatch = 0.0

        for current_it, batch in iterator:
            batch: ReaderBatch
            lastScale = scaler.get_scale()
            self.n_iter += 1

            batchOnDevice = batch.to(self.device)
            samples += 1

            try:
                with torch.cuda.amp.autocast(
                        enabled=self.config["mixed_precision"]):
                    startScores, endScores, jointScore, selectionScore = self._useModel(
                        model, batchOnDevice)

                    # according to the config we can get following loss combinations
                    # join components
                    # independent components
                    # join components with HardEM
                    # independent components with HardEM

                    logSpanProb = None
                    if not self.config["independent_components_in_loss"]:
                        # joined components in loss
                        logSpanProb = Reader.scores2logSpanProb(
                            startScores, endScores, jointScore, selectionScore)

                    # User may want to use hardEMLoss with certain probability.
                    # In the original article it is not written clearly and it seams like it is the other way around.
                    # After I had consulted it with authors the idea became clear.

                    if self.config["hard_em_steps"] > 0 and \
                            random.random() <= min(self.update_it/self.config["hard_em_steps"], self.config["max_hard_em_prob"]):
                        # loss is calculated for the max answer span with max probability
                        if self.config["independent_components_in_loss"]:
                            loss = Reader.hardEMIndependentComponentsLoss(
                                startScores, endScores, jointScore,
                                selectionScore, batchOnDevice.answersMask)
                        else:
                            loss = Reader.hardEMLoss(logSpanProb,
                                                     batchOnDevice.answersMask)
                    else:
                        # loss is calculated for all answer spans
                        if self.config["independent_components_in_loss"]:
                            loss = Reader.marginalCompoundLossWithIndependentComponents(
                                startScores, endScores, jointScore,
                                selectionScore, batchOnDevice.answersMask)
                        else:
                            loss = Reader.marginalCompoundLoss(
                                logSpanProb, batchOnDevice.answersMask)

                    if self.config[
                            "use_auxiliary_loss"] and batch.isGroundTruth:
                        # we must be sure that user wants it and that the true passage is ground truth
                        loss += Reader.auxiliarySelectedLoss(selectionScore)
                    loss_sum += loss.item()

                scaler.scale(loss).backward()

            # Catch out-of-memory errors
            except RuntimeError as e:
                if "CUDA out of memory." in str(e):
                    torch.cuda.empty_cache()
                    logging.error(e)
                    tb = traceback.format_exc()
                    logging.error(tb)
                    continue
                else:
                    raise e

            # update parameters

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                filter(lambda p: p.requires_grad, model.parameters()),
                self.config["max_grad_norm"])

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad()
            self.update_it += 1

            if math.isclose(lastScale, scaler.get_scale(),
                            rel_tol=1e-6) and scheduler is not None:
                # we should not perform scheduler step when the optimizer step was omitted due to the
                # change of scale factor
                scheduler.step()

            if self.update_it % self.config["validate_after_steps"] == 0:
                valLoss, exactMatch, passageMatch, samplesWithLoss = self.validate(
                    model, val)

                logging.info(
                    f"Steps:{self.update_it}, Training loss: {loss_sum / samples:.5f}, Validation loss: {valLoss} (samples with loss {samplesWithLoss} [{samplesWithLoss / len(val):.1%}]), Exact match: {exactMatch:.5f}, Passage match: {passageMatch:.5f}"
                )

                bestExactMatch = max(exactMatch, bestExactMatch)
                if self.update_it > self.config["first_save_after_updates_K"]:
                    checkpoint = Checkpoint(
                        model.module if isinstance(model, DataParallel) else
                        model, optimizer, scheduler, train.sampler.actPerm,
                        current_it + 1, self.config, self.update_it)
                    checkpoint.save(f"{self.config['save_dir']}/Reader_train"
                                    f"_{get_timestamp()}"
                                    f"_{socket.gethostname()}"
                                    f"_{valLoss}"
                                    f"_S_{self.update_it}"
                                    f"_E_{current_it}.pt")

                model.train()

            # statistics & logging
            total_tokens += batch.inputSequences.numel()
            if (self.n_iter + 1) % 50 == 0 or current_it == len(iterator) - 1:
                iterator.set_description(
                    f"Steps: {self.update_it} Tokens/s: {total_tokens / (time.time() - startTime)}, Training loss: {loss_sum / samples}"
                )

            if self.config["max_steps"] <= self.update_it:
                break

        logging.info(
            f"End of epoch training loss: {loss_sum / samples:.5f}, best validation exact match: {bestExactMatch}"
        )

        return bestExactMatch