Esempio n. 1
0
def train_and_validate(
    args: Arguments,
    model: nn.Module,
    loader: TensorDataLoader,
    optimizer: optim.Optimizer | None,
    criterion: nn.Module,
    metrics: MetricTracker,
    mode: Mode,
) -> None:
    if mode == Mode.TRAIN:
        model.train()
    else:
        model.eval()

    torch.set_grad_enabled(mode == Mode.TRAIN)
    metrics.reset_hard()
    num_batches = len(loader)
    with tqdm(desc=str(mode), total=num_batches, ncols=120) as pbar:
        for i, (data, target) in enumerate(loader):
            # If you have multiple optimizers, use model.zero_grad().
            # If you want to freeze layers, use optimizer.zero_grad().
            if mode == Mode.TRAIN and optimizer is not None:
                optimizer.zero_grad(set_to_none=True)

            if isinstance(data, (list, tuple)):
                output = model(*data)
                batch_size = data[0].size(args.batch_dim)
            else:
                output = model(data)
                batch_size = data.size(args.batch_dim)

            loss = criterion(output, target)
            if mode == Mode.TRAIN and optimizer is not None:
                loss.backward()
                optimizer.step()

            val_dict = {
                "data": data,
                "loss": loss,
                "output": output,
                "target": target,
                "batch_size": batch_size,
            }
            tqdm_dict = metrics.batch_update(SimpleNamespace(**val_dict), i,
                                             num_batches, mode)
            pbar.set_postfix(tqdm_dict)
            pbar.update()
    metrics.epoch_update(mode)
Esempio n. 2
0
def load_state_dict(
    checkpoint: dict[str, Any],
    model: nn.Module,
    optimizer: optim.Optimizer | None = None,
    scheduler: lr_scheduler._LRScheduler | None = None,
) -> None:
    """
    Loads model parameters (state_dict) from checkpoint. If optimizer or scheduler are
    provided, loads state_dict of optimizer assuming it is present in checkpoint.
    Args:
        checkpoint: () checkpoint object
        model: (torch.nn.Module) model for which the parameters are loaded
        optimizer: (torch.optim) optional: resume optimizer from checkpoint
    """
    if checkpoint:
        print("Loading checkpoint...")
        model.load_state_dict(checkpoint["model_state_dict"])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint["scheduler_state_dict"])