def lr_find(module: pl.LightningModule, gpu_id: typing.Union[torch.device, int] = None, init_value: float = 1e-8, final_value: float = 10., beta: float = 0.98, max_steps: int = None) -> (typing.List[float], typing.List[float]): with tempfile.TemporaryDirectory() as tmpdir: save_path = pathlib.Path(tmpdir) / 'model.pth' torch.save(module.state_dict(), save_path) train_dataloader = module.train_dataloader() if max_steps is None: num = len(train_dataloader) - 1 else: num = min(len(train_dataloader) - 1, max_steps) mult = (final_value / init_value)**(1 / num) lr = init_value avg_loss = 0. best_loss = 0. losses = [] lrs = [] optimizers = initialize_optimizers(module, lr) if gpu_id is not None: module = module.to(gpu_id) for batch_num, batch in enumerate(tqdm(train_dataloader, total=num), start=1): if gpu_id is not None: batch = transfer_batch_to_gpu(batch, gpu_id) loss = module.training_step(batch, batch_num)['loss'] # Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() smoothed_loss = avg_loss / (1 - beta**batch_num) # Stop if the loss is exploding if batch_num > 1 and smoothed_loss > 4 * best_loss: break if lr >= final_value: break if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss losses.append(smoothed_loss) lrs.append(lr) loss.backward() optimizers = step_optimizers(optimizers) optimizers = zero_grad_optimizers(optimizers) # Update the lr for the next step lr *= mult optimizers = set_optimizer_lr(optimizers, lr) module.load_state_dict(torch.load(save_path)) return lrs, losses
def load_pretrained( model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no cover if class_name is None: class_name = model.__class__.__name__ ckpt_url = urls[class_name] weights_model = model.__class__.load_from_checkpoint(ckpt_url) model.load_state_dict(weights_model.state_dict())