コード例 #1
0
ファイル: train.py プロジェクト: iver56/image-regression
    def on_train_epoch_end(self, trainer, pl_module: pl.LightningModule,
                           outputs: Any) -> None:
        step_losses = [step[0]["minimize"].item() for step in outputs[0]]
        avg_loss = np.mean(step_losses)
        print("Loss: {:.6f}".format(avg_loss))
        loss_change = 1 - avg_loss / self.last_loss_checkpoint

        if loss_change > self.loss_change_threshold:
            self.last_loss_checkpoint = avg_loss

            predicted_pixels = np.zeros(
                shape=tensor_y.shape,
                dtype=np.float32,
            )
            with torch.no_grad():
                for offset in range(0, tensor_x.shape[0], args.batch_size):
                    pred = (pl_module.forward(
                        tensor_x[offset:offset +
                                 args.batch_size]).cpu().numpy())
                    predicted_pixels[offset:offset + pred.shape[0]] = pred

            predicted_images = predicted_pixels.reshape(
                (num_images, img_height, img_width))
            for img_idx in range(num_images):
                output_file_path = os.path.join(
                    "output",
                    "{0}_predicted_{1:04d}.png".format(
                        args.image_filenames[img_idx], trainer.current_epoch),
                )
                Image.fromarray(
                    np.clip(predicted_images[img_idx] * 256, 0,
                            255).astype(np.uint8)).save(output_file_path)
コード例 #2
0
    def setup(self, trainer: Trainer, pl_module: LightningModule,
              stage: str) -> None:
        """Override the model with a quantized-aware version on setup.

        This is the earliest place we can override this model which allows for
        appropriate behavior when restoring from checkpoints, as well as connecting
        to accelerators, etc.

        The model is only prepared once.
        """
        # Only prepare the model once.
        if hasattr(pl_module, "_prepared"):
            return

        with mode(pl_module, training=True) as train:
            pl_module._prepared = self.prepare(_deepcopy(train),
                                               configs=self.qconfig_dicts)
        pl_module.forward = MethodType(_quantized_forward, pl_module)
        self.prepared = pl_module._prepared
コード例 #3
0
ファイル: vog.py プロジェクト: mttcnnff/vog-optimizer
 def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
     if trainer.current_epoch % self.snapshot_interval == 0:
         train_dataloader = pl_module.train_dataloader()
         for x, y, global_example_indexes in tqdm(
                 train_dataloader,
                 desc="Snapshotting: ",
                 file=sys.stdout,
                 position=0,
                 leave=True,
         ):
             x.requires_grad = True
             logits = pl_module.forward(x)
             loss = F.nll_loss(logits, y)
             loss.backward()
             for global_index, grad, class_label in zip(
                     global_example_indexes.tolist(), x.grad, y):
                 self._x_idx_to_grads.setdefault(global_index,
                                                 []).append(grad)
                 self._x_idx_to_y[global_index] = class_label
コード例 #4
0
    def on_epoch_end(self, trainer: pl.Trainer,
                     pl_module: pl.LightningModule) -> None:

        epoch = trainer.current_epoch
        if epoch % self.n == 0:

            # save images
            fake_images = pl_module.forward(pl_module.global_z_for_validation)
            trainer.logger.experiment.log({
                "images": [wandb.Image(pl_module.fake_images, caption="fake")],
                "epoch":
                epoch,
            })
            # save models
            filename = f"{self.filename_prefix}_epoch_{epoch}.ckpt"
            ckpt_path = f"{self.file_path}/{filename}"
            torch.save(
                {
                    "generator": pl_module.generator.state_dict(),
                    "discriminator": pl_module.discriminator.state_dict(),
                },
                ckpt_path,
            )