Esempio n. 1
0
D = nn.DataParallel(D)
D = D.cuda()

F = FeatureExtractor().cuda()
F.eval()

# settings
a = 1
b = 1
r_c = 1e-3
r_p = 1e-1
learning_rate = 2.5e-4
criterion_MSE = nn.MSELoss()
criterion_BCE = nn.BCELoss()
optimizer_G = optim.RMSprop(G.parameters(), lr=learning_rate)
optimizer_D = optim.RMSprop(D.parameters(), lr=learning_rate)

# outputs
output_dir = './outputsGAN_{:%Y-%m-%d-%H-%M-%S}/'.format(datetime.now())
checkp_dir = os.path.join(output_dir, '_checkpoints')
logtxt_dir = os.path.join(output_dir, 'log.txt')
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkp_dir, exist_ok=True)

# train & valid
num_epoch = 100
for epoch_idx in range(1, num_epoch + 1):

    # train
    losses_D = []; losses_G = []
    for batch_idx, (image_lr, image_hr, hmaps, pmaps) in enumerate(trn_dloader, start=1):
Esempio n. 2
0
class _3DGAN(object):
    def __init__(self, args, config=config):
        self.args = args
        self.attribute = args.attribute
        self.gpu = args.gpu
        self.mode = args.mode
        self.restore = args.restore

        # init dataset and networks
        self.config = config
        self.dataset = ShapeNet(self.attribute)
        self.G = Generator()
        self.D = Discriminator()

        self.adv_criterion = torch.nn.BCELoss()

        self.set_mode_and_gpu()
        self.restore_from_file()

    def set_mode_and_gpu(self):
        if self.mode == 'train':
            self.G.train()
            self.D.train()
            if self.gpu:
                with torch.cuda.device(self.gpu[0]):
                    self.G.cuda()
                    self.D.cuda()
                    self.adv_criterion.cuda()

            if len(self.gpu) > 1:
                self.G = torch.nn.DataParallel(self.G, device_ids=self.gpu)
                self.D = torch.nn.DataParallel(self.D, device_ids=self.gpu)

        elif self.mode == 'test':
            self.G.eval()
            self.D.eval()
            if self.gpu:
                with torch.cuda.device(self.gpu[0]):
                    self.G.cuda()
                    self.D.cuda()

            if len(self.gpu) > 1:
                self.G = torch.nn.DataParallel(self.G, device_ids=self.gpu)
                self.D = torch.nn.DataParallel(self.D, device_ids=self.gpu)

        else:
            raise NotImplementationError()

    def restore_from_file(self):
        if self.restore is not None:
            ckpt_file_G = os.path.join(
                self.config.model_dir,
                'G_iter_{:06d}.pth'.format(self.restore))
            assert os.path.exists(ckpt_file_G)
            self.G.load_state_dict(torch.load(ckpt_file_G))

            if self.mode == 'train':
                ckpt_file_D = os.path.join(
                    self.config.model_dir,
                    'D_iter_{:06d}.pth'.format(self.restore))
                assert os.path.exists(ckpt_file_D)
                self.D.load_state_dict(torch.load(ckpt_file_D))

            self.start_step = self.restore + 1
        else:
            self.start_step = 1

    def save_log(self):
        scalar_info = {
            'loss_D': self.loss_D,
            'loss_G': self.loss_G,
            'G_lr': self.G_lr_scheduler.get_lr()[0],
            'D_lr': self.D_lr_scheduler.get_lr()[0],
        }
        for key, value in self.G_loss.items():
            scalar_info['G_loss/' + key] = value

        for key, value in self.D_loss.items():
            scalar_info['D_loss/' + key] = value

        for tag, value in scalar_info.items():
            self.writer.add_scalar(tag, value, self.step)

    def save_img(self, save_num=5):
        for i in range(save_num):
            mdict = {'instance': self.fake_X[i, 0].data.cpu().numpy()}
            sio.savemat(
                os.path.join(self.config.img_dir,
                             '{:06d}_{:02d}.mat'.format(self.step, i)), mdict)

    def save_model(self):
        torch.save(
            {key: val.cpu()
             for key, val in self.G.state_dict().items()},
            os.path.join(self.config.model_dir,
                         'G_iter_{:06d}.pth'.format(self.step)))
        torch.save(
            {key: val.cpu()
             for key, val in self.D.state_dict().items()},
            os.path.join(self.config.model_dir,
                         'D_iter_{:06d}.pth'.format(self.step)))

    def train(self):
        self.writer = SummaryWriter(self.config.log_dir)
        self.opt_G = torch.optim.Adam(self.G.parameters(),
                                      lr=self.config.G_lr,
                                      betas=(0.5, 0.999))
        self.opt_D = torch.optim.Adam(self.D.parameters(),
                                      lr=self.config.D_lr,
                                      betas=(0.5, 0.999))
        self.G_lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.opt_G,
            step_size=self.config.step_size,
            gamma=self.config.gamma)
        self.D_lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.opt_D,
            step_size=self.config.step_size,
            gamma=self.config.gamma)

        # start training
        for step in range(self.start_step, 1 + self.config.max_iter):
            self.step = step
            self.G_lr_scheduler.step()
            self.D_lr_scheduler.step()

            self.real_X = next(self.dataset.gen(True))
            self.noise = torch.randn(self.config.nchw[0], 200)
            if len(self.gpu):
                with torch.cuda.device(self.gpu[0]):
                    self.real_X = self.real_X.cuda()
                    self.noise = self.noise.cuda()

            self.fake_X = self.G(self.noise)

            # update D
            self.D_real = self.D(self.real_X)
            self.D_fake = self.D(self.fake_X.detach())
            self.D_loss = {
                'adv_real':
                self.adv_criterion(self.D_real, torch.ones_like(self.D_real)),
                'adv_fake':
                self.adv_criterion(self.D_fake, torch.zeros_like(self.D_fake)),
            }
            self.loss_D = sum(self.D_loss.values())

            self.opt_D.zero_grad()
            self.loss_D.backward()
            self.opt_D.step()

            # update G
            self.D_fake = self.D(self.fake_X)
            self.G_loss = {
                'adv_fake':
                self.adv_criterion(self.D_fake, torch.ones_like(self.D_fake))
            }
            self.loss_G = sum(self.G_loss.values())
            self.opt_G.zero_grad()
            self.loss_G.backward()
            self.opt_G.step()

            print('step: {:06d}, loss_D: {:.6f}, loss_G: {:.6f}'.format(
                self.step,
                self.loss_D.data.cpu().numpy(),
                self.loss_G.data.cpu().numpy()))

            if self.step % 100 == 0:
                self.save_log()

            if self.step % 1000 == 0:
                self.save_img()
                self.save_model()

        print('Finished training!')
        self.writer.close()
Esempio n. 3
0
class GAN_CLS(object):
    def __init__(self, args, data_loader, SUPERVISED=True):
        """
        Arguments :
        ----------
        args : Arguments defined in Argument Parser
        data_loader = An instance of class DataLoader for loading our dataset in batches
        SUPERVISED :

        """

        self.data_loader = data_loader
        self.num_epochs = args.num_epochs
        self.batch_size = args.batch_size

        self.log_step = config.log_step
        self.sample_step = config.sample_step

        self.log_dir = args.log_dir
        self.checkpoint_dir = args.checkpoint_dir
        self.sample_dir = config.sample_dir
        self.final_model = args.final_model

        self.dataset = args.dataset
        self.model_name = args.model_name

        self.img_size = args.img_size
        self.z_dim = args.z_dim
        self.text_embed_dim = args.text_embed_dim
        self.text_reduced_dim = args.text_reduced_dim
        self.learning_rate = args.learning_rate
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.l1_coeff = args.l1_coeff
        self.resume_epoch = args.resume_epoch
        self.SUPERVISED = SUPERVISED

        # Logger setting
        self.logger = logging.getLogger('__name__')
        self.logger.setLevel(logging.INFO)
        self.formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
        self.file_handler = logging.FileHandler(self.log_dir)
        self.file_handler.setFormatter(self.formatter)
        self.logger.addHandler(self.file_handler)

        self.build_model()

    def build_model(self):
        """ A function of defining following instances :

        -----  Generator
        -----  Discriminator
        -----  Optimizer for Generator
        -----  Optimizer for Discriminator
        -----  Defining Loss functions

        """

        # ---------------------------------------------------------------------
        #						1. Network Initialization
        # ---------------------------------------------------------------------
        self.gen = Generator(batch_size=self.batch_size,
                             img_size=self, img_size,
                             z_dim=self.z_dim,
                             text_embed_dim=self.text_embed_dim,
                             text_reduced_dim=self.text_reduced_dim)

        self.disc = Discriminator(batch_size=self.batch_size,
                                  img_size=self, img_size,
                                  text_embed_dim=self.text_embed_dim,
                                  text_reduced_dim=self.text_reduced_dim)

        self.gen_optim = optim.Adam(self.gen.parameters(),
                                    lr=self.learning_rate,
                                    betas=(self.beta1, self.beta2))

        self.disc_optim = optim.Adam(self.disc.parameters(),
                                     lr=self.learning_rate,
                                     betas=(self.beta1, self.beta2))

        self.cls_gan_optim = optim.Adam(itertools.chain(self.gen.parameters(),
                                                        self.disc.parameters()),
                                        lr=self.learning_rate,
                                        betas=(self.beta1, self.beta2))

        print ('-------------  Generator Model Info  ---------------')
        self.print_network(self.gen, 'G')
        print ('------------------------------------------------')

        print ('-------------  Discriminator Model Info  ---------------')
        self.print_network(self.disc, 'D')
        print ('------------------------------------------------')

        self.gen.cuda()
        self.disc.cuda()
        self.criterion = nn.BCELoss().cuda()
        # self.CE_loss = nn.CrossEntropyLoss().cuda()
        # self.MSE_loss = nn.MSELoss().cuda()
        self.gen.train()
        self.disc.train()

    def print_network(self, model, name):
        """ A function for printing total number of model parameters """
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()

        print(model)
        print(name)
        print("Total number of parameters: {}".format(num_params))

    def load_checkpoints(self, resume_epoch):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_epoch))
        G_path = os.path.join(self.checkpoint_dir, '{}-G.ckpt'.format(resume_epoch))
        D_path = os.path.join(self.checkpoint_dir, '{}-D.ckpt'.format(resume_epoch))
        self.gen.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.disc.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))

    def train_model(self):

        data_loader = self.data_loader

        start_epoch = 0
        if self.resume_epoch:
            start_epoch = self.resume_epoch
            self.load_checkpoints(self.resume_epoch)

        print ('---------------  Model Training Started  ---------------')
        start_time = time.time()

        for epoch in range(start_epoch, self.num_epochs):
            for idx, batch in enumerate(data_loader):
                true_imgs = batch['true_imgs']
                true_embed = batch['true_embed']
                false_imgs = batch['false_imgs']

                real_labels = torch.ones(true_imgs.size(0))
                fake_labels = torch.zeros(true_imgs.size(0))

                smooth_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))

                true_imgs = Variable(true_imgs.float()).cuda()
                true_embed = Variable(true_embed.float()).cuda()
                false_imgs = Variable(false_imgs.float()).cuda()

                real_labels = Variable(real_labels).cuda()
                smooth_real_labels = Variable(smooth_real_labels).cuda()
                fake_labels = Variable(fake_labels).cuda()

                # ---------------------------------------------------------------
                # 					  2. Training the generator
                # ---------------------------------------------------------------
                self.gen.zero_grad()
                z = Variable(torch.randn(true_imgs.size(0), self.z_dim)).cuda()
                fake_imgs = self.gen(true_embed, z)
                fake_out, fake_logit = self.disc(fake_imgs, true_embed)
                true_out, true_logit = self.disc(true_imgs, true_embed)

                gen_loss = self.criterion(fake_out, real_labels) +
                    self.l1_coeff * nn.L1Loss(fake_imgs, true_imgs)

                gen_loss.backward()
                self.gen_optim.step()

                # ---------------------------------------------------------------
                # 					3. Training the discriminator
                # ---------------------------------------------------------------
                self.disc.zero_grad()
                false_out, false_logit = self.disc(false_imgs, true_embed)
                disc_loss = self.criterion(true_out, smooth_real_labels) +
                    self.criterion(fake_out, fake_labels) + self.criterion(false_out, fake_labels)

                disc_loss.backward()
                self.disc_optim.step()

                # self.cls_gan_optim.step()

                # Logging
                loss = {}
                loss['G_loss'] = gen_loss.item()
                loss['D_loss'] = disc_loss.item()

                # ---------------------------------------------------------------
                # 					4. Logging INFO into log_dir
                # ---------------------------------------------------------------
                if (idx + 1) % self.log_step == 0:
                    end_time = time.time() - start_time
                    end_time = datetime.timedelta(seconds=end_time)
                    log = "Elapsed [{}], Epoch [{}/{}], Idx [{}]".format(end_time, epoch + 1,
                                                                         self.num_epochs, idx)

                for net, loss_value in loss.items():
                    log += ", {}: {:.4f}".format(net, loss_value)
                    self.logger.info(log)
                    print (log)

                # ---------------------------------------------------------------
                # 					5. Saving generated images
                # ---------------------------------------------------------------
                if (idx + 1) % self.sample_step == 0:
                    concat_imgs = torch.cat((true_imgs, fake_imgs), 2)  # ??????????
                    save_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(idx + 1))
                    cocat_imgs = (cocat_imgs + 1) / 2
                    # out.clamp_(0, 1)
                    save_image(concat_imgs.data.cpu(), self.sample_dir, nrow=1, padding=0)
                    print ('Saved real and fake images into {}...'.format(self.sample_dir))

                # ---------------------------------------------------------------
                # 				6. Saving the checkpoints & final model
                # ---------------------------------------------------------------
                if (idx + 1) % self.model_save_step == 0:
                    G_path = os.path.join(self.checkpoint_dir, '{}-G.ckpt'.format(idx + 1))
                    D_path = os.path.join(self.checkpoint_dir, '{}-D.ckpt'.format(idx + 1))
                    torch.save(self.gen.state_dict(), G_path)
                    torch.save(self.disc.state_dict(), D_path)
                    print('Saved model checkpoints into {}...'.format(self.checkpoint_dir))
Esempio n. 4
0
def train(dataloader,
          num_epochs,
          net,
          run_settings,
          learning_rate=0.0002,
          optimizerD='Adam'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create the nets
    generator = Generator(net).to(device)
    discriminator = Discriminator(net).to(device)

    # Apply the weights_init function to randomly initialize all weights
    generator.apply(weights_init)
    discriminator.apply(weights_init)

    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)

    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

    beta1 = 0.5

    # Setup Adam optimizers for both G and D
    if optimizerD == 'SGD':
        optimizerD = optim.SGD(discriminator.parameters(), lr=learning_rate)
    else:
        optimizerD = optim.Adam(discriminator.parameters(),
                                lr=learning_rate,
                                betas=(beta1, 0.999))
    optimizerG = optim.Adam(generator.parameters(),
                            lr=learning_rate,
                            betas=(beta1, 0.999))

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print("Starting Training Loop...")
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            ## Train with all-real batch
            discriminator.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ),
                               real_label,
                               dtype=torch.float,
                               device=device)
            # Forward pass real batch through D
            output = discriminator(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = generator(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = discriminator(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            generator.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = discriminator(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 3 == 0:
                print(
                    '[%d/%d][%d/%d]\t\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch + 1, num_epochs, i + 1, len(dataloader),
                       errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving its output on fixed_noise
            if (iters %
                (len(dataloader) * 50) == 0) or ((epoch == num_epochs - 1) and
                                                 (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = generator(fixed_noise).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1

    print("finished")

    for i in range(len(img_list)):
        plt.imshow(np.transpose(img_list[i], (1, 2, 0)))
        plt.savefig('generated_images_' + str(i) + '.png')

    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    plt.savefig('generated_images_' + run_settings + '.png')

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig('loss_graph_' + run_settings + '.png')
Esempio n. 5
0
def train(args):
    device_str = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_str)

    gen = Generator(args.nz, 800)
    gen = gen.to(device)
    gen.apply(weights_init)

    discriminator = Discriminator(800)
    discriminator = discriminator.to(device)
    discriminator.apply(weights_init)

    bce = nn.BCELoss()
    bce = bce.to(device)

    galaxy_dataset = GalaxySet(args.data_path,
                               normalized=args.normalized,
                               out=args.out)
    loader = DataLoader(galaxy_dataset,
                        batch_size=args.bs,
                        shuffle=True,
                        num_workers=2,
                        drop_last=True)
    loader_iter = iter(loader)

    d_optimizer = Adam(discriminator.parameters(),
                       betas=(0.5, 0.999),
                       lr=args.lr)
    g_optimizer = Adam(gen.parameters(), betas=(0.5, 0.999), lr=args.lr)

    real_labels = to_var(torch.ones(args.bs), device_str)
    fake_labels = to_var(torch.zeros(args.bs), device_str)
    fixed_noise = to_var(torch.randn(1, args.nz), device_str)

    for i in tqdm(range(args.iters)):
        try:
            batch_data = loader_iter.next()
        except StopIteration:
            loader_iter = iter(loader)
            batch_data = loader_iter.next()

        batch_data = to_var(batch_data, device).unsqueeze(1)

        batch_data = batch_data[:, :, :1600:2]
        batch_data = batch_data.view(-1, 800)

        ### Train Discriminator ###

        d_optimizer.zero_grad()

        # train Infer with real
        pred_real = discriminator(batch_data)
        d_loss = bce(pred_real, real_labels)

        # train infer with fakes
        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes.detach())
        d_loss += bce(pred_fake, fake_labels)

        d_loss.backward()

        d_optimizer.step()

        ### Train Gen ###

        g_optimizer.zero_grad()

        z = to_var(torch.randn((args.bs, args.nz)), device)
        fakes = gen(z)
        pred_fake = discriminator(fakes)
        gen_loss = bce(pred_fake, real_labels)

        gen_loss.backward()
        g_optimizer.step()

        if i % 5000 == 0:
            print("Iteration %d >> g_loss: %.4f., d_loss: %.4f." %
                  (i, gen_loss, d_loss))
            torch.save(gen.state_dict(),
                       os.path.join(args.out, 'gen_%d.pkl' % 0))
            torch.save(discriminator.state_dict(),
                       os.path.join(args.out, 'disc_%d.pkl' % 0))
            gen.eval()
            fixed_fake = gen(fixed_noise).detach().cpu().numpy()
            real_data = batch_data[0].detach().cpu().numpy()
            gen.train()
            display_noise(fixed_fake.squeeze(),
                          os.path.join(args.out, "gen_sample_%d.png" % i))
            display_noise(real_data.squeeze(),
                          os.path.join(args.out, "real_%d.png" % 0))
Esempio n. 6
0
class GAN3DTrainer(object):
    def __init__(self,
                 logDir,
                 printEvery=1,
                 resume=False,
                 useTensorboard=True):
        super(GAN3DTrainer, self).__init__()

        self.logDir = logDir

        self.currentEpoch = 0
        self.totalBatches = 0

        self.trainStats = {'lossG': [], 'lossD': [], 'accG': [], 'accD': []}

        self.printEvery = printEvery

        self.G = Generator()
        self.D = Discriminator()

        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')

            self.G = self.G.to(self.device)
            self.D = self.D.to(self.device)

            # parallelize models on both devices, splitting input on batch dimension
            self.G = torch.nn.DataParallel(self.G, device_ids=[0, 1])
            self.D = torch.nn.DataParallel(self.D, device_ids=[0, 1])

        # optim params direct from paper
        self.optimG = torch.optim.Adam(self.G.parameters(),
                                       lr=0.0025,
                                       betas=(0.5, 0.999))

        self.optimD = torch.optim.Adam(self.D.parameters(),
                                       lr=0.00005,
                                       betas=(0.5, 0.999))

        if resume:
            self.load()

        self.useTensorboard = useTensorboard
        self.tensorGraphInitialized = False
        self.writer = None
        if useTensorboard:
            self.writer = SummaryWriter(
                os.path.join(self.logDir, 'tensorboard'))

    def train(self, trainData: torch.utils.data.DataLoader):
        epochLoss = 0.0
        numBatches = 0

        self.G.train()
        self.D.train()

        for i, sample in enumerate(tqdm(trainData)):
            data = sample['data']

            self.optimG.zero_grad()
            self.G.zero_grad()

            self.optimD.zero_grad()
            self.D.zero_grad()

            realVoxels = torch.zeros(data['62'].shape[0], 64, 64,
                                     64).to(self.device)
            realVoxels[:, 1:-1, 1:-1, 1:-1] = data['62'].to(self.device)

            # discriminator train
            z = torch.normal(torch.zeros(data['62'].shape[0], 200),
                             torch.ones(data['62'].shape[0], 200) * 0.33).to(
                                 self.device)

            fakeVoxels = self.G(z)
            fakeD = self.D(fakeVoxels)
            realD = self.D(realVoxels)

            lossD = -torch.mean(torch.log(realD) + torch.log(1. - fakeD))
            accD = ((realD >= .5).float().mean() +
                    (fakeD < .5).float().mean()) / 2.
            accG = (fakeD > .5).float().mean()

            # only train if Disc wrong enough :)
            if accD < .8:
                self.D.zero_grad()
                lossD.backward()
                self.optimD.step()

            # gen train
            z = torch.normal(torch.zeros(data['62'].shape[0], 200),
                             torch.ones(data['62'].shape[0], 200) * 0.33).to(
                                 self.device)

            fakeVoxels = self.G(z)
            fakeD = self.D(fakeVoxels)

            # https://arxiv.org/pdf/1706.05170.pdf (IV. Methods, A. Training the gen model)
            lossG = -torch.mean(torch.log(fakeD))

            self.D.zero_grad()
            self.G.zero_grad()
            lossG.backward()
            self.optimG.step()

            #log
            numBatches += 1
            if i % self.printEvery == 0:
                tqdm.write(
                    f'[TRAIN] Epoch {self.currentEpoch:03d}, Batch {i:03d}: '
                    f'gen: {float(accG.item()):2.3f}, dis = {float(accD.item()):2.3f}'
                )

                if (self.useTensorboard):
                    self.writer.add_scalar('GenLoss/train', lossG,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('DisLoss/train', lossD,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('GenAcc/train', accG,
                                           numBatches + self.totalBatches)
                    self.writer.add_scalar('DisAcc/train', accD,
                                           numBatches + self.totalBatches)
                    self.writer.flush()

                    if not self.tensorGraphInitialized:
                        #TODO: why can't I push graph?
                        tempZ = torch.autograd.Variable(
                            torch.rand(data['62'].shape[0], 200, 1, 1,
                                       1)).cuda(1)
                        self.writer.add_graph(self.G.module, tempZ)
                        self.writer.flush()

                        self.writer.add_graph(self.D.module, fakeVoxels)
                        self.writer.flush()

                        self.tensorGraphInitialized = True

        #self.trainLoss.append(epochLoss)
        self.currentEpoch += 1
        self.totalBatches += numBatches

    def save(self):
        logTable = {
            'epoch': self.currentEpoch,
            'totalBatches': self.totalBatches
        }

        torch.save(self.G.state_dict(),
                   os.path.join(self.logDir, 'generator.pth'))
        torch.save(self.D.state_dict(), os.path.join(self.logDir,
                                                     'discrim.pth'))
        torch.save(self.optimG.state_dict(),
                   os.path.join(self.logDir, 'optimG.pth'))
        torch.save(self.optimD.state_dict(),
                   os.path.join(self.logDir, 'optimD.pth'))

        with open(os.path.join(self.logDir, 'recent.log'), 'w') as f:
            f.write(json.dumps(logTable))

        pickle.dump(self.trainStats,
                    open(os.path.join(self.logDir, 'trainStats.pkl'), 'wb'))

        tqdm.write('======== SAVED RECENT MODEL ========')

    def load(self):
        self.G.load_state_dict(
            torch.load(os.path.join(self.logDir, 'generator.pth')))
        self.D.load_state_dict(
            torch.load(os.path.join(self.logDir, 'discrim.pth')))
        self.optimG.load_state_dict(
            torch.load(os.path.join(self.logDir, 'optimG.pth')))
        self.optimD.load_state_dict(
            torch.load(os.path.join(self.logDir, 'optimD.pth')))

        with open(os.path.join(self.logDir, 'recent.log'), 'r') as f:
            runData = json.load(f)

        self.trainStats = pickle.load(
            open(os.path.join(self.logDir, 'trainStats.pkl'), 'rb'))

        self.currentEpoch = runData['epoch']
        self.totalBatches = runData['totalBatches']
Esempio n. 7
0
def gan_augment(x, y, seed, n_samples=None):
    if n_samples is None:
        n_samples = len(x)

    lr = 3e-4
    num_ep = 300
    z_dim = 100
    model_path = "./gan_checkpoint_%d.pth" % seed

    device = "cuda" if torch.cuda.is_available() else "cpu"
    G = Generator(z_dim).to(device)
    D = Discriminator(z_dim).to(device)
    bce_loss = nn.BCELoss()
    G_optim = optim.Adam(G.parameters(), lr=lr * 3, betas=(0.5, 0.999))
    D_optim = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

    batch = 64
    train_x = torch.Tensor(x)
    train_labels = torch.LongTensor(y)

    if os.path.exists(model_path):
        print("load trained GAN...")
        state = torch.load(model_path)
        G.load_state_dict(state["G"])
    else:
        print("training a new GAN...")
        for epoch in range(num_ep):
            for _ in range(len(train_x) // batch):
                idx = np.random.choice(range(len(train_x)), batch)
                batch_x = train_x[idx].to(device)
                batch_labels = train_labels[idx].to(device)

                y_real = torch.ones(batch).to(device)
                y_fake = torch.zeros(batch).to(device)

                # train D with real images
                D.zero_grad()
                D_real_out = D(batch_x, batch_labels).squeeze()
                D_real_loss = bce_loss(D_real_out, y_real)

                # train D with fake images
                z_ = torch.randn((batch, z_dim)).view(-1, z_dim, 1,
                                                      1).to(device)
                fake_labels = torch.randint(0, 10, (batch, )).to(device)
                G_out = G(z_, fake_labels)

                D_fake_out = D(G_out, fake_labels).squeeze()
                D_fake_loss = bce_loss(D_fake_out, y_fake)
                D_loss = D_real_loss + D_fake_loss
                D_loss.backward()
                D_optim.step()

                # train G
                G.zero_grad()
                z_ = torch.randn((batch, z_dim)).view(-1, z_dim, 1,
                                                      1).to(device)
                fake_labels = torch.randint(0, 10, (batch, )).to(device)
                G_out = G(z_, fake_labels)
                D_out = D(G_out, fake_labels).squeeze()
                G_loss = bce_loss(D_out, y_real)
                G_loss.backward()
                G_optim.step()

            plot2img(G_out[:50].cpu())
            print("epoch: %d G_loss: %.2f D_loss: %.2f" %
                  (epoch, G_loss, D_loss))
        state = {"G": G.state_dict(), "D": D.state_dict()}
        torch.save(state, model_path)

    with torch.no_grad():
        z_ = torch.randn((n_samples, z_dim)).view(-1, z_dim, 1, 1).to(device)
        fake_labels = torch.randint(0, 10, (n_samples, )).to(device)
        G_samples = G(z_, fake_labels)
        samples = G_samples.cpu().numpy().reshape((-1, 28, 28, 1))
    return samples, fake_labels.cpu().numpy()