コード例 #1
0
class Visualiser(object):
    def __init__(self, base_directory):
        self.d_loss_real = []
        self.d_loss_fake = []
        self.g_loss = []
        self.base_directory = base_directory

    def __enter__(self):
        self.queue = JobQueue(num_processes=1)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.queue.join()

    def _get_directory(self):
        os.makedirs(self.base_directory, exist_ok=True)
        return self.base_directory

    def step(self, d_loss_real, d_loss_fake, g_loss):
        self.d_loss_real.append(d_loss_real)
        self.d_loss_fake.append(d_loss_fake)
        self.g_loss.append(g_loss)

    def step_autoencoder(self, loss):
        self.g_loss.append(loss)

    def test(self, epoch, size_first, discriminator, generator, noise, real):
        generator.eval()
        discriminator.eval()
        out = generator(noise)
        self.queue.submit(
            GANTest(directory=self._get_directory(),
                    epoch=epoch,
                    size_first=size_first,
                    gen_out=out.cpu().data.numpy(),
                    real_out=real.cpu().data.numpy(),
                    discriminator_out=discriminator(out).cpu().data.numpy(),
                    discriminator_real=discriminator(real).cpu().data.numpy()))
        generator.train()
        discriminator.train()

    def test_autoencoder(self, epoch, generator, real):
        generator.eval()
        self.queue.submit(
            AutoEncoderTest(directory=self._get_directory(),
                            epoch=epoch,
                            out=generator(real[:10]).cpu().data.numpy(),
                            real=real[:10].cpu().data.numpy()))
        generator.train()

    def plot_training(self, epoch):
        self.queue.submit(
            PlotLearning(directory=self._get_directory(),
                         epoch=epoch,
                         d_loss_real=self.d_loss_real,
                         d_loss_fake=self.d_loss_fake,
                         g_loss=self.g_loss))
コード例 #2
0
ファイル: plots.py プロジェクト: lao19881213/rfi_ml
                # Create merged spectrograms for this p
                merged = self.merge_spectrograms(spectrograms)
                merged_normalised = self.merge_spectrograms(
                    spectrograms, normalise_local=True)
                self.save_spectrogram(merged, "merged",
                                      "spectrogram_merged.png")
                self.save_spectrogram(merged_normalised,
                                      "merged local normalisation",
                                      "spectrogram_merged_normalised.png")


if __name__ == "__main__":
    queue = JobQueue(8)

    # Load each file using a process pool
    num_samples = 102400
    queue.submit(
        LBAPlotter("../data/v255ae_At_072_060000.lba",
                   "./At_out/",
                   num_samples=num_samples))
    queue.submit(
        LBAPlotter("../data/v255ae_Mp_072_060000.lba",
                   "./Mp_out/",
                   num_samples=num_samples))
    queue.submit(
        LBAPlotter("../data/vt255ae_Pa_072_060000.lba",
                   "./Pa_out/",
                   num_samples=num_samples))

    queue.join()