Beispiel #1
0
def run(opt):
    img_shape = (1, opt.img_size, opt.img_size)  # one channel only
    global LOGGER
    global RUN

    generator = Generator(opt.batch_size)
    discriminator = Discriminator(opt.batch_size)
    adversarial_loss = torch.nn.BCELoss()

    # Initialize optimizers for generator and discriminator
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.g_lr)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.d_lr)

    # Map to CUDA if necessary
    if CUDA:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Create checkpoint handler and load state if required
    current_epoch = 0
    checkpoint_g = Checkpoint(CWD, generator, optimizer_g)
    checkpoint_d = Checkpoint(CWD, discriminator, optimizer_d)

    #Randomize weights
    #self.G.model.weight_init(mean=0, std=0.5)
    #self.D.model.weight_init(mean=0, std=0.5)

    if opt.resume:
        RUN, current_epoch = checkpoint_g.load()
        _, _ = checkpoint_d.load()
        LOGGER = Logger(CWD, RUN)
        print('Loaded models from disk. Starting at epoch {}.'.format(
            current_epoch + 1))
    else:
        LOGGER = Logger(CWD, RUN)

    # Configure data loader
    mnist_loader = data_loader.mnist(opt, True)

    for epoch in range(current_epoch, opt.n_epochs):
        for i, imgs in enumerate(mnist_loader):

            # Adversarial ground truths
            valid = Variable(TENSOR(imgs.shape[0], 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(TENSOR(imgs.shape[0], 1).fill_(0.0),
                            requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_g.zero_grad()

            # Sample noise as generator input
            z = Variable(
                TENSOR(torch.randn((opt.batch_size, 100)).view(-1, 100, 1, 1)))
            #z = Variable(TENSOR(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

            # Generate a batch of images
            fake_images = generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(fake_images), valid)

            g_loss.backward()
            optimizer_g.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_d.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            #if (previous_d_G_z > threshold and previous_d_z < threshold2) or epoch_range > epoch:
            real_scores = discriminator(real_imgs)
            real_loss = adversarial_loss(real_scores, valid)
            fake_scores = discriminator(fake_images.detach())
            fake_loss = adversarial_loss(fake_scores, fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_d.step()

            batches_done = epoch * len(mnist_loader) + i + 1
            if batches_done % opt.sample_interval == 0:
                LOGGER.log_generated_sample(fake_images, batches_done)

                LOGGER.log_batch_statistics(epoch, opt.n_epochs, i + 1,
                                            len(mnist_loader), d_loss, g_loss,
                                            real_scores, fake_scores)

                LOGGER.log_tensorboard_basic_data(g_loss, d_loss, real_scores,
                                                  fake_scores, batches_done)

                if opt.log_details:
                    LOGGER.save_image_grid(real_imgs, fake_images,
                                           batches_done)
                    LOGGER.log_tensorboard_parameter_data(
                        discriminator, generator, batches_done)
        # -- Save model checkpoints after each epoch -- #
        checkpoint_g.save(RUN, epoch)
        checkpoint_d.save(RUN, epoch)
    LOGGER.close_writers()
Beispiel #2
0
def run(args: Namespace):
    global LOGGER
    global RUN

    img_shape = (1, args.img_size, args.img_size)

    # noinspection PyMethodMayBeStatic
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()

            def block(in_feat, out_feat, normalize=True):
                layers = [nn.Linear(in_feat, out_feat)]
                if normalize:
                    layers.append(nn.BatchNorm1d(out_feat, 0.8))
                layers.append(nn.LeakyReLU(0.2, inplace=True))
                return layers

            self.model = nn.Sequential(
                *block(args.latent_dim, 128, normalize=False),
                *block(128, 256), *block(256, 512), *block(512, 1024),
                nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())

        def forward(self, z_batch):
            img = st_heaviside.straight_through(self.model(z_batch))

            return img.view(img.size(0), *img_shape)

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()

            self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
                                       nn.LeakyReLU(0.2, inplace=True),
                                       nn.Linear(512, 256),
                                       nn.LeakyReLU(0.2, inplace=True),
                                       nn.Linear(256, 1), nn.Sigmoid())

        def forward(self, img):
            img_flat = img.view(img.shape[0], -1)
            validity = self.model(img_flat)
            return validity

    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    # Initialize optimizers for generator and discriminator
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=args.g_lr)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=args.d_lr)

    # Map to CUDA if necessary
    if CUDA:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_xavier)
    discriminator.apply(weights_init_xavier)

    # Create checkpoint handler and load state if required
    current_epoch = 0
    checkpoint_g = Checkpoint(CWD, generator, optimizer_g)
    checkpoint_d = Checkpoint(CWD, discriminator, optimizer_d)
    if args.resume:
        RUN, current_epoch = checkpoint_g.load()
        _, _ = checkpoint_d.load()
        LOGGER = Logger(CWD, RUN, args)
        print('Loaded models from disk. Starting at epoch {}.'.format(
            current_epoch + 1))
    else:
        LOGGER = Logger(CWD, RUN, args)

    # Configure data loader
    opts = {'binary': True, 'crop': 20}
    mnist_loader = data_loader.load(args, opts)

    for epoch in range(current_epoch, args.n_epochs):
        for i, imgs in enumerate(mnist_loader):

            # Adversarial ground truths with noise
            valid = 0.8 + torch.rand(imgs.shape[0], 1).type(TENSOR) * 0.3
            valid = Variable(valid, requires_grad=False)
            fake = torch.rand(imgs.shape[0], 1).type(TENSOR) * 0.3
            fake = Variable(fake, requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_g.zero_grad()

            # Sample noise as generator input
            z = Variable(
                torch.randn(imgs.shape[0], args.latent_dim).type(TENSOR))

            # Generate a batch of images
            fake_images = generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(fake_images), valid)

            g_loss.backward()
            optimizer_g.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_d.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_scores = discriminator(real_imgs)
            real_loss = adversarial_loss(real_scores, valid)
            fake_scores = discriminator(fake_images.detach())
            fake_loss = adversarial_loss(fake_scores, fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_d.step()

            batches_done = epoch * len(mnist_loader) + i + 1
            if batches_done % args.sample_interval == 0:
                LOGGER.log_generated_sample(fake_images, batches_done)

                LOGGER.log_batch_statistics(epoch, args.n_epochs, i + 1,
                                            len(mnist_loader), d_loss, g_loss,
                                            real_scores, fake_scores)

                LOGGER.log_tensorboard_basic_data(g_loss, d_loss, real_scores,
                                                  fake_scores, batches_done)

                if args.log_details:
                    if batches_done == args.sample_interval:
                        LOGGER.save_image_grid(real_imgs, fake_images,
                                               batches_done)
                    else:
                        LOGGER.save_image_grid(None, fake_images, batches_done)
        # -- Save model checkpoints after each epoch -- #
        checkpoint_g.save(RUN, epoch)
        checkpoint_d.save(RUN, epoch)
    LOGGER.close_writers()
Beispiel #3
0
def run(args: argparse.Namespace):
    global LOGGER
    global RUN

    img_shape = (1, args.img_size, args.img_size)

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()

            def block(in_feat, out_feat, normalize=True):
                layers = [nn.Linear(in_feat, out_feat)]
                if normalize:
                    layers.append(nn.BatchNorm1d(out_feat, 0.8))
                layers.append(nn.LeakyReLU(0.2, inplace=True))

                return layers

            self.model = nn.Sequential(
                *block(args.latent_dim, 128, normalize=False),
                *block(128, 256),
                *block(256, 512),
                *block(512, 1024),
                nn.Linear(1024, int(np.prod(img_shape))),
            )

            self.out = nn.LogSigmoid()

        def forward(self, z_batch):
            linear = self.model(z_batch)

            white_prob = self.out(linear).view(args.batch_size,
                                               args.img_size**2, 1)
            black_prob = self.out(-linear).view(args.batch_size,
                                                args.img_size**2, 1)
            probs = torch.cat([black_prob, white_prob], dim=-1)
            img = st_gumbel_softmax.straight_through(probs, args.temp, True)

            return img.view(img.shape[0], *img_shape)

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()

            self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
                                       nn.LeakyReLU(0.2, inplace=True),
                                       nn.Linear(512, 256),
                                       nn.LeakyReLU(0.2, inplace=True),
                                       nn.Linear(256, 1))

        def forward(self, img):
            img_flat = img.view(img.shape[0], -1)
            validity = self.model(img_flat)

            return validity

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    # Optimizers
    optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=args.g_lr)
    optimizer_d = torch.optim.RMSprop(discriminator.parameters(), lr=args.d_lr)

    if CUDA:
        generator.cuda()
        discriminator.cuda()

    # Initialize weights
    generator.apply(weights_init_xavier)
    discriminator.apply(weights_init_xavier)

    # Create checkpoint handler and load state if required
    current_epoch = 0
    checkpoint_g = Checkpoint(CWD, generator, optimizer_g)
    checkpoint_d = Checkpoint(CWD, discriminator, optimizer_d)
    if args.resume:
        RUN, current_epoch = checkpoint_g.load()
        _, _ = checkpoint_d.load()
        LOGGER = Logger(CWD, RUN, args)
        print('Loaded models from disk. Starting at epoch {}.'.format(
            current_epoch + 1))
    else:
        LOGGER = Logger(CWD, RUN, args)

    # Configure data loader
    opts = {
        'binary': True,
    }
    batched_data = data_loader.load(args, opts)

    # ----------
    #  Training
    # ----------

    for epoch in range(current_epoch, args.n_epochs):
        for i, imgs in enumerate(batched_data):
            # Configure input
            real_imgs = Variable(imgs.type(TENSOR))

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_d.zero_grad()

            # Sample noise as generator input
            z = Variable(
                torch.randn(imgs.shape[0], args.latent_dim).type(TENSOR))

            # Generate a batch of images
            fake_images = generator(z).detach()
            # Adversarial loss
            loss_d = -torch.mean(discriminator(real_imgs)) + torch.mean(
                discriminator(fake_images))

            loss_d.backward()
            optimizer_d.step()

            # Clip weights of discriminator
            for p in discriminator.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)

            batches_done = epoch * len(batched_data) + i + 1
            # Train the generator every n_critic iterations
            if batches_done % args.n_critic == 0:
                # -----------------
                #  Train Generator
                # -----------------

                optimizer_g.zero_grad()

                # Generate a batch of images
                fake_images = generator(z)
                # Adversarial loss
                loss_g = -torch.mean(discriminator(fake_images))

                loss_g.backward()
                optimizer_g.step()

                if batches_done % args.sample_interval == 0:
                    LOGGER.log_generated_sample(fake_images, batches_done)

                    LOGGER.log_batch_statistics(epoch, args.n_epochs, i + 1,
                                                len(batched_data), loss_d,
                                                loss_g)

                    LOGGER.log_tensorboard_basic_data(loss_g,
                                                      loss_d,
                                                      step=batches_done)

                    if args.log_details:
                        if batches_done == args.sample_interval:
                            LOGGER.save_image_grid(real_imgs, fake_images,
                                                   batches_done)
                        else:
                            LOGGER.save_image_grid(None, fake_images,
                                                   batches_done)
        # -- Save model checkpoints after each epoch -- #
        checkpoint_g.save(RUN, epoch)
        checkpoint_d.save(RUN, epoch)
Beispiel #4
0
def run(opt):
    global LOGGER
    global RUN

    # noinspection PyMethodMayBeStatic
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()

            self.init_size = opt.maze_size**2 // 4
            self.l1 = nn.Sequential(
                nn.Linear(opt.latent_dim, 128 * self.init_size))
            self.model = nn.Sequential(
                nn.BatchNorm1d(128), nn.Upsample(scale_factor=2),
                nn.Conv1d(128, 128, 3, stride=1, padding=1),
                nn.BatchNorm1d(128, 0.8), nn.LeakyReLU(0.2, inplace=True),
                nn.Upsample(scale_factor=2),
                nn.Conv1d(128, 64, 3, stride=1, padding=1),
                nn.BatchNorm1d(64, 0.8), nn.LeakyReLU(0.2, inplace=True),
                nn.Conv1d(64, 1, 3, stride=1, padding=1), nn.Tanh())

        def forward(self, z):
            out = self.l1(z)
            out = out.view(out.shape[0], 128, self.init_size)
            fake_mazes = self.model(out)

            fake_mazes = straight_through(fake_mazes)
            return fake_mazes

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()

            def discriminator_block(in_filters, out_filters, bn=True):
                block = [
                    nn.Conv1d(in_filters, out_filters, 3, 2, 1),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Dropout(0.25)
                ]
                if bn:
                    block.append(nn.BatchNorm1d(out_filters, 0.8))
                return block

            self.model = nn.Sequential(
                *discriminator_block(1, 16, bn=False),
                *discriminator_block(16, 32),
                *discriminator_block(32, 64),
                *discriminator_block(64, 128),
            )

            # The height and width of downsampled image
            ds_size = math.ceil((opt.maze_size**2) / 4**2)
            self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size**1, 1),
                                           nn.Sigmoid())

        def forward(self, maze):
            out = self.model(maze)
            out = out.view(out.shape[0], -1)
            validity = self.adv_layer(out)

            return validity

    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    # Initialize optimizers for generator and discriminator
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.g_lr)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=opt.d_lr)

    # Map to CUDA if necessary
    if CUDA:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_xavier)
    discriminator.apply(weights_init_xavier)

    # Create checkpoint handler and load state if required
    current_epoch = 0
    checkpoint_g = Checkpoint(CWD, generator, optimizer_g)
    checkpoint_d = Checkpoint(CWD, discriminator, optimizer_d)
    if opt.resume:
        RUN, current_epoch = checkpoint_g.load()
        _, _ = checkpoint_d.load()
        LOGGER = Logger(CWD, RUN, opt)
        print('Loaded models from disk. Starting at epoch {}.'.format(
            current_epoch + 1))
    else:
        LOGGER = Logger(CWD, RUN, opt)

    # Configure data loader
    opts = {
        'binary': True,
    }
    maze_loader = data_loader.load(opt, opts)

    for epoch in range(current_epoch, opt.n_epochs):
        for i, mazes in enumerate(maze_loader):
            mazes = mazes.reshape(opt.batch_size, -1).type(TENSOR).float()

            # Adversarial ground truths
            #            valid = Variable(torch.ones(mazes.shape[0], 1).type(TENSOR), requires_grad=False)
            #            fake = Variable(torch.zeros(mazes.shape[0], 1).type(TENSOR), requires_grad=False)
            # Adversarial ground truths with noise
            valid = 0.8 + torch.rand(mazes.shape[0], 1).type(TENSOR) * 0.3
            valid = Variable(valid, requires_grad=False)
            fake = torch.rand(mazes.shape[0], 1).type(TENSOR) * 0.3
            fake = Variable(fake, requires_grad=False)

            # Configure input
            real_mazes = Variable(mazes)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_g.zero_grad()

            # Sample noise as generator input
            z = Variable(
                torch.randn(mazes.shape[0], opt.latent_dim).type(TENSOR))

            # Generate a batch of images
            fake_mazes = generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(fake_mazes), valid)

            g_loss.backward()
            optimizer_g.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_d.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_mazes = real_mazes.unsqueeze_(1)
            real_scores = discriminator(real_mazes)
            real_loss = adversarial_loss(real_scores, valid)
            fake_scores = discriminator(fake_mazes.detach())
            fake_loss = adversarial_loss(fake_scores, fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_d.step()

            batches_done = epoch * len(maze_loader) + i + 1
            if batches_done % opt.sample_interval == 0:
                fake_mazes = fake_mazes.reshape(fake_mazes.size(0),
                                                opt.maze_size, opt.maze_size)
                fake_mazes[fake_mazes < 0.5] = 0
                fake_mazes[fake_mazes > 0.5] = 1
                #correct = 0
                #for maze in fake_mazes:
                #    correct += int(check_maze(maze.detach()))
                #print(correct)
                real_mazes = real_mazes.reshape(real_mazes.size(0),
                                                opt.maze_size, opt.maze_size)
                LOGGER.log_generated_sample(fake_mazes, batches_done)

                LOGGER.log_batch_statistics(epoch, opt.n_epochs, i + 1,
                                            len(maze_loader), d_loss, g_loss,
                                            real_scores, fake_scores)

                LOGGER.log_tensorboard_basic_data(g_loss, d_loss, real_scores,
                                                  fake_scores, batches_done)

                if opt.log_details:
                    LOGGER.save_image_grid(real_mazes, fake_mazes,
                                           batches_done)
                    # LOGGER.log_tensorboard_parameter_data(discriminator, generator, batches_done)
        # -- Save model checkpoints after each epoch -- #
        checkpoint_g.save(RUN, epoch)
        checkpoint_d.save(RUN, epoch)
    LOGGER.close_writers()
Beispiel #5
0
def run(args: Namespace):
    global LOGGER
    global RUN

    # noinspection PyMethodMayBeStatic
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()

            self.init_size = args.img_size // 4
            self.filters = 32
            self.l1 = nn.Linear(args.latent_dim, self.filters * self.init_size ** 2)

            self.conv_blocks = nn.Sequential(
                nn.BatchNorm2d(self.filters),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(self.filters, self.filters, 3, stride=1, padding=1),
                nn.BatchNorm2d(self.filters, 0.8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(self.filters, self.filters // 2, 3, stride=1, padding=1),
                nn.BatchNorm2d(self.filters // 2, 0.8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(self.filters // 2, 1, 3, stride=1, padding=1),
                nn.Tanh()
            )

        def forward(self, z_batch):
            out = self.l1(z_batch)
            out = out.view(out.shape[0], self.filters, self.init_size, self.init_size)
            img = self.conv_blocks(out)
            return img

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()

            self.filters = 8

            def discriminator_block(in_filters, out_filters, step, bn=True):
                block = [nn.Conv2d(in_filters, out_filters, 3, step, 1),
                         nn.LeakyReLU(0.2, inplace=True),
                         nn.Dropout2d(0.25)]
                if bn:
                    block.append(nn.BatchNorm2d(out_filters, 0.8))
                return block

            self.model = nn.Sequential(
                *discriminator_block(1, self.filters, 1, bn=False),
                *discriminator_block(self.filters, self.filters * 2, 2),
                *discriminator_block(self.filters * 2, self.filters * 4, 1),
                *discriminator_block(self.filters * 4, self.filters * 8, 2),
            )

            # The height and width of downsampled image
            ds_size = args.img_size // 2 ** 2
            self.adv_layer = nn.Sequential(
                nn.Linear(self.filters * 8 * ds_size ** 2, 1),
                nn.Sigmoid()
            )

        def forward(self, img):
            out = self.model(img)
            out = out.view(out.shape[0], -1)
            validity = self.adv_layer(out)

            return validity

    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    # Initialize optimizers for generator and discriminator
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=args.g_lr)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=args.d_lr)

    # Map to CUDA if necessary
    if CUDA:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_xavier)
    discriminator.apply(weights_init_xavier)

    # Create checkpoint handler and load state if required
    current_epoch = 0
    checkpoint_g = Checkpoint(CWD, generator, optimizer_g)
    checkpoint_d = Checkpoint(CWD, discriminator, optimizer_d)
    if args.resume:
        RUN, current_epoch = checkpoint_g.load()
        _, _ = checkpoint_d.load()
        LOGGER = Logger(CWD, RUN, args)
        print('Loaded models from disk. Starting at epoch {}.'.format(current_epoch + 1))
    else:
        LOGGER = Logger(CWD, RUN, args)

    # Configure data loader
    opts = {
        'binary': False,
        'crop': 20
    }
    mnist_loader = data_loader.load(args, opts)

    for epoch in range(current_epoch, args.n_epochs):
        for i, imgs in enumerate(mnist_loader):

            # Adversarial ground truths with noise
            valid = 0.8 + torch.rand(imgs.shape[0], 1).type(TENSOR) * 0.3
            valid = Variable(valid, requires_grad=False)
            fake = torch.rand(imgs.shape[0], 1).type(TENSOR) * 0.3
            fake = Variable(fake, requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_g.zero_grad()

            # Sample noise as generator input
            z = Variable(torch.randn(imgs.shape[0], args.latent_dim).type(TENSOR))

            # Generate a batch of images
            fake_images = generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(fake_images), valid)

            g_loss.backward()
            optimizer_g.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_d.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_scores = discriminator(real_imgs)
            real_loss = adversarial_loss(real_scores, valid)
            fake_scores = discriminator(fake_images.detach())
            fake_loss = adversarial_loss(fake_scores, fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_d.step()

            batches_done = epoch * len(mnist_loader) + i + 1
            if batches_done % args.sample_interval == 0:
                LOGGER.log_generated_sample(fake_images, batches_done)

                LOGGER.log_batch_statistics(epoch, args.n_epochs, i + 1, len(mnist_loader), d_loss, g_loss, real_scores,
                                            fake_scores)

                LOGGER.log_tensorboard_basic_data(g_loss, d_loss, real_scores, fake_scores, batches_done)

                if args.log_details:
                    if batches_done == args.sample_interval:
                        LOGGER.save_image_grid(real_imgs, fake_images, batches_done)
                    else:
                        LOGGER.save_image_grid(None, fake_images, batches_done)
        # -- Save model checkpoints after each epoch -- #
        checkpoint_g.save(RUN, epoch)
        checkpoint_d.save(RUN, epoch)
    LOGGER.close_writers()
Beispiel #6
0
def run(args: Namespace):
    global LOGGER
    global RUN

    # noinspection PyMethodMayBeStatic
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()

            self.init_size = args.maze_size**2 // 4
            self.l1 = nn.Sequential(
                nn.Linear(args.latent_dim, 128 * self.init_size))
            self.model = nn.Sequential(
                nn.BatchNorm1d(128),
                nn.Upsample(scale_factor=2),
                nn.Conv1d(128, 128, 3, stride=1, padding=1),
                nn.BatchNorm1d(128, 0.8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Upsample(scale_factor=2),
                nn.Conv1d(128, 64, 3, stride=1, padding=1),
                nn.BatchNorm1d(64, 0.8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv1d(64, 1, 3, stride=1, padding=1),
            )

            self.out = nn.LogSigmoid()

        def forward(self, z):
            map1 = self.l1(z)
            map1 = map1.view(map1.size(0), 128, self.init_size)
            conv = self.model(map1).view(args.batch_size, args.img_size**2, 1)

            white_prob = self.out(conv).view(args.batch_size, args.img_size**2,
                                             1)
            black_prob = self.out(-conv).view(args.batch_size,
                                              args.img_size**2, 1)

            probs = torch.cat([black_prob, white_prob], dim=-1)
            img = st_gumbel_softmax.straight_through(probs, args.temp, True)

            return img.view(args.batch_size, 1, args.img_size**2)

    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()

            def discriminator_block(in_filters, out_filters, bn=True):
                block = [
                    nn.Conv1d(in_filters, out_filters, 3, 2, 1),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Dropout(0.25)
                ]
                if bn:
                    block.append(nn.BatchNorm1d(out_filters, 0.8))
                return block

            self.model = nn.Sequential(
                *discriminator_block(1, 16, bn=False),
                *discriminator_block(16, 32, bn=False),
                *discriminator_block(32, 64, bn=False),
                *discriminator_block(64, 128, bn=False),
            )

            # The height and width of down sampled image
            ds_size = math.ceil((args.img_size**2) / 4**2)
            self.adv_layer = nn.Linear(128 * ds_size, 1)

        def forward(self, maze):
            out = self.model(maze)
            out = out.view(out.shape[0], -1)
            validity = self.adv_layer(out)

            return validity

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    # Initialize optimizers for generator and discriminator
    optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=args.g_lr)
    optimizer_d = torch.optim.RMSprop(discriminator.parameters(), lr=args.d_lr)

    # Map to CUDA if necessary
    if CUDA:
        generator.cuda()
        discriminator.cuda()

    # Initialize weights
    generator.apply(weights_init_xavier)
    discriminator.apply(weights_init_xavier)

    # Create checkpoint handler and load state if required
    current_epoch = 0
    checkpoint_g = Checkpoint(CWD, generator, optimizer_g)
    checkpoint_d = Checkpoint(CWD, discriminator, optimizer_d)
    if args.resume:
        RUN, current_epoch = checkpoint_g.load()
        _, _ = checkpoint_d.load()
        LOGGER = Logger(CWD, RUN, args)
        print('Loaded models from disk. Starting at epoch {}.'.format(
            current_epoch + 1))
    else:
        LOGGER = Logger(CWD, RUN, args)

    # Configure data loader
    opts = {
        'binary': True,
    }
    batched_data = data_loader.load(args, opts)

    for epoch in range(current_epoch, args.n_epochs):
        for i, mazes in enumerate(batched_data):
            batches_done = epoch * len(batched_data) + i + 1

            mazes = mazes.reshape(args.batch_size, 1, -1).type(TENSOR)

            # Configure input
            real_images = Variable(mazes)

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_d.zero_grad()

            z = Variable(
                torch.randn(real_images.size(0), args.latent_dim).type(TENSOR))
            fake_images = generator(z).detach()
            # Adversarial loss
            loss_d = -torch.mean(discriminator(real_images)) + torch.mean(
                discriminator(fake_images))

            loss_d.backward()
            optimizer_d.step()

            # Clip weights of discriminator
            for p in discriminator.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)

            # Train the generator every n_critic iterations
            if batches_done % args.n_critic == 0:
                # -----------------
                #  Train Generator
                # -----------------

                optimizer_g.zero_grad()

                # Generate a batch of images
                fake_images = generator(z)
                # Adversarial loss
                loss_g = -torch.mean(discriminator(fake_images))

                loss_g.backward()
                optimizer_g.step()

                if batches_done % args.sample_interval == 0:
                    fake_mazes = fake_images.reshape(fake_images.size(0),
                                                     args.img_size,
                                                     args.img_size)
                    fake_mazes[fake_mazes < 0.5] = 0
                    fake_mazes[fake_mazes > 0.5] = 1
                    real_mazes = real_images.reshape(real_images.size(0),
                                                     args.img_size,
                                                     args.img_size)

                    LOGGER.log_generated_sample(fake_mazes, batches_done)

                    LOGGER.log_batch_statistics(epoch, args.n_epochs, i + 1,
                                                len(batched_data), loss_d,
                                                loss_g)

                    LOGGER.log_tensorboard_basic_data(loss_g,
                                                      loss_d,
                                                      step=batches_done)

                    if args.log_details:
                        if batches_done == args.sample_interval:
                            LOGGER.save_image_grid(real_mazes, fake_mazes,
                                                   batches_done)
                        else:
                            LOGGER.save_image_grid(None, fake_images,
                                                   batches_done)

        # -- Save model checkpoints after each epoch -- #
        checkpoint_g.save(RUN, epoch)
        checkpoint_d.save(RUN, epoch)
    LOGGER.close_writers()