예제 #1
0
    def interpolate_latent_space(self, pl_module: LightningModule,
                                 latent_dim: int) -> List[Tensor]:
        images = []
        with torch.no_grad():
            pl_module.eval()
            for z1 in np.linspace(self.range_start, self.range_end,
                                  self.steps):
                for z2 in np.linspace(self.range_start, self.range_end,
                                      self.steps):
                    # set all dims to zero
                    z = torch.zeros(self.num_samples,
                                    latent_dim,
                                    device=pl_module.device)

                    # set the fist 2 dims to the value
                    z[:, 0] = torch.tensor(z1)
                    z[:, 1] = torch.tensor(z2)

                    # sample
                    # generate images
                    img = pl_module(z)

                    if len(img.size()) == 2:
                        img = img.view(self.num_samples, *pl_module.img_dim)

                    img = img[0]
                    img = img.unsqueeze(0)
                    images.append(img)

        pl_module.train()
        return images
예제 #2
0
    def test_test_epoch_end(self, model: pl.LightningModule):
        r"""Tests that ``test_epoch_end()`` runs and outputs a dict as required by PyTorch
        Lightning.

        Because of the size of some models, this test is only run when a GPU is available.
        """
        if isinstance(model, (LightningDistributedDataParallel, LightningDistributedModule)):
            pytest.skip()

        check_overriden(model, "test_dataloader")
        check_overriden(model, "test_step")
        check_overriden(model, "test_epoch_end")
        dl = model.test_dataloader()

        if torch.cuda.is_available():
            model = model.cuda()  # type: ignore
            model.eval()
            outputs = [model.test_step([x.cuda() for x in batch], 0) for batch in dl]
        else:
            model.eval()
            outputs = [model.test_step(batch, 0) for batch in dl]

        result = model.test_epoch_end(outputs)  # type: ignore
        assert isinstance(result, dict)
        return outputs, result
예제 #3
0
    def on_epoch_end(self, trainer: Trainer,
                     pl_module: LightningModule) -> None:
        dim = (self.num_samples, pl_module.hparams.latent_dim
               )  # type: ignore[union-attr]
        z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device)

        # generate images
        with torch.no_grad():
            pl_module.eval()
            images = pl_module(z)
            pl_module.train()

        if len(images.size()) == 2:
            img_dim = pl_module.img_dim
            images = images.view(self.num_samples, *img_dim)

        grid = torchvision.utils.make_grid(
            tensor=images,
            nrow=self.nrow,
            padding=self.padding,
            normalize=self.normalize,
            range=self.norm_range,
            scale_each=self.scale_each,
            pad_value=self.pad_value,
        )
        str_title = f"{pl_module.__class__.__name__}_images"
        trainer.logger.experiment.add_image(str_title,
                                            grid,
                                            global_step=trainer.global_step)
예제 #4
0
def generate_samples(
    samples: List[List[torch.Tensor]],
    pl_module: pl.LightningModule,
    normalize: Callable,
) -> torch.Tensor:
    images = []
    for sample in samples:
        img_a, img_b = sample
        with torch.no_grad():
            pl_module.eval()
            for direction, img in zip(['ab', 'ba'], [img_a, img_b]):
                fake, rec = pl_module(
                    img.unsqueeze(0).to(device=pl_module.device),
                    direction,
                )
                fake = fake.squeeze(0)
                rec = rec.squeeze(0)
                imgs = [img, fake, rec]
                imgs = [normalize(img.detach().cpu()) for img in imgs]
                images.append(imgs)
            pl_module.train()

    grid = build_grid(images)

    return grid
    def on_train_epoch_end(self, trainer: pl.Trainer,
                           pl_module: pl.LightningModule) -> None:
        pl_module.eval()

        y_hat = pl_module(self._inputs)

        ans = y_hat.reshape(self._image.shape[0], self._image.shape[1],
                            self._image.shape[2])
        ans = 0.5 * (ans + 1.0)

        f, axarr = plt.subplots(1, 2)
        axarr[0].imshow(ans.detach().numpy())
        axarr[0].set_title("fit")
        axarr[1].imshow(self._image)
        axarr[1].set_title("original")

        buf = io.BytesIO()
        plt.savefig(
            buf,
            dpi="figure",
            format=None,
            metadata=None,
            bbox_inches=None,
            pad_inches=0.1,
            facecolor="auto",
            edgecolor="auto",
            backend=None,
        )
        buf.seek(0)
        image = PIL.Image.open(buf)
        image = transforms.ToTensor()(image)

        trainer.logger.experiment.add_image(f"image",
                                            image,
                                            global_step=trainer.global_step)
예제 #6
0
def test_submit(model: pl.LightningModule, test_loader, output_path):
    with torch.no_grad():
        model.eval()

        predicts = []
        cur_id = 0
        for nbatch, batch in enumerate(test_loader):
            # bs = test_loader.batch_sampler.batch_size if test_loader.batch_sampler is not None else test_loader.batch_size
            batch = to_cuda(batch)
            outputs = model(*batch[:-1])
            if outputs['label_logits'].shape[-1] == 1:
                prob = torch.sigmoid(
                    outputs['label_logits'][:, 0]).detach().cpu().tolist()
            else:
                prob = torch.softmax(outputs['label_logits'],
                                     dim=-1)[:, 1].detach().cpu().tolist()
            sample_ids = batch[-1].cpu().tolist()

            for pb, id in zip(prob, sample_ids):
                predicts.append({
                    'id': int(id),
                    'proba': float(pb),
                    'label': int(pb > 0.5)
                })

        result_pd = pd.DataFrame.from_dict(predicts)
        result_pd.to_csv(output_path, index=False)
        model.train()
        return result_pd
    def on_train_epoch_end(self, trainer: pl.Trainer,
                           pl_module: pl.LightningModule) -> None:
        pl_module.eval()

        for e in range(self._samples):
            images_list = []
            this_image = (torch.rand(
                3,
                self._output_size[0],
                self._output_size[1],
                device=pl_module.device,
            ) * 2 - 1)
            for _ in range(self._frames):
                new_image = neighborhood_sample_generator(
                    model=pl_module,
                    image=this_image,
                    width=self._width,
                    outside=self._outside,
                    device=pl_module.device,
                    batch_size=self._batch_size,
                )
                this_image = 0.5 * (this_image + new_image)
                images_list.append(this_image.cpu().clone().detach())

            all_images = torch.stack(images_list, dim=0).detach()
            all_images = 0.5 * (all_images + 1)

            img = make_grid(all_images).permute(1, 2, 0).cpu().numpy()

            trainer.logger.experiment.add_image(
                f"sample {e}",
                torch.tensor(img).permute(2, 0, 1),
                global_step=trainer.global_step,
            )
예제 #8
0
def draw_samples(
    samples: List[List[torch.Tensor]],
    mode: str,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    normalize: Callable,
):
    images = []
    for sample in samples:
        img_a, img_b = sample
        img_a = img_a.to(device=pl_module.device)
        with torch.no_grad():
            pl_module.eval()
            img_b_fakes = pl_module(img_a.unsqueeze(0))
            img_b_fake = img_b_fakes.squeeze(0)
            pl_module.train()
            imgs = [img_a, img_b_fake, img_b]
            imgs = [normalize(img.detach().cpu()) for img in imgs]
            images.append(imgs)

    grid = build_grid(images)

    str_title = f'{pl_module.__class__.__name__}_images_{mode}'
    trainer.logger.experiment.add_image(
        str_title,
        grid,
        global_step=trainer.global_step,
    )
예제 #9
0
    def on_train_epoch_end(self, trainer, pl_module: LightningModule,
                           outputs: Any) -> None:
        self.env_loop.seed(self.seed)
        was_in_training_mode = pl_module.training
        if self.to_eval:
            pl_module.eval()

        returns: List[float] = []
        lengths: List[float] = []

        while len(returns) < self.n_eval_episodes:
            self.env_loop.reset()
            _lengths, _returns = self._eval_env_run()
            returns = returns + _returns
            lengths = lengths + _lengths

        returns_arr = np.array(returns)
        lengths_arr = np.array(lengths)

        if self.to_eval and was_in_training_mode:
            pl_module.train()

        for k, mapper in self.return_mappers.items():
            v: Any = mapper(returns_arr)
            pl_module.log(self.logging_prefix + "/" + k, v, prog_bar=False)

        for k, mapper in self.length_mappers.items():
            v: Any = mapper(lengths_arr)  # type: ignore
            pl_module.log(self.logging_prefix + "/" + k, v, prog_bar=False)

        if self.mean_return_in_progress_bar:
            pl_module.log("return", np.mean(returns), prog_bar=True)
예제 #10
0
    def on_epoch_end(self, trainer: Trainer,
                     pl_module: LightningModule) -> None:
        if (trainer.current_epoch
                == 0) | (trainer.current_epoch % self.save_every_epochs == 0):
            x, y = next(iter(pl_module.val_dataloader()))
            x, y = x.to(pl_module.device), y.to(pl_module.device)
            pl_module.eval()
            with torch.no_grad():
                if isinstance(pl_module, VanillaVAE):
                    x_hat, p_x_z, z, q_z_x, p_z = pl_module._run_step(x)
                elif isinstance(pl_module, V3AE):
                    x_hat, p_x_z, λ, q_λ_z, p_λ, z, q_z_x, p_z = pl_module._run_step(
                        x)

            x_mean = batch_reshape(p_x_z.mean, pl_module.input_dims)
            x_var = batch_reshape(p_x_z.variance, pl_module.input_dims)

            fig = plt.figure()
            n_display = 4
            gs = fig.add_gridspec(4,
                                  n_display,
                                  width_ratios=[1] * n_display,
                                  height_ratios=[1, 1, 1, 1])
            gs.update(wspace=0, hspace=0)
            for n in range(n_display):
                for k in range(4):
                    ax = plt.subplot(gs[k, n])
                    ax = disable_ticks(ax)
                    # Original
                    if k == 0:
                        ax.imshow(x[n, :][0].cpu(),
                                  cmap="binary",
                                  vmin=0,
                                  vmax=1)
                    # Mean
                    elif k == 1:
                        ax.imshow(x_mean[n, :][0].cpu(),
                                  cmap="binary",
                                  vmin=0,
                                  vmax=1)
                    # Variance
                    elif k == 2:
                        ax.imshow(x_var[n, :][0].cpu(), cmap="binary")
                    # Sample
                    elif k == 3:
                        ax.imshow(x_hat[n, :][0].cpu(),
                                  cmap="binary",
                                  vmin=0,
                                  vmax=1)

            str_title = f"{pl_module.__class__.__name__}_images"
            trainer.logger.experiment.add_image(
                str_title,
                plot_to_image(fig),
                global_step=trainer.global_step,
                dataformats="CHW",
            )
예제 #11
0
def get_misclassified(
    module: pl.LightningModule, data_loader: DataLoader, device: str = "cpu"
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
        Gets the information of the misclassified data items in the given dataset.

        Args:
            module (pl.LightningModule): The module to use.
            test_loader (DataLoader): The ``DataLoader`` to use.
            device (str): A valid pytorch device string.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple consisting of the \
                image information, predicted, and actual class of the misclassified images.
        """

    # defining variables
    misclassified = []
    misclassified_pred = []
    misclassified_target = []

    module.to(device)

    # put the model to evaluation mode
    module.eval()

    with torch.no_grad():
        for data, target in data_loader:
            # casting data to device
            data, target = data.to(device), target.to(device)

            # forward prop
            output = module(data)

            # get the predicted class
            pred = output.argmax(dim=1, keepdim=True)

            # get the current misclassified in this batch
            list_misclassified = pred.eq(target.view_as(pred)) == False
            batch_misclassified = data[list_misclassified.squeeze()]
            batch_mis_pred = pred[list_misclassified]
            batch_mis_target = target.view_as(pred)[list_misclassified]

            # add data to function variables
            misclassified.append(batch_misclassified)
            misclassified_pred.append(batch_mis_pred)
            misclassified_target.append(batch_mis_target)

    # group all the batched together
    misclassified = torch.cat(misclassified)
    misclassified_pred = torch.cat(misclassified_pred)
    misclassified_target = torch.cat(misclassified_target)

    return misclassified, misclassified_pred, misclassified_target
예제 #12
0
    def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if not _TORCHVISION_AVAILABLE:
            return

        images, _ = next(iter(DataLoader(trainer.datamodule.mnist_val, batch_size=self.num_samples)))
        images_flattened = images.view(images.size(0), -1)

        # generate images
        with torch.no_grad():
            pl_module.eval()
            images_generated = pl_module(images_flattened.to(pl_module.device))
            pl_module.train()

        if trainer.current_epoch == 0:
            save_image(self._to_grid(images), f"grid_ori_{trainer.current_epoch}.png")
        save_image(self._to_grid(images_generated.reshape(images.shape)), f"grid_generated_{trainer.current_epoch}.png")
예제 #13
0
    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        # show images only every 20 batches
        if (batch_idx + 1
            ) % self.logging_batch_interval != 0:  # type: ignore[attr-defined]
            return

        # pick the last batch and logits
        x, y = batch
        try:
            logits = pl_module.last_logits
        except AttributeError as err:
            m = """please track the last_logits in the training_step like so:
                def training_step(...):
                    self.last_logits = your_logits
            """
            raise AttributeError(m) from err

        # only check when it has opinions (ie: the logit > 5)
        if logits.max() > self.min_logit_value:  # type: ignore[operator]
            # pick the top two confused probs
            (values, idxs) = torch.topk(logits, k=2,
                                        dim=1)  # type: ignore[arg-type]

            # care about only the ones that are at most eps close to each other
            eps = self.max_logit_difference
            mask = (values[:, 0] - values[:, 1]).abs() < eps

            if mask.sum() > 0:
                # pull out the ones we care about
                confusing_x = x[mask, ...]
                confusing_y = y[mask]

                mask_idxs = idxs[mask]

                pl_module.eval()
                self._plot(confusing_x, confusing_y, trainer, pl_module,
                           mask_idxs)
                pl_module.train()
예제 #14
0
    def on_train_end(self, trainer: pl.Trainer,
                     pl_module: pl.LightningModule) -> None:
        r"""Called after training to export a model using TorchScript.

        Args:
            trainer:
                The :class:`pytorch_lightning.Trainer` instance

            pl_module:
                The :class:`pytorch_lightning.LightningModule` to export.
        """
        # check _device annotation is not ...
        # scripting will fail if _device type annotation is not overridden
        device = pl_module.__annotations__.get("_device")
        if device is None or device == ...:
            raise RuntimeError(
                "Please override type annotation for pl_module._device for scripting to work. "
                "Using _deivce: torch.device seems to work.")

        # get training state of model so it can be restored later
        training = pl_module.training
        if training:
            pl_module.eval()

        path = self.path if self.path is not None else self._get_default_save_path(
            trainer)

        if self.trace and self.sample_input is None:
            if not hasattr(pl_module, "example_input_array"):
                raise RuntimeError(
                    "Trace export was requested, but sample_input was not given and "
                    "module.example_input_array was not set.")
            self.sample_input = pl_module.example_input_array

        if self.trace:
            log.debug("Tracing %s", pl_module.__class__.__name__)
            script = self._get_trace(pl_module)
        else:
            log.debug("Scripting %s", pl_module.__class__.__name__)
            script = self._get_script(pl_module)
        torch.jit.save(script, path)
        log.info("Exported ScriptModule to %s", path)

        # restore training state
        if training:
            pl_module.train()
예제 #15
0
    def on_epoch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
    ) -> None:
        dim = (self.num_samples, pl_module.hparams.latent_dim)
        z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device)

        # generate images
        with torch.no_grad():
            pl_module.eval()
            images = pl_module(z)
            pl_module.train()

        grid = torchvision.utils.make_grid(images, nrow=self.nrows)
        str_title = f'{pl_module.__class__.__name__}_images_{trainer.current_epoch}'
        image_file = tensor_to_file_like_object(grid, img_size=self.image_size)
        self.tg_logger.write_image(image_file, caption=str_title)
예제 #16
0
    def test_forward(self, model: pl.LightningModule, data: torch.Tensor, training: bool):
        r"""Calls ``model.forward()`` and tests that the output is not ``None``.

        Because of the size of some models, this test is only run when a GPU is available.
        """
        if isinstance(model, (LightningDistributedDataParallel, LightningDistributedModule)):
            pytest.skip()

        if torch.cuda.is_available():
            model = model.cuda()  # type: ignore
            data = data.cuda()

        if training:
            model.train()
        else:
            model.eval()

        _ = model(data)

        assert _ is not None
예제 #17
0
    def test_validation_step(self, model: pl.LightningModule):
        r"""Runs a validation step based on the data returned from ``model.val_dataloader()``.
        Tests that the dictionary returned from ``validation_step()`` are as required by PyTorch
        Lightning.

        Because of the size of some models, this test is only run when a GPU is available.
        """
        if isinstance(model, LightningDistributedDataParallel):
            pytest.skip()

        check_overriden(model, "val_dataloader")
        check_overriden(model, "validation_step")

        dl = model.val_dataloader()
        # TODO this can't handle multiple optimizers
        batch = next(iter(dl))

        if torch.cuda.is_available():
            batch = [x.cuda() for x in batch]
            model = model.cuda()

        model.eval()
        output = model.validation_step(batch, 0)

        assert isinstance(output, dict)

        if "loss" in output.keys():
            assert isinstance(output["loss"], Tensor)
            assert output["loss"].shape == (1, )

        if "log" in output.keys():
            assert isinstance(output["log"], dict)

        if "progress_bar" in output.keys():
            assert isinstance(output["progress_bar"], dict)

        return batch, output
    def on_train_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        pl_module.eval()
        logger.info("Generating sample")
        all_images_list = generate_sample_radial(
            model=pl_module,
            features=self._features,
            targets=self._targets,
            iterations=self._iterations,
            image_size=self._image_size,
            rotations=self._rotations,
            batch_size=self._batch_size,
            device=pl_module.device,
        )

        all_images = torch.stack(all_images_list, dim=0).detach()
        all_images = 0.5 * (all_images + 1)

        img = make_grid(all_images).permute(1, 2, 0).cpu().numpy()

        trainer.logger.experiment.add_image(
            "img", torch.tensor(img).permute(2, 0, 1), global_step=trainer.global_step
        )