Пример #1
0
    def on_train_epoch_end(self, trainer, pl_module: pl.LightningModule,
                           outputs: Any) -> None:
        step_losses = [step[0]["minimize"].item() for step in outputs[0]]
        avg_loss = np.mean(step_losses)
        print("Loss: {:.6f}".format(avg_loss))
        loss_change = 1 - avg_loss / self.last_loss_checkpoint

        if loss_change > self.loss_change_threshold:
            self.last_loss_checkpoint = avg_loss

            predicted_pixels = np.zeros(
                shape=tensor_y.shape,
                dtype=np.float32,
            )
            with torch.no_grad():
                for offset in range(0, tensor_x.shape[0], args.batch_size):
                    pred = (pl_module.forward(
                        tensor_x[offset:offset +
                                 args.batch_size]).cpu().numpy())
                    predicted_pixels[offset:offset + pred.shape[0]] = pred

            predicted_images = predicted_pixels.reshape(
                (num_images, img_height, img_width))
            for img_idx in range(num_images):
                output_file_path = os.path.join(
                    "output",
                    "{0}_predicted_{1:04d}.png".format(
                        args.image_filenames[img_idx], trainer.current_epoch),
                )
                Image.fromarray(
                    np.clip(predicted_images[img_idx] * 256, 0,
                            255).astype(np.uint8)).save(output_file_path)
Пример #2
0
    def setup(self, trainer: Trainer, pl_module: LightningModule,
              stage: str) -> None:
        """Override the model with a quantized-aware version on setup.

        This is the earliest place we can override this model which allows for
        appropriate behavior when restoring from checkpoints, as well as connecting
        to accelerators, etc.

        The model is only prepared once.
        """
        # Only prepare the model once.
        if hasattr(pl_module, "_prepared"):
            return

        with mode(pl_module, training=True) as train:
            pl_module._prepared = self.prepare(_deepcopy(train),
                                               configs=self.qconfig_dicts)
        pl_module.forward = MethodType(_quantized_forward, pl_module)
        self.prepared = pl_module._prepared
Пример #3
0
 def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
     if trainer.current_epoch % self.snapshot_interval == 0:
         train_dataloader = pl_module.train_dataloader()
         for x, y, global_example_indexes in tqdm(
                 train_dataloader,
                 desc="Snapshotting: ",
                 file=sys.stdout,
                 position=0,
                 leave=True,
         ):
             x.requires_grad = True
             logits = pl_module.forward(x)
             loss = F.nll_loss(logits, y)
             loss.backward()
             for global_index, grad, class_label in zip(
                     global_example_indexes.tolist(), x.grad, y):
                 self._x_idx_to_grads.setdefault(global_index,
                                                 []).append(grad)
                 self._x_idx_to_y[global_index] = class_label
Пример #4
0
    def on_epoch_end(self, trainer: pl.Trainer,
                     pl_module: pl.LightningModule) -> None:

        epoch = trainer.current_epoch
        if epoch % self.n == 0:

            # save images
            fake_images = pl_module.forward(pl_module.global_z_for_validation)
            trainer.logger.experiment.log({
                "images": [wandb.Image(pl_module.fake_images, caption="fake")],
                "epoch":
                epoch,
            })
            # save models
            filename = f"{self.filename_prefix}_epoch_{epoch}.ckpt"
            ckpt_path = f"{self.file_path}/{filename}"
            torch.save(
                {
                    "generator": pl_module.generator.state_dict(),
                    "discriminator": pl_module.discriminator.state_dict(),
                },
                ckpt_path,
            )