Beispiel #1
0
def test_attach_model_callbacks_override_info(caplog):
    """Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
    model = LightningModule()
    model.configure_callbacks = lambda: [
        LearningRateMonitor(),
        EarlyStopping(monitor="foo")
    ]
    trainer = Trainer(enable_checkpointing=False,
                      callbacks=[
                          EarlyStopping(monitor="foo"),
                          LearningRateMonitor(),
                          TQDMProgressBar()
                      ])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    with caplog.at_level(logging.INFO):
        cb_connector._attach_model_callbacks()

    assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text
Beispiel #2
0
 def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
     if trainer.current_epoch % self.snapshot_interval == 0:
         train_dataloader = pl_module.train_dataloader()
         for x, y, global_example_indexes in tqdm(
                 train_dataloader,
                 desc="Snapshotting: ",
                 file=sys.stdout,
                 position=0,
                 leave=True,
         ):
             x.requires_grad = True
             logits = pl_module.forward(x)
             loss = F.nll_loss(logits, y)
             loss.backward()
             for global_index, grad, class_label in zip(
                     global_example_indexes.tolist(), x.grad, y):
                 self._x_idx_to_grads.setdefault(global_index,
                                                 []).append(grad)
                 self._x_idx_to_y[global_index] = class_label
Beispiel #3
0
    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        # show images only every 20 batches
        if (batch_idx + 1) % self.logging_batch_interval != 0:
            return

        # pick the last batch and logits
        x, y = batch
        try:
            logits = pl_module.last_logits
        except AttributeError as err:
            m = """please track the last_logits in the training_step like so:
                def training_step(...):
                    self.last_logits = your_logits
            """
            raise AttributeError(m) from err

        # only check when it has opinions (ie: the logit > 5)
        if logits.max() > self.min_logit_value:
            # pick the top two confused probs
            (values, idxs) = torch.topk(logits, k=2, dim=1)

            # care about only the ones that are at most eps close to each other
            eps = self.max_logit_difference
            mask = (values[:, 0] - values[:, 1]).abs() < eps

            if mask.sum() > 0:
                # pull out the ones we care about
                confusing_x = x[mask, ...]
                confusing_y = y[mask]

                mask_idxs = idxs[mask]

                pl_module.eval()
                self._plot(confusing_x, confusing_y, trainer, pl_module, mask_idxs)
                pl_module.train()
Beispiel #4
0
    def _train_log(self, trainer: Trainer, pl_module: LightningModule):
        if self.training_config.evaluate_metrics:
            self.train_combined_report.metrics = pl_module.metrics(
                self.train_combined_report, self.train_combined_report)

        pl_module.train_meter.update_from_report(self.train_combined_report)

        extra = {}
        if "cuda" in str(trainer.model.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        optimizer = self.get_optimizer(trainer)
        num_updates = self._get_num_updates_for_logging(trainer)
        current_iteration = self._get_iterations_for_logging(trainer)
        extra.update({
            "epoch":
            self._get_current_epoch_for_logging(trainer),
            "iterations":
            current_iteration,
            "num_updates":
            num_updates,
            "max_updates":
            trainer.max_steps,
            "lr":
            "{:.5f}".format(optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.trainer_config.log_every_n_steps /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=trainer.max_steps,
                num_updates=num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.trainer_config.log_every_n_steps,
                eval_interval=self.trainer_config.val_check_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=current_iteration,
            num_updates=num_updates,
            max_updates=trainer.max_steps,
            meter=pl_module.train_meter,
            extra=extra,
            tb_writer=self.lightning_trainer.tb_writer,
        )
Beispiel #5
0
    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        x, y = self.to_device(batch, pl_module.device)

        with torch.no_grad():
            representations = self.get_representations(pl_module, x)

        representations = representations.detach()

        # forward pass
        mlp_preds = pl_module.non_linear_evaluator(
            representations)  # type: ignore[operator]
        mlp_loss = F.cross_entropy(mlp_preds, y)

        # log metrics
        val_acc = accuracy(mlp_preds, y)
        pl_module.log('online_val_acc',
                      val_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log('online_val_loss',
                      mlp_loss,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        self.confusion_matrix(mlp_preds, y)

        if self.time_to_sample:
            N, C, H, W = batch[0][2].shape
            num = min(N, 16)
            self.images = batch[0][2][0:num]
            self.mlp_preds = torch.argmax(mlp_preds[0:num], dim=1)
            self.labels = y[0:num]
            self.time_to_sample = False
Beispiel #6
0
        def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx,
                         on_steps=[], on_epochs=[], prob_bars=[]):
            self.funcs_called_count[func_name] += 1
            for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))):
                # run logging
                custom_func_name = f"{func_idx}_{idx}_{func_name}"
                pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step,
                              on_epoch=on_epoch, prog_bar=prog_bar)

                # catch information for verification

                # on on_train_start is outside the main loop. Won't be called
                if func_name == "on_train_start":
                    self.callback_funcs_called[func_name].append([self.count * func_idx])

                # Saved only values from second epoch, so we can compute its mean or latest.
                if pl_module.trainer.current_epoch == 1:
                    self.callback_funcs_called[func_name].append([self.count * func_idx])

                forked = on_step and on_epoch

                self.funcs_attr[custom_func_name] = {
                    "on_step": on_step,
                    "on_epoch": on_epoch,
                    "prog_bar": prog_bar,
                    "forked": forked,
                    "func_name": func_name}

                if on_step and on_epoch:
                    self.funcs_attr[f"{custom_func_name}_step"] = {
                        "on_step": True,
                        "on_epoch": False,
                        "prog_bar": prog_bar,
                        "forked": False,
                        "func_name": func_name}

                    self.funcs_attr[f"{custom_func_name}_epoch"] = {
                        "on_step": False,
                        "on_epoch": True,
                        "prog_bar": prog_bar,
                        "forked": False,
                        "func_name": func_name}
Beispiel #7
0
    def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Make sure we have a quantized version.

        This handles the edge case where a user does .test() without .fit() first.
        """
        if hasattr(pl_module, "_quantized"):
            return
        pl_module._quantized = self.convert(
            pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
        )
        self.quantized = pl_module._quantized
    def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

        pl_module.non_linear_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim,
        ).to(pl_module.device)

        self.optimizer = torch.optim.Adam(pl_module.non_linear_evaluator.parameters(), lr=1e-4)
    def fit(
        self,
        model: pl.LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[DataLoader] = None,
    ):
        self.model = model

        # try to get hparams from module.
        if self.tune_hparams is None:
            if hasattr(model, 'get_pbt_hparams'):
                self.tune_hparams = model.get_pbt_hparams()

        mp = _mp.get_context('forkserver')

        global_epoch = mp.Value('i', 0)

        # initialize population tasks
        population_tasks = mp.Queue(maxsize=self.population_size)
        for i in range(self.population_size):
            population_tasks.put({
                'id': i,
                self.pbt_monitor: 0,
                'checkpoint_path': None,
            })

        logger_info = dict(name=self.trainer.logger.name,
                           version=self.trainer.logger.version,
                           save_dir=self.trainer.logger.save_dir)

        workers = [
            Worker(
                pl_trainer=copy.deepcopy(self.trainer),
                model=copy.deepcopy(model),
                population_tasks=population_tasks,
                tune_hparams=self.tune_hparams,
                process_position=i,
                global_epoch=global_epoch,
                max_epoch=10,
                full_parallel=self.num_workers == self.population_size,
                logger_info=logger_info,
                dataloaders=dict(
                    train_dataloader=copy.deepcopy(train_dataloader),
                    val_dataloaders=copy.deepcopy(val_dataloaders),
                ),
            ) for i in range(self.num_workers)
        ]

        [w.start() for w in workers]
        [w.join() for w in workers]
        task = []

        while not population_tasks.empty():
            task.append(population_tasks.get())
Beispiel #10
0
 def on_validation_batch_end(
     self,
     trainer: Trainer,
     pl_module: LightningModule,
     outputs: Sequence,
     batch: Sequence,
     batch_idx: int,
     dataloader_idx: int,
 ) -> None:
     val_acc, mlp_loss = self.shared_step(pl_module, batch)
     pl_module.log("online_val_acc",
                   val_acc,
                   on_step=False,
                   on_epoch=True,
                   sync_dist=True)
     pl_module.log("online_val_loss",
                   mlp_loss,
                   on_step=False,
                   on_epoch=True,
                   sync_dist=True)
Beispiel #11
0
    def test_forward(self, model: pl.LightningModule, data: torch.Tensor, training: bool):
        r"""Calls ``model.forward()`` and tests that the output is not ``None``.

        Because of the size of some models, this test is only run when a GPU is available.
        """
        if isinstance(model, (LightningDistributedDataParallel, LightningDistributedModule)):
            pytest.skip()

        if torch.cuda.is_available():
            model = model.cuda()  # type: ignore
            data = data.cuda()

        if training:
            model.train()
        else:
            model.eval()

        _ = model(data)

        assert _ is not None
def test_invalid_weights_summmary():
    """Test that invalid value for weights_summary raises an error."""
    model = LightningModule()

    with pytest.raises(
        MisconfigurationException, match="`weights_summary` can be None, .* got temp"
    ), pytest.deprecated_call(match="weights_summary=temp)` is deprecated"):
        Trainer(weights_summary="temp")

    with pytest.raises(ValueError, match="max_depth` can be .* got temp"):
        ModelSummary(model, max_depth="temp")
Beispiel #13
0
def log_hyperparameters(
    config: DictConfig,
    model: pl.LightningModule,
    datamodule: pl.LightningDataModule,
    trainer: pl.Trainer,
    callbacks: List[pl.Callback],
    logger: List[pl.loggers.LightningLoggerBase],
) -> None:
    """Controls which config parts are saved by Lightning loggers.

    Additionaly saves:
    - number of model parameters
    """

    if not trainer.logger:
        return

    hparams = {}

    # choose which parts of hydra config will be saved to loggers
    hparams["model"] = config["model"]

    # save number of model parameters
    hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
    hparams["model/params/trainable"] = sum(p.numel()
                                            for p in model.parameters()
                                            if p.requires_grad)
    hparams["model/params/non_trainable"] = sum(p.numel()
                                                for p in model.parameters()
                                                if not p.requires_grad)

    hparams["datamodule"] = config["datamodule"]
    hparams["trainer"] = config["trainer"]

    if "seed" in config:
        hparams["seed"] = config["seed"]
    if "callbacks" in config:
        hparams["callbacks"] = config["callbacks"]

    # send hparams to all loggers
    trainer.logger.log_hyperparams(hparams)
Beispiel #14
0
def log_hyperparameters(
    config: DictConfig,
    model: pl.LightningModule,
    datamodule: pl.LightningDataModule,
    trainer: pl.Trainer,
    callbacks: List[pl.Callback],
    logger: List[pl.loggers.LightningLoggerBase],
) -> None:
    """This method controls which parameters from Hydra config are saved by Lightning loggers.

    Additionaly saves:
        - number of trainable model parameters
    """

    hparams = {}

    # choose which parts of hydra config will be saved to loggers
    hparams["trainer"] = config["trainer"]
    hparams["model"] = config["model"]
    hparams["datamodule"] = config["datamodule"]
    if "seed" in config:
        hparams["seed"] = config["seed"]
    if "callbacks" in config:
        hparams["callbacks"] = config["callbacks"]

    # save number of model parameters
    hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
    hparams["model/params_trainable"] = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    hparams["model/params_not_trainable"] = sum(
        p.numel() for p in model.parameters() if not p.requires_grad
    )

    # send hparams to all loggers
    trainer.logger.log_hyperparams(hparams)

    # disable logging any more hyperparameters for all loggers
    # this is just a trick to prevent trainer from logging hparams of model,
    # since we already did that above
    trainer.logger.log_hyperparams = empty
Beispiel #15
0
 def common_epoch_end(self,
                      step_outputs,
                      prefix="train/",
                      exclude_keys={"pred", "target"}):
     keys = list(step_outputs[0].keys())
     mean_step_metrics = {
         k: torch.mean(torch.stack([x[k] for x in step_outputs]))
         for k in keys if k not in exclude_keys
     }
     preds, targets = zip(*[(s["pred"], s["target"]) for s in step_outputs])
     preds = cat_steps(preds)
     targets = cat_steps(targets)
     epoch_metrics = prefix_keys(
         prefix,
         self.collect_epoch_metrics(preds, targets, prefix.replace("/",
                                                                   "")),
     )
     epoch_metrics, epoch_figures = sort_out_figures(epoch_metrics)
     all_metrics = {**mean_step_metrics, **epoch_metrics}
     LightningModule.log_dict(self, all_metrics, sync_dist=self._sync_dist)
     log_figures(self, epoch_figures)
Beispiel #16
0
    def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        """Override the model with a quantized-aware version on setup.

        This is the earliest place we can override this model which allows for
        appropriate behavior when restoring from checkpoints, as well as connecting
        to accelerators, etc.

        The model is only prepared once.
        """
        # Only prepare the model once.
        if hasattr(pl_module, "_prepared"):
            return

        with mode(pl_module, training=True) as train:
            pl_module._prepared = self.prepare(
                _deepcopy(train),
                configs=self.qconfig_dicts,
                attrs=self.preserved_attrs,
            )
        pl_module.forward = MethodType(_quantized_forward, pl_module)
        self.prepared = pl_module._prepared
Beispiel #17
0
 def on_train_epoch_end(self, trainer : pl.Trainer, pl_module : pl.LightningModule, outputs):
     if pl_module.current_epoch % self.save_step == 0:
         ckpt = {
             'net_model' : pl_module.model.state_dict(),
             'ema_model' : pl_module.ema_model.state_dict(),
             'optim' : pl_module.optimizers().state_dict(),
             'epoch' : pl_module.current_epoch,
             #'x_T' : self.x_T
         }
         Path(os.path.join(self.save_dir)).mkdir(parents=True, exist_ok=True)
         path = os.path.join(self.save_dir, 'ckpt.pt')
         torch.save(ckpt, path)
 def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
     super().on_train_start(trainer, pl_module)
     submodule_dict = dict(pl_module.named_modules())
     self._hook_handles = []
     for name in self._get_submodule_names(pl_module):
         if name not in submodule_dict:
             rank_zero_warn(
                 f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__},"
                 " skipping this key.")
             continue
         handle = self._register_hook(name, submodule_dict[name])
         self._hook_handles.append(handle)
def lr_find(module: pl.LightningModule,
            gpu_id: typing.Union[torch.device, int] = None,
            init_value: float = 1e-8,
            final_value: float = 10.,
            beta: float = 0.98,
            max_steps: int = None) -> (typing.List[float], typing.List[float]):
    with tempfile.TemporaryDirectory() as tmpdir:
        save_path = pathlib.Path(tmpdir) / 'model.pth'
        torch.save(module.state_dict(), save_path)
        train_dataloader = module.train_dataloader()

        if max_steps is None:
            num = len(train_dataloader) - 1
        else:
            num = min(len(train_dataloader) - 1, max_steps)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value

        avg_loss = 0.
        best_loss = 0.
        losses = []
        lrs = []

        optimizers = initialize_optimizers(module, lr)

        if gpu_id is not None:
            module = module.to(gpu_id)

        for batch_num, batch in enumerate(tqdm(train_dataloader, total=num),
                                          start=1):
            if gpu_id is not None:
                batch = transfer_batch_to_gpu(batch, gpu_id)
            loss = module.training_step(batch, batch_num)['loss']

            # Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta**batch_num)

            # Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                break

            if lr >= final_value:
                break

            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss

            losses.append(smoothed_loss)
            lrs.append(lr)

            loss.backward()
            optimizers = step_optimizers(optimizers)
            optimizers = zero_grad_optimizers(optimizers)

            # Update the lr for the next step
            lr *= mult
            optimizers = set_optimizer_lr(optimizers, lr)
        module.load_state_dict(torch.load(save_path))
    return lrs, losses
def test_checkpoint_callbacks_are_last(tmpdir):
    """Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
    checkpoint1 = ModelCheckpoint(tmpdir)
    checkpoint2 = ModelCheckpoint(tmpdir)
    model_summary = ModelSummary()
    early_stopping = EarlyStopping()
    lr_monitor = LearningRateMonitor()
    progress_bar = ProgressBar()

    # no model reference
    trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]

    # no model callbacks
    model = LightningModule()
    model.configure_callbacks = lambda: []
    trainer.model = model
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]

    # with model-specific callbacks that substitute ones in Trainer
    model = LightningModule()
    model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
    trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, model_summary, checkpoint1, checkpoint2]
Beispiel #21
0
 def on_validation_batch_end(self, trainer: pl.Trainer,
                             pl_module: pl.LightningModule, outputs: Any,
                             batch: BatchType, batch_idx: int,
                             dataloader_idx: int) -> None:  # type: ignore
     """
     Get and log validation metrics.
     """
     ids_linear_head = tuple(
         batch[SSLDataModuleType.LINEAR_HEAD][0].tolist())
     if ids_linear_head not in self.visited_ids:
         self.visited_ids.add(ids_linear_head)
         loss = self.shared_step(batch, pl_module, is_training=False)
         pl_module.log('ssl_online_evaluator/val/loss',
                       loss,
                       on_step=False,
                       on_epoch=True,
                       sync_dist=False)
         for metric in self.val_metrics:
             pl_module.log(f"ssl_online_evaluator/val/{metric.name}",
                           metric,
                           on_epoch=True,
                           on_step=False)  # type: ignore
Beispiel #22
0
def override_unsupported_nud(lm: pl.LightningModule, context: PyTorchTrialContext) -> None:
    writer = pytorch.TorchWriter()

    def lm_print(*args: Any, **kwargs: Any) -> None:
        if context.distributed.get_rank() == 0:
            print(*args, **kwargs)

    def lm_log_dict(a_dict: Dict, *args: Any, **kwargs: Any) -> None:
        if len(args) != 0 or len(kwargs) != 0:
            raise InvalidModelException(
                f"unsupported arguments to LightningModule.log {args} {kwargs}"
            )
        for metric, value in a_dict.items():
            if type(value) == int or type(value) == float:
                writer.add_scalar(metric, value, context.current_train_batch())

    def lm_log(name: str, value: Any, *args: Any, **kwargs: Any) -> None:
        lm_log_dict({name: value}, *args, **kwargs)

    lm.print = lm_print  # type: ignore
    lm.log = lm_log  # type: ignore
    lm.log_dict = lm_log_dict  # type: ignore
Beispiel #23
0
    def on_train_epoch_end(self, trainer: pl.Trainer,
                           pl_module: pl.LightningModule) -> None:
        if not _TORCHVISION_AVAILABLE:
            return

        images, _ = next(
            iter(
                DataLoader(trainer.datamodule.mnist_val,
                           batch_size=self.num_samples)))
        images_flattened = images.view(images.size(0), -1)

        # generate images
        with torch.no_grad():
            pl_module.eval()
            images_generated = pl_module(images_flattened.to(pl_module.device))
            pl_module.train()

        if trainer.current_epoch == 0:
            save_image(self._to_grid(images),
                       f"grid_ori_{trainer.current_epoch}.png")
        save_image(self._to_grid(images_generated.reshape(images.shape)),
                   f"grid_generated_{trainer.current_epoch}.png")
    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        if trainer.current_epoch >= self.start_epoch:
            x, y = self.to_device(batch, pl_module.device)

            with torch.no_grad():
                representations = self.get_representations(pl_module, x)

            representations = representations.detach()

            # forward pass
            mlp_preds = pl_module.non_linear_evaluator(
                representations)  # type: ignore[operator]
            weight = (None if self.loss_weight is None else
                      self.loss_weight.to(pl_module.device).float())
            mlp_loss = F.cross_entropy(mlp_preds, y, weight=weight)

            # update finetune weights
            mlp_loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            # log metrics
            train_acc = accuracy(mlp_preds, y)
            pl_module.log("online_train_acc",
                          train_acc,
                          on_step=True,
                          on_epoch=False)
            pl_module.log("online_train_loss",
                          mlp_loss,
                          on_step=True,
                          on_epoch=False)
    def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        logger.info("***** Test results *****")

        if pl_module.is_logger():
            metrics = trainer.callback_metrics

            # Log and save results to file
            output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
            with open(output_test_results_file, "w") as writer:
                for key in sorted(metrics):
                    if key not in ["log", "progress_bar"]:
                        logger.info("{} = {}\n".format(key, str(metrics[key])))
                        writer.write("{} = {}\n".format(key, str(metrics[key])))
Beispiel #26
0
    def create_figure(
        self, task: pl.LightningModule, x: Tensor, year: np.array = None
    ) -> plt.Figure:
        """
        Creates a matplotlib figure of generated album covers.

        Args:
            task: The CoverGANTask that contains the generator that should be
                used to generate the images from the latent vectors
            x: Latent vectors
            year: Standardized release year of the artificial album
        """
        if year is not None:
            year = Tensor(year).to(task.device)
        idx = 1
        figsize = (np.array(self.target_size) * [10, 1]).astype(int) / 300
        fig = plt.figure(figsize=figsize, dpi=300)
        with torch.no_grad():
            if isinstance(task.generator, ProGANGenerator):
                output = task.generator(
                    x, year=year, block=task.block, alpha=task.alpha
                )
            else:
                output = task.generator(x)
        generated_images = output.detach().cpu().numpy()
        generated_images = np.moveaxis(generated_images, 1, -1)
        for img in generated_images:
            if img.shape[-1] == 1:
                img = np.tile(img, (1, 1, 3))
            img = array_to_img(img, scale=True)
            img = img.resize(size=self.target_size)
            plt.subplot(1, 10, idx)
            plt.axis("off")
            plt.imshow(img)
            idx += 1
            plt.subplots_adjust(
                left=0, bottom=0, right=1, top=1, wspace=0, hspace=0.1
            )
        return fig
Beispiel #27
0
def _deepcopy(pl_module: LightningModule) -> LightningModule:
    """Copy a LightningModule. Some properties need to be ignored. """
    # Remove _result before call to deepcopy since it store non-leaf Tensors.
    # If not removed, you'll see this error on deepcopy() attempts: P150283141.
    if hasattr(pl_module, "_results"):
        result = pl_module._results
        delattr(pl_module, "_results")
        copy = deepcopy(pl_module)

        # Set back.
        pl_module._results = result
    else:
        copy = deepcopy(pl_module)
    return copy
Beispiel #28
0
def custom_lr_finder(runner: LightningModule, omegaConf: DictConfig) -> LightningModule:
    """
        LR finder
        The prepare_data/setup does not work well with the lr finder
        To handle the situation we manually search for it before the training

        The situation is being handled:
        https://github.com/PyTorchLightning/pytorch-lightning/issues/2485
    """
    del omegaConf['trainer']['auto_lr_find']
    tmp_trainer = pl.Trainer(**omegaConf['trainer'])
    runner.prepare_data()
    runner.setup('lr_finder')
    lr_finder = tmp_trainer.lr_find(runner)
    # fig = lr_finder.plot(suggest=True)
    new_lr = lr_finder.suggestion()
    omegaConf['runner']['lr'] = new_lr
    runner = make_runner(omegaConf['runner'])

    if omegaConf['runner'].get('verbose', False) is True:
        print('Learning rate found: {}'.format(new_lr))

    return runner
Beispiel #29
0
def wrap_training_step(wrapped, instance: LightningModule, args, kwargs):
    """
    Wraps the training step of the LightningModule.

    Parameters
    ----------
    wrapped: The wrapped function.
    instance: The LightningModule instance.
    args: The arguments passed to the wrapped function.
    kwargs: The keyword arguments passed to the wrapped function.

    Returns
    -------
    The return value of the wrapped function.
    """
    output_dict = wrapped(*args, **kwargs)

    if isinstance(output_dict,
                  dict) and output_dict is not None and "log" in output_dict:
        log_dict = output_dict.pop("log")
        instance.log_dict(log_dict, on_step=True)

    return output_dict
Beispiel #30
0
    def test_validation_step(self, model: pl.LightningModule):
        r"""Runs a validation step based on the data returned from ``model.val_dataloader()``.
        Tests that the dictionary returned from ``validation_step()`` are as required by PyTorch
        Lightning.

        Because of the size of some models, this test is only run when a GPU is available.
        """
        if isinstance(model, LightningDistributedDataParallel):
            pytest.skip()

        check_overriden(model, "val_dataloader")
        check_overriden(model, "validation_step")

        dl = model.val_dataloader()
        # TODO this can't handle multiple optimizers
        batch = next(iter(dl))

        if torch.cuda.is_available():
            batch = [x.cuda() for x in batch]
            model = model.cuda()

        model.eval()
        output = model.validation_step(batch, 0)

        assert isinstance(output, dict)

        if "loss" in output.keys():
            assert isinstance(output["loss"], Tensor)
            assert output["loss"].shape == (1, )

        if "log" in output.keys():
            assert isinstance(output["log"], dict)

        if "progress_bar" in output.keys():
            assert isinstance(output["progress_bar"], dict)

        return batch, output