Exemple #1
0
    def test(self):
        name = path.split('.')[0]
        max_value = -100.
        max_index = 0
        max_picture = None
        last = 1.
        sim_range = [1., 1.]
        for index in range(10000):
            #self.x.data = self.feed_interpolated_input(self.loader.get_batch())
            temp, suc = self.loader.load_batch_picture(self.batch, self.imsize)

            if (not suc):
                break
            if (temp.size()[0] != self.batch):  # end of video
                break
            self.x.data = temp
            self.fx = self.D(self.x)
            self.fx = self.fx.cpu().data[0][0]
            print(index, "data", self.fx)
            if (self.fx > max_value and self.fx <= sim_range[1]
                    and self.fx >= sim_range[0]):
                max_value = self.fx
                max_index = index
                max_picture = temp
            utils.save_image_single(
                temp, "test/test_index_score_{}_{}.jpg".format(name, index))
            sim_range[0] = min(last, self.fx) - abs(last - self.fx)
            sim_range[1] = max(last, self.fx) + abs(last - self.fx)
            last = self.fx
        print(max_index, "max", max_value)
        utils.save_image_single(max_picture,
                                "test/max_index_score_{}.jpg".format(name))
Exemple #2
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
        
        for step in range(2, self.max_resl+1+5):
            for iter in tqdm(range(0,(self.trns_tick*2+self.stab_tick*2)*self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter+1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack%(ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()
                
                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)
               
                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())
                
                loss_d = self.mse(self.fx.squeeze(), self.real_label) +  self.mse(self.fx_tilde, self.fake_label)
                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()
                
                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(self.epoch, self.globalTick, self.stack, len(self.loader.dataset), loss_d.item(), loss_g.item(), self.resl, int(pow(2,floor(self.resl))), self.phase, self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo/model')

                # save image grid.
                if self.globalIter%self.config.save_img_every == 0:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    utils.mkdir('repo/save/grid')
                    utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
                    utils.mkdir('repo/save/resl_{}'.format(int(floor(self.resl))))
                    utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)),int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))

                # tensorboard visualization.
                if self.use_tb:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g[0].item(), self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d[0].item(), self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.globalIter)
                    '''IMAGE GRID
print('load checkpoint form ... {}'.format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
test_model.module.load_state_dict(checkpoint['state_dict'])

# create folder.
for i in range(1000):
    name = os.path.join(config.output_dir, 'interpolation/try_{}'.format(i))
    if not os.path.exists(name):
        os.makedirs(name)
        break

# interpolate between twe noise(z1, z2).
z_intp = torch.FloatTensor(1, config.nz)
z1 = torch.FloatTensor(1, config.nz).normal_(0.0, 1.0)
z2 = torch.FloatTensor(1, config.nz).normal_(0.0, 1.0)
if use_cuda:
    z_intp = z_intp.cuda()
    z1 = z1.cuda()
    z2 = z2.cuda()
    test_model = test_model.cuda()

z_intp = Variable(z_intp)

for i in range(1, n_intp + 1):
    alpha = 1.0 / float(n_intp + 1)
    z_intp.data = z1.mul_(alpha) + z2.mul_(1.0 - alpha)
    fake_im = test_model.module(z_intp)
    fname = os.path.join(name, '_intp{}.jpg'.format(i))
    utils.save_image_single(fake_im.data, fname, imsize=pow(2, config.max_resl))
    print('saved {}-th interpolated image ...'.format(i))
    def train(self):
        # noise for test.
        sample_batch = self.loader.get_batch()
        print(sample_batch)
        self.z_test = sample_batch['encods']
        print("0self.z_test")
        print(self.z_test)
        print("1self.z_test")
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=False)

        self.z_test.data.resize_(self.loader.batchsize, self.nz)

        for step in range(2, self.max_resl + 1 + 5):
            for iter in tqdm(
                    range(0, (self.trns_tick * 2 + self.stab_tick * 2) *
                          self.TICK, self.loader.batchsize)):
                sample_batch = self.loader.get_batch()

                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    sample_batch['image'])
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z = sample_batch['encods']
                print("2self.z")
                print(self.z_test)
                print("3self.z")
                if self.use_cuda:
                    self.z = self.z.cuda()
                self.z = Variable(self.z, volatile=False)
                self.z.data.resize_(self.loader.batchsize, self.nz)
                self.x_tilde = self.G(self.z.float())

                self.fx = self.D(self.x.float())
                self.fx_tilde = self.D(self.x_tilde.detach())
                loss_d = self.mse(self.fx, self.real_label) + self.mse(
                    self.fx_tilde, self.fake_label)

                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde, self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch, self.globalTick, self.stack,
                    len(self.loader.dataset), loss_d.data[0], loss_g.data[0],
                    self.resl, int(pow(2, floor(self.resl))), self.phase,
                    self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo_enco/model')

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    x_test = self.G(self.z_test.float())
                    os.system('mkdir -p repo_enco/save/grid')
                    utils.save_image_grid(
                        x_test.data,
                        'repo_enco/save/grid/{}_{}_G{}_D{}.jpg'.format(
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))
                    os.system('mkdir -p repo_enco/save/resl_{}'.format(
                        int(floor(self.resl))))
                    utils.save_image_single(
                        x_test.data,
                        'repo_enco/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))

                # tensorboard visualization.
                if self.use_tb:
                    x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl',
                                       int(pow(2, floor(self.resl))),
                                       self.globalIter)
                    self.tb.add_image_grid(
                        'grid/x_test', 4,
                        utils.adjust_dyn_range(x_test.data.float(), [-1, 1],
                                               [0, 1]), self.globalIter)
                    self.tb.add_image_grid(
                        'grid/x_tilde', 4,
                        utils.adjust_dyn_range(self.x_tilde.data.float(),
                                               [-1, 1], [0, 1]),
                        self.globalIter)
                    self.tb.add_image_grid(
                        'grid/x_intp', 4,
                        utils.adjust_dyn_range(self.x.data.float(), [-1, 1],
                                               [0, 1]), self.globalIter)
Exemple #5
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()

        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        for step in range(2, self.max_resl + 1 + 5):
            for iter in tqdm(
                    range(
                        0,
                        (self.trns_tick * 2 + self.stab_tick * 2) * self.TICK,
                        self.loader.batchsize,
                    )):
                if self.just_passed:
                    continue
                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()
                if self.skip and self.previous_phase == self.phase:
                    continue
                self.skip = False
                if self.globalIter % self.accelerate != 0:
                    continue

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize,
                                    self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())

                loss_d = self.mse(self.fx.squeeze(),
                                  self.real_label) + self.mse(
                                      self.fx_tilde, self.fake_label)

                ### gradient penalty
                gradients = torch_grad(
                    outputs=self.fx,
                    inputs=self.x,
                    grad_outputs=torch.ones(self.fx.size()).cuda()
                    if self.use_cuda else torch.ones(self.fx.size()),
                    create_graph=True,
                    retain_graph=True,
                )[0]
                gradient_penalty = self._gradient_penalty(gradients)
                loss_d += gradient_penalty

                ### epsilon penalty
                epsilon_penalty = (self.fx**2).mean()
                loss_d += epsilon_penalty * self.wgan_epsilon
                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                if (iter - 1) % 10:
                    log_msg = " [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]".format(
                        self.epoch,
                        self.globalTick,
                        self.stack,
                        len(self.loader.dataset),
                        loss_d.item(),
                        loss_g.item(),
                        self.resl,
                        int(pow(2, floor(self.resl))),
                        self.phase,
                        self.complete["gen"],
                        self.complete["dis"],
                        self.lr,
                    )
                    tqdm.write(log_msg)

                # save model.
                self.snapshot("repo/model")

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    utils.mkdir("repo/save/grid")
                    utils.mkdir("repo/save/grid_real")
                    utils.save_image_grid(
                        x_test.data,
                        "repo/save/grid/{}_{}_G{}_D{}.jpg".format(
                            int(self.globalIter / self.config.save_img_every),
                            self.phase,
                            self.complete["gen"],
                            self.complete["dis"],
                        ),
                    )
                    if self.globalIter % self.config.save_img_every * 10 == 0:
                        utils.save_image_grid(
                            self.x.data,
                            "repo/save/grid_real/{}_{}_G{}_D{}.jpg".format(
                                int(self.globalIter /
                                    self.config.save_img_every),
                                self.phase,
                                self.complete["gen"],
                                self.complete["dis"],
                            ),
                        )
                    utils.mkdir("repo/save/resl_{}".format(
                        int(floor(self.resl))))
                    utils.mkdir("repo/save/resl_{}_real".format(
                        int(floor(self.resl))))
                    utils.save_image_single(
                        x_test.data,
                        "repo/save/resl_{}/{}_{}_G{}_D{}.jpg".format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every),
                            self.phase,
                            self.complete["gen"],
                            self.complete["dis"],
                        ),
                    )
                    if self.globalIter % self.config.save_img_every * 10 == 0:
                        utils.save_image_single(
                            self.x.data,
                            "repo/save/resl_{}_real/{}_{}_G{}_D{}.jpg".format(
                                int(floor(self.resl)),
                                int(self.globalIter /
                                    self.config.save_img_every),
                                self.phase,
                                self.complete["gen"],
                                self.complete["dis"],
                            ),
                        )

                # tensorboard visualization.
                if self.use_tb:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    self.tb.add_scalar("data/loss_g", loss_g.item(),
                                       self.globalIter)
                    self.tb.add_scalar("data/loss_d", loss_d.item(),
                                       self.globalIter)
                    self.tb.add_scalar("tick/lr", self.lr, self.globalIter)
                    self.tb.add_scalar("tick/cur_resl",
                                       int(pow(2, floor(self.resl))),
                                       self.globalIter)
                    """IMAGE GRID
                    self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1,1], [0,1]), self.globalIter)
                    """
            self.just_passed = False
Exemple #6
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        for step in range(2, self.max_resl + 1 + 5):
            for iter in tqdm(
                    range(0, (self.trns_tick * 2 + self.stab_tick * 2) *
                          self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize,
                                    self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())

                loss_d = self.mse(self.fx.squeeze(), self.real_label) + \
                                self.mse(self.fx_tilde.squeeze(), self.fake_label)

                # GP
                r = torch.rand_like(self.x)
                self.x_hat = torch.autograd.Variable(
                    r * self.x + (1 - r) * self.x_tilde.detach(),
                    requires_grad=True)
                self.fx_hat = self.D(self.x_hat)
                gradients = torch.autograd.grad(outputs=self.fx_hat,
                                                inputs=self.x_hat,
                                                grad_outputs=torch.ones_like(
                                                    self.fx_hat),
                                                create_graph=True,
                                                retain_graph=True,
                                                only_inputs=True)[0]
                gradients = gradients.view(gradients.size(0), -1)
                gradient_penalty = (
                    (gradients.norm(2, dim=1) - 1)**2).mean() * self.gp_lambda

                # DP
                drift_penalty = (self.fx.norm(2, dim=1)**
                                 2).mean() * self.dp_epsilon

                if self.config.loss == 'WGAN':
                    pass
                elif self.config.loss == 'WGAN-GP':
                    loss_d += gradient_penalty
                elif self.config.loss == 'WGAN-DP':
                    loss_d += drift_penalty
                elif self.config.loss == 'PG-GAN':
                    loss_d += (gradient_penalty + drift_penalty)
                else:
                    raise NotImplementedError

                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch, self.globalTick, self.stack,
                    len(self.loader.dataset), loss_d.item(), loss_g.item(),
                    self.resl, int(pow(2, floor(self.resl))), self.phase,
                    self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('log/model')

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    utils.save_image_grid(
                        x_test.data, 'log/save/grid/{}_{}_G{}_D{}.jpg'.format(
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))
                    utils.save_image_single(
                        x_test.data,
                        'log/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))

                # tensorboard visualization.
                if self.use_tb:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g[0].item(),
                                       self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d[0].item(),
                                       self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl',
                                       int(pow(2, floor(self.resl))),
                                       self.globalIter)
                    '''IMAGE GRID
Exemple #7
0
    def train(self):
        # noise for test
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)

        for step in range(0, self.max_resl + 1 + 5):
            for iter in tqdm(range(0, (self.trns_tick * 2 + self.stab_tick * 2) * self.TICK, self.loader.batchsize)):
                self.global_iter = self.global_iter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack % (ceil(len(self.loader.dataset))))

                # Resolution scheduler
                self.resl_scheduler()

                # Zero the gradients
                self.G.zero_grad()
                self.D.zero_grad()

                # Update discriminator
                self.x.data = self.feed_interpolated_input(self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())
                real_loss = self.criterion(torch.squeeze(self.fx), self.real_label)
                fake_loss = self.criterion(torch.squeeze(self.fx_tilde), self.fake_label)
                loss_d = real_loss + fake_loss

                # Compute gradients and apply update to parameters
                loss_d.backward()
                self.opt_d.step()

                # Update generator
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.criterion(torch.squeeze(fx_tilde), self.real_label.detach())
                
                # Compute gradients and apply update to parameters
                loss_g.backward()
                self.opt_g.step()

                # Log information
                log_msg = ' [epoch:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch,
                    self.global_tick,
                    self.stack,
                    len(self.loader.dataset),
                    loss_d.data[0],
                    loss_g.data[0],
                    self.resl,
                    int(pow(2, floor(self.resl))),
                    self.phase,
                    self.complete['gen'],
                    self.complete['dis'],
                    self.lr)
                tqdm.write(log_msg)

                # Save the model
                self.snapshot('./repo/model')

                # Save the image grid
                if self.global_iter % self.config.save_img_every == 0:
                    x_test = self.G(self.z_test)
                    os.system('mkdir -p repo/save/grid')
                    utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.global_iter / self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
                    os.system('mkdir -p repo/save/resl_{}'.format(int(floor(self.resl))))
                    utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)), int(self.global_iter / self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))

                # Tensorboard visualization
                if self.use_tb:
                    x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0], self.global_iter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0], self.global_iter)
                    self.tb.add_scalar('tick/lr', self.lr, self.global_iter)
                    self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.global_iter)
                    self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1, 1], [0, 1]), self.global_iter)
                    self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1, 1], [0, 1]), self.global_iter)
                    self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1, 1], [0, 1]), self.global_iter)
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
        if self.use_captions:
            test_caps_set = False
            self.caps_test = torch.FloatTensor(self.loader.batchsize, self.ncap)
            if self.use_cuda:
                self.caps_test = self.caps_test.cuda()
            self.caps_test = Variable(self.caps_test, volatile=True)
        
        
        for step in range(2, self.max_resl+1+5):
            for iter in tqdm(range(0,(self.trns_tick*2+self.stab_tick*2)*self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter+1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack%(ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()
                
                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                if self.use_captions:
                    batch_imgs, batch_caps = self.loader.get_batch()
                    if self.use_cuda:
                        batch_caps = batch_caps.cuda()
                    self.caps.data = batch_caps
                    if not test_caps_set:
                        self.caps_test.data = batch_caps
                        test_caps_set = True
                else:
                    batch_imgs, _ = self.loader.get_batch()
                self.x.data = self.feed_interpolated_input(batch_imgs)
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
                if not self.use_captions:
                    self.x_tilde = self.G(self.z)
                else:
                    self.x_tilde = self.G(self.z, self.caps)
                if not self.use_captions:
                    self.fx = self.D(self.x)
                    self.fx_tilde = self.D(self.x_tilde.detach())
                else:
                    self.fx = self.D(self.x, self.caps)
                    self.fx_tilde = self.D(self.x_tilde.detach(), self.caps)

                if self.gan_type == 'lsgan':
                    loss_d = self.mse(self.fx, self.real_label) + self.mse(self.fx_tilde, self.fake_label)
                elif self.gan_type == 'wgan-gp':
                    D_real_loss = -torch.mean(self.fx_tilde)
                    D_fake_loss = torch.mean(self.x_tilde)

                    if self.use_cuda:
                        alpha = torch.rand(self.x.size().cuda())
                    else:
                        alpha = torch.rand(self.x.size())

                    x_hat = Variable(alpha * self.x.data + (1- alpha) * self.G.data, requires_grad=True)

                    pred_hat = self.D(x_hat)

                    if self.use_cuda:
                        gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(),
                                     create_graph=True, retain_graph=True, only_inputs=True)[0]
                    else:
                        gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()),
                                         create_graph=True, retain_graph=True, only_inputs=True)[0]

                    gradient_penalty = self.lambda * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()

                    loss_d = D_real_loss + D_fake_loss + gradient_penalty

                loss_d.backward()
                self.opt_d.step()

                # update generator.
                if not self.use_captions:
                    fx_tilde = self.D(self.x_tilde)
                else:
                    fx_tilde = self.D(self.x_tilde, self.caps)

                if self.gan_type == 'lsgan':
                    loss_g = self.mse(fx_tilde, self.real_label.detach())
                elif self.gan_type == 'wgan-gp':
                    loss_g = -torch.mean(fx_tilde)

                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(self.epoch, self.globalTick, self.stack, len(self.loader.dataset), loss_d.data[0], loss_g.data[0], self.resl, int(pow(2,floor(self.resl))), self.phase, self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo/model')

                # save image grid.
                if self.globalIter%self.config.save_img_every == 0:
                    if not self.use_captions:
                        x_test = self.G(self.z_test)
                    else:
                        x_test = self.G(self.z_test, self.caps_test)
                    os.system('mkdir -p repo/save/grid')
                    utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
                    os.system('mkdir -p repo/save/resl_{}'.format(int(floor(self.resl))))
                    utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)),int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))


                # tensorboard visualization.
                if self.use_tb:
                    if not self.use_captions:
                        x_test = self.G(self.z_test)
                    else:
                        x_test = self.G(self.z_test, self.caps_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0], self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0], self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.globalIter)
                    self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1,1], [0,1]), self.globalIter)
Exemple #9
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        # for step in range(2, self.max_resolution+1+5):
        # for iter in range(0,(self.transition_tick*2+self.stablize_tick*2)*self.TICK, self.loader.batchsize):
        final_step = 0
        while True:
            self.globalIter = self.globalIter + 1
            self.stack = self.stack + self.loader.batchsize
            if self.stack > ceil(len(self.loader.dataset)):
                self.epoch = self.epoch + 1
                self.stack = int(self.stack % (ceil(len(self.loader.dataset))))

            # resolutionolution scheduler.
            sched_results = self.resolution_scheduler()

            # zero gradients.
            self.G.zero_grad()
            self.D.zero_grad()

            # update discriminator.
            self.x.data = self.feed_interpolated_input(self.loader.get_batch())
            if self.flag_add_noise:
                self.x = self.add_noise(self.x)
            self.z.data.resize_(self.loader.batchsize,
                                self.nz).normal_(0.0, 1.0)
            self.x_tilde = self.G(self.z)

            self.fx = self.D(self.x)
            self.fx_tilde = self.D(self.x_tilde.detach())

            loss_d = self.mse(self.fx.squeeze(), self.real_label) + \
                                self.mse(self.fx_tilde, self.fake_label)
            loss_d.backward()
            self.opt_d.step()

            # update generator.
            fx_tilde = self.D(self.x_tilde)
            loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
            loss_g.backward()
            self.opt_g.step()
            if self.globalIter % self.config.freq_print == 0:
                # logging.
                log_msg = sched_results[
                    'ticked'] + ' [E:{0}][T:{1}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resolution:{7:4}][{8}]'.format(
                        self.epoch, self.globalTick, self.stack,
                        len(self.loader.dataset), loss_d.item(),
                        loss_g.item(), self.resolution,
                        int(pow(2, floor(self.resolution))), self.phase,
                        self.complete['gen'], self.complete['dis'], self.lr)
                if hasattr(self, 'fadein') and self.fadein['dis'] is not None:
                    log_msg += '|D-Alpha:{:0.2f}'.format(
                        self.fadein['dis'].alpha)

                if hasattr(self, 'fadein') and self.fadein['gen'] is not None:
                    log_msg += '|G-Alpha:{:0.2f}'.format(
                        self.fadein['gen'].alpha)

                print(log_msg)
            if self.phase == 'final':
                final_step += 1
                if final_step > self.config.final_steps:
                    self.snapshot('repo/model')
                    break
            # tqdm.write(log_msg)

            # save model.
            self.snapshot('repo/model')

            # save image grid.
            if self.globalIter % self.config.save_img_every == 0:
                with torch.no_grad():
                    x_test = self.G(self.z_test)
                utils.mkdir('repo/save/grid')
                utils.save_image_grid(
                    x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(
                        int(self.globalIter / self.config.save_img_every),
                        self.phase, self.complete['gen'],
                        self.complete['dis']))
                utils.mkdir('repo/save/resolution_{}'.format(
                    int(floor(self.resolution))))
                utils.save_image_single(
                    x_test.data,
                    'repo/save/resolution_{}/{}_{}_G{}_D{}.jpg'.format(
                        int(floor(self.resolution)),
                        int(self.globalIter / self.config.save_img_every),
                        self.phase, self.complete['gen'],
                        self.complete['dis']))
                # import ipdb; ipdb.set_trace()
            # tensorboard visualization.
            if self.use_tb:
                with torch.no_grad():
                    x_test = self.G(self.z_test)
                self.tb.add_scalar('data/loss_g', loss_g.item(),
                                   self.globalIter)
                self.tb.add_scalar('data/loss_d', loss_d.item(),
                                   self.globalIter)
                self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                self.tb.add_scalar('tick/cur_resolution',
                                   int(pow(2, floor(self.resolution))),
                                   self.globalIter)
                '''IMAGE GRID
Exemple #10
0
    def train(self):

        # optimizer
        betas = (self.config.beta1, self.config.beta2)
        if self.optimizer == 'adam':
            self.opt_g = Adam(filter(lambda p: p.requires_grad,
                                     self.G.parameters()),
                              lr=self.config.lr,
                              betas=betas,
                              weight_decay=0.0)
            self.opt_d = Adam(filter(lambda p: p.requires_grad,
                                     self.D.parameters()),
                              lr=self.config.lr,
                              betas=betas,
                              weight_decay=0.0)

        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        for step in range(2, self.max_resl):
            for iter in tqdm(
                    range(0, (self.trns_tick * 2 + self.stab_tick * 2) *
                          self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    self.loader.get_batch())
                self.z.data.resize_(self.loader.batchsize,
                                    self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                fx = self.D(self.x)
                fx_tilde = self.D(self.x_tilde.detach())
                loss_d = self.mse(fx, self.real_label) + self.mse(
                    fx_tilde, self.fake_label)
                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde, self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch, self.globalTick, self.stack,
                    len(self.loader.dataset), loss_d.data[0], loss_g.data[0],
                    self.resl, int(pow(2, floor(self.resl))), self.phase,
                    self.complete['gen'], self.complete['dis'])
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo/model')

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    x_test = self.G(self.z_test)
                    os.system('mkdir -p repo/save/grid')
                    utils.save_image_grid(
                        x_test.data, 'repo/save/grid/{}.jpg'.format(
                            int(self.globalIter / self.config.save_img_every)))
                    os.system('mkdir -p repo/save/resl_{}'.format(
                        int(floor(self.resl))))
                    utils.save_image_single(
                        x_test.data, 'repo/save/resl_{}/{}.jpg'.format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every)))

                # tensorboard visualization.
                if self.use_tb:
                    x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('tick/globalTick', int(self.globalTick),
                                       self.globalIter)
                    self.tb.add_image_grid('grid/x_test', 4,
                                           x_test.data.float(),
                                           self.globalIter)
                    self.tb.add_image_grid('grid/x_tilde', 4,
                                           self.x_tilde.data.float(),
                                           self.globalIter)
                    self.tb.add_image_grid('grid/x_intp', 1,
                                           self.x.data.float(),
                                           self.globalIter)