예제 #1
0
파일: trainer.py 프로젝트: kahartma/eeggan
    def train_batch(self, engine, batch) -> BatchOutput:
        with torch.no_grad():
            batch_real: Data[torch.Tensor] = Data(batch[0], batch[1], batch[2])

            latent, y_fake, y_onehot_fake = to_device(
                batch_real.X.device,
                *self.generator.create_latent_input(self.rng,
                                                    len(batch_real.X)))
            X_fake = self.generator(latent, y=y_fake, y_onehot=y_onehot_fake)
            batch_fake: Data[torch.Tensor] = Data(X_fake, y_fake,
                                                  y_onehot_fake)

        batch_real, batch_fake, latent = detach_all(batch_real, batch_fake,
                                                    latent)
        loss_d = self.train_discriminator(batch_real, batch_fake, latent)
        batch_real = detach_all(batch_real)
        loss_g = self.train_generator(batch_real)

        with torch.no_grad():
            latent, y_fake, y_onehot_fake = to_device(
                batch_real.X.device,
                *self.generator.create_latent_input(self.rng,
                                                    len(batch_real.X)))
            X_fake = self.generator(latent, y=y_fake, y_onehot=y_onehot_fake)
            batch_fake: Data[torch.Tensor] = Data(X_fake, y_fake,
                                                  y_onehot_fake)

        batch_real, batch_fake, latent = detach_all(batch_real, batch_fake,
                                                    latent)
        return BatchOutput(engine.state.iteration, engine.state.epoch,
                           batch_real, batch_fake, latent, loss_d, loss_g)
예제 #2
0
파일: metrics.py 프로젝트: kahartma/eeggan
 def update(self, batch_output: BatchOutput) -> None:
     with torch.no_grad():
         X_real, = to_device(
             batch_output.batch_real.X.device,
             Tensor(
                 upsample(batch_output.batch_real.X.data.cpu().numpy(),
                          self.upsample_factor,
                          axis=2)))
         X_real = X_real[:, :, :, None]
         X_fake, = to_device(
             batch_output.batch_fake.X.device,
             Tensor(
                 upsample(batch_output.batch_fake.X.data.cpu().numpy(),
                          self.upsample_factor,
                          axis=2)))
         X_fake = X_fake[:, :, :, None]
         epoch = batch_output.i_epoch
         dists = []
         for deep4 in self.deep4s:
             mu_real, sig_real = calculate_activation_statistics(
                 deep4(X_real)[0])
             mu_fake, sig_fake = calculate_activation_statistics(
                 deep4(X_fake)[0])
             dist = calculate_frechet_distances(
                 mu_real[None, :, :], sig_real[None, :, :],
                 mu_fake[None, :, :], sig_fake[None, :, :]).item()
             dists.append(dist)
         self.append((epoch, (np.mean(dists).item(), np.std(dists).item())))
예제 #3
0
    def train_generator(self, batch_real: Data[torch.Tensor]):
        self.generator.zero_grad()
        self.optim_generator.zero_grad()
        self.generator.train(True)
        self.discriminator.train(False)

        with torch.no_grad():
            latent, y_fake, y_onehot_fake = to_device(
                batch_real.X.device,
                *self.generator.create_latent_input(self.rng,
                                                    len(batch_real.X)))
            latent, y_fake, y_onehot_fake = detach_all(latent, y_fake,
                                                       y_onehot_fake)

        X_fake = self.generator(latent.requires_grad_(False),
                                y=y_fake.requires_grad_(False),
                                y_onehot=y_onehot_fake.requires_grad_(False))

        batch_fake: Data[torch.Tensor] = Data(X_fake, y_fake, y_onehot_fake)
        fx_fake = sigmoid(
            self.discriminator(
                batch_fake.X.requires_grad_(True),
                y=batch_fake.y.requires_grad_(True),
                y_onehot=batch_fake.y_onehot.requires_grad_(True)))
        loss = self.loss(fx_fake, torch.ones_like(fx_fake))
        loss.backward()

        self.optim_generator.step()

        return loss.item()
예제 #4
0
파일: metrics.py 프로젝트: kahartma/eeggan
 def update(self, batch_output: BatchOutput) -> None:
     X_fake, = to_device(
         batch_output.batch_fake.X.device,
         Tensor(
             upsample(batch_output.batch_fake.X.data.cpu().numpy(),
                      self.upsample_factor,
                      axis=2)))
     X_fake = X_fake[:, :, :, None]
     epoch = batch_output.i_epoch
     accuracies = []
     for deep4 in self.deep4s:
         with torch.no_grad():
             preds: Tensor = deep4(X_fake)[1].squeeze()
             class_pred = preds.argmax(dim=1)
             accuracy = (class_pred == batch_output.batch_fake.y).type(
                 torch.float).mean()
             accuracies.append(accuracy.item())
     self.append(
         (epoch, (np.mean(accuracies).item(), np.std(accuracies).item())))
예제 #5
0
파일: metrics.py 프로젝트: kahartma/eeggan
 def update(self, batch_output: BatchOutput) -> None:
     X_fake, = to_device(
         batch_output.batch_fake.X.device,
         Tensor(
             upsample(batch_output.batch_fake.X.data.cpu().numpy(),
                      self.upsample_factor,
                      axis=2)))
     X_fake = X_fake[:, :, :, None]
     epoch = batch_output.i_epoch
     score_means = []
     score_stds = []
     for deep4 in self.deep4s:
         with torch.no_grad():
             preds = deep4(X_fake)[1]
             preds = logsoftmax_act_to_softmax(preds)
             score_mean, score_std = calculate_inception_score(
                 preds, self.splits, self.repetitions)
         score_means.append(score_mean)
         score_stds.append(score_std)
     self.append(
         (epoch, (np.mean(score_means).item(), np.mean(score_stds).item())))