def lr_find(module: pl.LightningModule,
            gpu_id: typing.Union[torch.device, int] = None,
            init_value: float = 1e-8,
            final_value: float = 10.,
            beta: float = 0.98,
            max_steps: int = None) -> (typing.List[float], typing.List[float]):
    with tempfile.TemporaryDirectory() as tmpdir:
        save_path = pathlib.Path(tmpdir) / 'model.pth'
        torch.save(module.state_dict(), save_path)
        train_dataloader = module.train_dataloader()

        if max_steps is None:
            num = len(train_dataloader) - 1
        else:
            num = min(len(train_dataloader) - 1, max_steps)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value

        avg_loss = 0.
        best_loss = 0.
        losses = []
        lrs = []

        optimizers = initialize_optimizers(module, lr)

        if gpu_id is not None:
            module = module.to(gpu_id)

        for batch_num, batch in enumerate(tqdm(train_dataloader, total=num),
                                          start=1):
            if gpu_id is not None:
                batch = transfer_batch_to_gpu(batch, gpu_id)
            loss = module.training_step(batch, batch_num)['loss']

            # Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta**batch_num)

            # Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                break

            if lr >= final_value:
                break

            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss

            losses.append(smoothed_loss)
            lrs.append(lr)

            loss.backward()
            optimizers = step_optimizers(optimizers)
            optimizers = zero_grad_optimizers(optimizers)

            # Update the lr for the next step
            lr *= mult
            optimizers = set_optimizer_lr(optimizers, lr)
        module.load_state_dict(torch.load(save_path))
    return lrs, losses
示例#2
0
def load_pretrained(
        model: LightningModule,
        class_name: Optional[str] = None) -> None:  # pragma: no cover
    if class_name is None:
        class_name = model.__class__.__name__
    ckpt_url = urls[class_name]
    weights_model = model.__class__.load_from_checkpoint(ckpt_url)
    model.load_state_dict(weights_model.state_dict())