예제 #1
0
    def train(self, dataset):
        z = tf.constant(random.normal((FLAGS.n_samples, 1, 1, self.z_dim)))
        g_train_loss = metrics.Mean()
        d_train_loss = metrics.Mean()

        for epoch in range(self.epochs):
            bar = pbar(self.total_images, self.batch_size, epoch, self.epochs)
            for batch in dataset:
                for _ in range(self.n_critic):
                    self.train_d(batch)
                    d_loss = self.train_d(batch)
                    d_train_loss(d_loss)

                g_loss = self.train_g()
                g_train_loss(g_loss)
                self.train_g()

                bar.postfix['g_loss'] = f'{g_train_loss.result():6.3f}'
                bar.postfix['d_loss'] = f'{d_train_loss.result():6.3f}'
                bar.update(self.batch_size)

            g_train_loss.reset_states()
            d_train_loss.reset_states()

            bar.close()
            del bar

            samples = self.generate_samples(z)
            image_grid = img_merge(samples, n_rows=8).squeeze()
            save_image_grid(image_grid, epoch + 1)
예제 #2
0
파일: trainer.py 프로젝트: mikanCan/PG-GAN
 def save_image(self, image, path):
     save_path = self.save_dir + 'images/' + path
     if not os.path.exists(save_path):
         os.mkdir(save_path)
     utils.save_image_grid(
         image.data, save_path + '{}_{}_G{}_D{}.jpg'.format(
             int(self.globalIter / self.config.save_img_every), self.phase,
             self.complete['gen'], self.complete['dis']))
예제 #3
0
def run_training(args):
    """Initialize and run the full training process using the hyper-params in args."""
    device, data_loader, train_config, generator, discriminator, optimizers = init_training(
        args)

    # generate a sample with fixed seed, and reset the seed to pseudo-random
    torch.manual_seed(42)
    z_sample = torch.randn(train_config['batch_size'],
                           train_config['latent_dim'], 1, 1).to(device)
    torch.manual_seed(random.randint(0, 1e10))

    # Loss function for DCGAN
    if args.gan_type == 'dcgan':
        loss_fn = torch.nn.BCELoss().to(device)

    # Training
    print('Training:')
    for epoch in range(train_config['start_epoch'],
                       train_config['max_epoch'] + 1):
        for batch_idx, batch in enumerate(data_loader):

            if args.gan_type == 'dcgan':
                g_loss, d_loss = \
                    training_step_dcgan(batch, device, generator, discriminator, optimizers, train_config, loss_fn)
            elif args.gan_type == 'wgan_gp':
                _, d_loss = \
                    training_step_wgan_gp(batch_idx, batch, device, train_config, generator, discriminator, optimizers)
                if _ is not None:
                    g_loss = _

        print('\nEpoch {}/{}:\n'
              '  Discriminator loss={:.4f}\n'
              '  Generator loss={:.4f}'.format(epoch,
                                               train_config['max_epoch'],
                                               d_loss.item(), g_loss.item()))

        if epoch == 1 or epoch % train_config['sample_save_freq'] == 0:
            # Save sample
            gen_sample = generator(z_sample)
            save_image_grid(
                img_batch=gen_sample[:train_config['grid_size']**2].detach(
                ).cpu().numpy(),
                grid_size=train_config['grid_size'],
                epoch=epoch,
                img_path=os.path.join(
                    args.checkpoint_path, 'samples',
                    'checkpoint_ep{}_sample.png'.format(epoch)))
            print('Image sample saved.')

        if epoch == 1 or epoch % train_config['save_freq'] == 0:
            # Save checkpoint
            gen_path = os.path.join(args.checkpoint_path, 'weights',
                                    'checkpoint_ep{}_gen.pt'.format(epoch))
            disc_path = os.path.join(args.checkpoint_path, 'weights',
                                     'checkpoint_ep{}_disc.pt'.format(epoch))
            torch.save(generator.state_dict(), gen_path)
            torch.save(discriminator.state_dict(), disc_path)
            print('Checkpoint.')
예제 #4
0
    def train(self, dataset):
        z = tf.constant(
            random.normal((self.params.n_samples, 1, 1, self.z_dim)))
        g_train_loss = metrics.Mean()
        d_train_loss = metrics.Mean()

        for epoch in range(self.epochs):
            for batch in dataset:
                for _ in range(self.n_critic):
                    self.train_d(batch)
                    d_loss = self.train_d(batch)
                    d_train_loss(d_loss)

                g_loss = self.train_g()
                g_train_loss(g_loss)
                self.train_g()

            g_train_loss.reset_states()
            d_train_loss.reset_states()

            samples = self.generate_samples(z)
            image_grid = img_merge(samples, n_rows=8).squeeze()
            save_image_grid(image_grid, epoch + 1)
예제 #5
0
def evaluate_pred(config):

    # define directories
    model_name = config.model

    test_data_root = config.data_root
    if config.deep_pred > 1:
        test_dir = config.test_dir + '/' + config.experiment_name + '/deep-pred{}/'.format(
            config.deep_pred) + model_name
    else:
        test_dir = config.test_dir + '/' + config.experiment_name + '/pred/' + model_name
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
    sample_dir = test_dir + '/samples'
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)

    nframes_in = config.nframes_in
    nframes_pred = config.nframes_pred * config.deep_pred
    nframes = nframes_in + nframes_pred
    img_size = int(config.resl)
    nworkers = 4

    # load model
    if config.model == 'FutureGAN':
        ckpt = torch.load(config.model_path)
        # model structure
        G = ckpt['G_structure']
        # load model parameters
        G.load_state_dict(ckpt['state_dict'])
        G.eval()
        G = G.module.model
        if use_cuda:
            G = G.cuda()
        print(' ... loading FutureGAN`s FutureGenerator from checkpoint: {}'.
              format(config.model_path))

    # load test dataset
    transform = transforms.Compose([
        transforms.Resize(size=(img_size, img_size),
                          interpolation=Image.NEAREST),
        transforms.ToTensor(),
    ])
    if config.model == 'FutureGAN' or config.model == 'CopyLast':
        dataset_gt = VideoFolder(video_root=test_data_root,
                                 video_ext=config.ext,
                                 nframes=nframes,
                                 loader=video_loader,
                                 transform=transform)
        dataloader_gt = DataLoader(
            dataset=dataset_gt,
            batch_size=config.batch_size,
            sampler=sampler.SequentialSampler(dataset_gt),
            num_workers=nworkers)
    else:
        dataset_gt = VideoFolder(video_root=test_data_root + '/in_gt',
                                 video_ext=config.ext,
                                 nframes=nframes,
                                 loader=video_loader,
                                 transform=transform)
        dataset_pred = VideoFolder(video_root=test_data_root + '/in_pred',
                                   video_ext=config.ext,
                                   nframes=nframes,
                                   loader=video_loader,
                                   transform=transform)
        dataloader_pred = DataLoader(
            dataset=dataset_pred,
            batch_size=config.batch_size,
            sampler=sampler.SequentialSampler(dataset_pred),
            num_workers=nworkers)
        dataloader_gt = DataLoader(
            dataset=dataset_gt,
            batch_size=config.batch_size,
            sampler=sampler.SequentialSampler(dataset_gt),
            num_workers=nworkers)
        data_iter_pred = iter(dataloader_pred)
    test_len = len(dataset_gt)
    data_iter_gt = iter(dataloader_gt)

    # save model structure to file
    if config.model == 'FutureGAN':
        # count model parameters
        nparams_g = count_model_params(G)
        with open(
                test_dir +
                '/model_structure_{}x{}.txt'.format(img_size, img_size),
                'w') as f:
            print('--------------------------------------------------', file=f)
            print('Sequences in test dataset: ', len(dataset_gt), file=f)
            print('Number of model parameters: ', file=f)
            print(nparams_g, file=f)
            print('--------------------------------------------------', file=f)
            print('Model structure: ', file=f)
            print(G, file=f)
            print('--------------------------------------------------', file=f)
            print(
                ' ... FutureGAN`s FutureGenerator has been loaded successfully from checkpoint ... '
            )
            print(' ... saving model struture to {}'.format(f))

    # save test configuration
    with open(test_dir + '/eval_config.txt', 'w') as f:
        print('------------- test configuration -------------', file=f)
        for l, m in vars(config).items():
            print(('{}: {}').format(l, m), file=f)
        print(' ... loading test configuration ... ')
        print(' ... saving test configuration {}'.format(f))

    # define tensors
    if config.model == 'FutureGAN':
        print(' ... testing FutureGAN ...')
        if config.deep_pred > 1:
            print(
                ' ... recursively predicting {}x{} future frames from {} input frames ...'
                .format(config.deep_pred, config.nframes_pred, nframes_in))
        else:
            print(' ... predicting {} future frames from {} input frames ...'.
                  format(nframes_pred, nframes_in))
    z = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_in, img_size,
                          img_size))
    z_in = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_in, img_size,
                          img_size))
    x_pred = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_pred, img_size,
                          img_size))
    x = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes, img_size,
                          img_size))
    x_eval = Variable(
        torch.FloatTensor(config.batch_size, config.nc, nframes_pred, img_size,
                          img_size))

    # define tensors for evaluation
    if config.metrics is not None:
        print(' ... evaluating {} ...'.format(model_name))
        if 'ms_ssim' in config.metrics and img_size < 32:
            raise ValueError(
                'For calculating `ms_ssim`, your dataset must consist of images at least of size 32x32!'
            )

        metrics_values = {}
        for metric_name in config.metrics:
            metrics_values['{}_frames'.format(metric_name)] = torch.zeros_like(
                torch.FloatTensor(test_len, nframes_pred))
            metrics_values['{}_avg'.format(metric_name)] = torch.zeros_like(
                torch.FloatTensor(test_len, 1))
            print(' ... calculating {} ...'.format(metric_name))

    # test loop
    if config.metrics is not None:
        metrics_i_video = {}
        for metric_name in config.metrics:
            metrics_i_video['{}_i_video'.format(metric_name)] = 0

    i_save_video = 1
    i_save_gif = 1

    for step in tqdm(range(len(data_iter_gt))):

        # input frames
        x.data = next(data_iter_gt)
        x_eval.data = x.data[:, :, nframes_in:, :, :]
        z.data = x.data[:, :, :nframes_in, :, :]

        if use_cuda:
            x = x.cuda()
            x_eval = x_eval.cuda()
            z = z.cuda()
            x_pred = x_pred.cuda()

        # predict video frames
        # !!! TODO !!! for deep_pred > 1: correctly implemented only if nframes_in == nframes_pred
        if config.model == 'FutureGAN':
            z_in.data = z.data
            for i_deep_pred in range(0, config.deep_pred):
                x_pred[:z_in.size(0), :, i_deep_pred *
                       config.nframes_pred:(i_deep_pred *
                                            config.nframes_pred) +
                       config.nframes_pred, :, :] = G(z_in).detach()
                z_in.data = x_pred.data[:, :,
                                        i_deep_pred * config.nframes_pred:
                                        (i_deep_pred * config.nframes_pred) +
                                        config.nframes_pred, :, :]

        elif config.model == 'CopyLast':
            for i_baseline_frame in range(x_pred.size(2)):
                x_pred.data[:x.size(0), :,
                            i_baseline_frame, :, :] = x.data[:, :, nframes_in -
                                                             1, :, :]

        else:
            x_pred.data = next(data_iter_pred)[:x.size(0), :,
                                               nframes_in:, :, :]

        # calculate eval statistics
        if config.metrics is not None:
            for metric_name in config.metrics:
                calculate_metric = getattr(eval_metrics,
                                           'calculate_{}'.format(metric_name))

                for i_batch in range(x.size(0)):
                    for i_frame in range(nframes_pred):
                        metrics_values['{}_frames'.format(metric_name)][
                            metrics_i_video['{}_i_video'.format(metric_name)],
                            i_frame] = calculate_metric(
                                x_pred[i_batch, :, i_frame, :, :],
                                x_eval[i_batch, :, i_frame, :, :])
                        metrics_values['{}_avg'.format(metric_name)][
                            metrics_i_video['{}_i_video'.format(
                                metric_name)]] = torch.mean(
                                    metrics_values['{}_frames'.format(
                                        metric_name)][metrics_i_video[
                                            '{}_i_video'.format(metric_name)]])
                    metrics_i_video['{}_i_video'.format(
                        metric_name
                    )] = metrics_i_video['{}_i_video'.format(metric_name)] + 1

        # save frames
        if config.save_frames_every is not 0 and config.model == 'FutureGAN':
            if step % config.save_frames_every == 0 or step == 0:
                for i_save_batch in range(x.size(0)):
                    if not os.path.exists(
                            sample_dir +
                            '/in_gt/video{:04d}'.format(i_save_video)):
                        os.makedirs(sample_dir +
                                    '/in_gt/video{:04d}'.format(i_save_video))
                    if not os.path.exists(
                            sample_dir +
                            '/in_pred/video{:04d}'.format(i_save_video)):
                        os.makedirs(
                            sample_dir +
                            '/in_pred/video{:04d}'.format(i_save_video))
                    for i_save_z in range(z.size(2)):
                        save_image_grid(
                            z.data[i_save_batch, :,
                                   i_save_z, :, :].unsqueeze(0), sample_dir +
                            '/in_gt/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_z + 1,
                                    img_size, img_size), img_size, 1)
                        save_image_grid(
                            z.data[i_save_batch, :,
                                   i_save_z, :, :].unsqueeze(0), sample_dir +
                            '/in_pred/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_z + 1,
                                    img_size, img_size), img_size, 1)
                    for i_save_x_pred in range(x_pred.size(2)):
                        save_image_grid(
                            x_eval.data[i_save_batch, :,
                                        i_save_x_pred, :, :].unsqueeze(0),
                            sample_dir +
                            '/in_gt/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_x_pred +
                                    1 + nframes_in, img_size, img_size),
                            img_size, 1)
                        save_image_grid(
                            x_pred.data[i_save_batch, :,
                                        i_save_x_pred, :, :].unsqueeze(0),
                            sample_dir +
                            '/in_pred/video{:04d}/video{:04d}_frame{:04d}_R{}x{}.png'
                            .format(i_save_video, i_save_video, i_save_x_pred +
                                    1 + nframes_in, img_size, img_size),
                            img_size, 1)
                    i_save_video = i_save_video + 1

        # save gifs
        if config.save_gif_every is not 0:
            if step % config.save_gif_every == 0 or step == 0:
                for i_save_batch in range(x.size(0)):
                    if not os.path.exists(
                            sample_dir +
                            '/in_gt/video{:04d}'.format(i_save_gif)):
                        os.makedirs(sample_dir +
                                    '/in_gt/video{:04d}'.format(i_save_gif))
                    if not os.path.exists(
                            sample_dir +
                            '/in_pred/video{:04d}'.format(i_save_gif)):
                        os.makedirs(sample_dir +
                                    '/in_pred/video{:04d}'.format(i_save_gif))
                    frames = []
                    for i_save_z in range(z.size(2)):
                        frames.append(
                            get_image_grid(
                                z.data[i_save_batch, :,
                                       i_save_z, :, :].unsqueeze(0), img_size,
                                1, config.in_border, config.npx_border))
                    for i_save_x_pred in range(x_pred.size(2)):
                        frames.append(
                            get_image_grid(
                                x_eval.data[i_save_batch, :,
                                            i_save_x_pred, :, :].unsqueeze(0),
                                img_size, 1, config.out_border,
                                config.npx_border))
                    imageio.mimsave(
                        sample_dir +
                        '/in_gt/video{:04d}/video{:04d}_R{}x{}.gif'.format(
                            i_save_gif, i_save_gif, img_size, img_size),
                        frames)
                    frames = []
                    for i_save_z in range(z.size(2)):
                        frames.append(
                            get_image_grid(
                                z.data[i_save_batch, :,
                                       i_save_z, :, :].unsqueeze(0), img_size,
                                1, config.in_border, config.npx_border))
                    for i_save_x_pred in range(x_pred.size(2)):
                        frames.append(
                            get_image_grid(
                                x_pred.data[i_save_batch, :,
                                            i_save_x_pred, :, :].unsqueeze(0),
                                img_size, 1, config.out_border,
                                config.npx_border))
                    imageio.mimsave(
                        sample_dir +
                        '/in_pred/video{:04d}/video{:04d}_R{}x{}.gif'.format(
                            i_save_gif, i_save_gif, img_size, img_size),
                        frames)
                    i_save_gif = i_save_gif + 1

    if config.save_frames_every is not 0 and config.model == 'FutureGAN':
        print(' ... saving video frames to dir: {}'.format(sample_dir))
        if config.save_gif_every is not 0:
            print(' ... saving gifs to dir: {}'.format(sample_dir))

    # calculate and save mean eval statistics
    if config.metrics is not None:
        metrics_mean_values = {}
        for metric_name in config.metrics:
            metrics_mean_values['{}_frames'.format(metric_name)] = torch.mean(
                metrics_values['{}_frames'.format(metric_name)], 0)
            metrics_mean_values['{}_avg'.format(metric_name)] = torch.mean(
                metrics_values['{}_avg'.format(metric_name)], 0)
            torch.save(
                metrics_mean_values['{}_frames'.format(metric_name)],
                os.path.join(test_dir, '{}_frames.pt'.format(metric_name)))
            torch.save(metrics_mean_values['{}_avg'.format(metric_name)],
                       os.path.join(test_dir, '{}_avg.pt'.format(metric_name)))

        print(' ... saving evaluation statistics to dir: {}'.format(test_dir))
예제 #6
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
        
        for step in range(2, self.max_resl+1+5):
            for iter in tqdm(range(0,(self.trns_tick*2+self.stab_tick*2)*self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter+1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack%(ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()
                
                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)
               
                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())
                
                loss_d = self.mse(self.fx.squeeze(), self.real_label) +  self.mse(self.fx_tilde, self.fake_label)
                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()
                
                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(self.epoch, self.globalTick, self.stack, len(self.loader.dataset), loss_d.item(), loss_g.item(), self.resl, int(pow(2,floor(self.resl))), self.phase, self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo/model')

                # save image grid.
                if self.globalIter%self.config.save_img_every == 0:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    utils.mkdir('repo/save/grid')
                    utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
                    utils.mkdir('repo/save/resl_{}'.format(int(floor(self.resl))))
                    utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)),int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))

                # tensorboard visualization.
                if self.use_tb:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g[0].item(), self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d[0].item(), self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.globalIter)
                    '''IMAGE GRID
    def train(self):
        # noise for test.
        sample_batch = self.loader.get_batch()
        print(sample_batch)
        self.z_test = sample_batch['encods']
        print("0self.z_test")
        print(self.z_test)
        print("1self.z_test")
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=False)

        self.z_test.data.resize_(self.loader.batchsize, self.nz)

        for step in range(2, self.max_resl + 1 + 5):
            for iter in tqdm(
                    range(0, (self.trns_tick * 2 + self.stab_tick * 2) *
                          self.TICK, self.loader.batchsize)):
                sample_batch = self.loader.get_batch()

                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    sample_batch['image'])
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z = sample_batch['encods']
                print("2self.z")
                print(self.z_test)
                print("3self.z")
                if self.use_cuda:
                    self.z = self.z.cuda()
                self.z = Variable(self.z, volatile=False)
                self.z.data.resize_(self.loader.batchsize, self.nz)
                self.x_tilde = self.G(self.z.float())

                self.fx = self.D(self.x.float())
                self.fx_tilde = self.D(self.x_tilde.detach())
                loss_d = self.mse(self.fx, self.real_label) + self.mse(
                    self.fx_tilde, self.fake_label)

                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde, self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch, self.globalTick, self.stack,
                    len(self.loader.dataset), loss_d.data[0], loss_g.data[0],
                    self.resl, int(pow(2, floor(self.resl))), self.phase,
                    self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo_enco/model')

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    x_test = self.G(self.z_test.float())
                    os.system('mkdir -p repo_enco/save/grid')
                    utils.save_image_grid(
                        x_test.data,
                        'repo_enco/save/grid/{}_{}_G{}_D{}.jpg'.format(
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))
                    os.system('mkdir -p repo_enco/save/resl_{}'.format(
                        int(floor(self.resl))))
                    utils.save_image_single(
                        x_test.data,
                        'repo_enco/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))

                # tensorboard visualization.
                if self.use_tb:
                    x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl',
                                       int(pow(2, floor(self.resl))),
                                       self.globalIter)
                    self.tb.add_image_grid(
                        'grid/x_test', 4,
                        utils.adjust_dyn_range(x_test.data.float(), [-1, 1],
                                               [0, 1]), self.globalIter)
                    self.tb.add_image_grid(
                        'grid/x_tilde', 4,
                        utils.adjust_dyn_range(self.x_tilde.data.float(),
                                               [-1, 1], [0, 1]),
                        self.globalIter)
                    self.tb.add_image_grid(
                        'grid/x_intp', 4,
                        utils.adjust_dyn_range(self.x.data.float(), [-1, 1],
                                               [0, 1]), self.globalIter)
예제 #8
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()

        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        for step in range(2, self.max_resl + 1 + 5):
            for iter in tqdm(
                    range(
                        0,
                        (self.trns_tick * 2 + self.stab_tick * 2) * self.TICK,
                        self.loader.batchsize,
                    )):
                if self.just_passed:
                    continue
                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()
                if self.skip and self.previous_phase == self.phase:
                    continue
                self.skip = False
                if self.globalIter % self.accelerate != 0:
                    continue

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize,
                                    self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())

                loss_d = self.mse(self.fx.squeeze(),
                                  self.real_label) + self.mse(
                                      self.fx_tilde, self.fake_label)

                ### gradient penalty
                gradients = torch_grad(
                    outputs=self.fx,
                    inputs=self.x,
                    grad_outputs=torch.ones(self.fx.size()).cuda()
                    if self.use_cuda else torch.ones(self.fx.size()),
                    create_graph=True,
                    retain_graph=True,
                )[0]
                gradient_penalty = self._gradient_penalty(gradients)
                loss_d += gradient_penalty

                ### epsilon penalty
                epsilon_penalty = (self.fx**2).mean()
                loss_d += epsilon_penalty * self.wgan_epsilon
                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                if (iter - 1) % 10:
                    log_msg = " [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]".format(
                        self.epoch,
                        self.globalTick,
                        self.stack,
                        len(self.loader.dataset),
                        loss_d.item(),
                        loss_g.item(),
                        self.resl,
                        int(pow(2, floor(self.resl))),
                        self.phase,
                        self.complete["gen"],
                        self.complete["dis"],
                        self.lr,
                    )
                    tqdm.write(log_msg)

                # save model.
                self.snapshot("repo/model")

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    utils.mkdir("repo/save/grid")
                    utils.mkdir("repo/save/grid_real")
                    utils.save_image_grid(
                        x_test.data,
                        "repo/save/grid/{}_{}_G{}_D{}.jpg".format(
                            int(self.globalIter / self.config.save_img_every),
                            self.phase,
                            self.complete["gen"],
                            self.complete["dis"],
                        ),
                    )
                    if self.globalIter % self.config.save_img_every * 10 == 0:
                        utils.save_image_grid(
                            self.x.data,
                            "repo/save/grid_real/{}_{}_G{}_D{}.jpg".format(
                                int(self.globalIter /
                                    self.config.save_img_every),
                                self.phase,
                                self.complete["gen"],
                                self.complete["dis"],
                            ),
                        )
                    utils.mkdir("repo/save/resl_{}".format(
                        int(floor(self.resl))))
                    utils.mkdir("repo/save/resl_{}_real".format(
                        int(floor(self.resl))))
                    utils.save_image_single(
                        x_test.data,
                        "repo/save/resl_{}/{}_{}_G{}_D{}.jpg".format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every),
                            self.phase,
                            self.complete["gen"],
                            self.complete["dis"],
                        ),
                    )
                    if self.globalIter % self.config.save_img_every * 10 == 0:
                        utils.save_image_single(
                            self.x.data,
                            "repo/save/resl_{}_real/{}_{}_G{}_D{}.jpg".format(
                                int(floor(self.resl)),
                                int(self.globalIter /
                                    self.config.save_img_every),
                                self.phase,
                                self.complete["gen"],
                                self.complete["dis"],
                            ),
                        )

                # tensorboard visualization.
                if self.use_tb:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    self.tb.add_scalar("data/loss_g", loss_g.item(),
                                       self.globalIter)
                    self.tb.add_scalar("data/loss_d", loss_d.item(),
                                       self.globalIter)
                    self.tb.add_scalar("tick/lr", self.lr, self.globalIter)
                    self.tb.add_scalar("tick/cur_resl",
                                       int(pow(2, floor(self.resl))),
                                       self.globalIter)
                    """IMAGE GRID
                    self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1,1], [0,1]), self.globalIter)
                    """
            self.just_passed = False
예제 #9
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        for step in range(2, self.max_resl + 1 + 5):
            for iter in tqdm(
                    range(0, (self.trns_tick * 2 + self.stab_tick * 2) *
                          self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize,
                                    self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())

                loss_d = self.mse(self.fx.squeeze(), self.real_label) + \
                                self.mse(self.fx_tilde.squeeze(), self.fake_label)

                # GP
                r = torch.rand_like(self.x)
                self.x_hat = torch.autograd.Variable(
                    r * self.x + (1 - r) * self.x_tilde.detach(),
                    requires_grad=True)
                self.fx_hat = self.D(self.x_hat)
                gradients = torch.autograd.grad(outputs=self.fx_hat,
                                                inputs=self.x_hat,
                                                grad_outputs=torch.ones_like(
                                                    self.fx_hat),
                                                create_graph=True,
                                                retain_graph=True,
                                                only_inputs=True)[0]
                gradients = gradients.view(gradients.size(0), -1)
                gradient_penalty = (
                    (gradients.norm(2, dim=1) - 1)**2).mean() * self.gp_lambda

                # DP
                drift_penalty = (self.fx.norm(2, dim=1)**
                                 2).mean() * self.dp_epsilon

                if self.config.loss == 'WGAN':
                    pass
                elif self.config.loss == 'WGAN-GP':
                    loss_d += gradient_penalty
                elif self.config.loss == 'WGAN-DP':
                    loss_d += drift_penalty
                elif self.config.loss == 'PG-GAN':
                    loss_d += (gradient_penalty + drift_penalty)
                else:
                    raise NotImplementedError

                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch, self.globalTick, self.stack,
                    len(self.loader.dataset), loss_d.item(), loss_g.item(),
                    self.resl, int(pow(2, floor(self.resl))), self.phase,
                    self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('log/model')

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    utils.save_image_grid(
                        x_test.data, 'log/save/grid/{}_{}_G{}_D{}.jpg'.format(
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))
                    utils.save_image_single(
                        x_test.data,
                        'log/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every),
                            self.phase, self.complete['gen'],
                            self.complete['dis']))

                # tensorboard visualization.
                if self.use_tb:
                    with torch.no_grad():
                        x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g[0].item(),
                                       self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d[0].item(),
                                       self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl',
                                       int(pow(2, floor(self.resl))),
                                       self.globalIter)
                    '''IMAGE GRID
예제 #10
0
    def train(self):
        # noise for test
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)

        for step in range(0, self.max_resl + 1 + 5):
            for iter in tqdm(range(0, (self.trns_tick * 2 + self.stab_tick * 2) * self.TICK, self.loader.batchsize)):
                self.global_iter = self.global_iter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack % (ceil(len(self.loader.dataset))))

                # Resolution scheduler
                self.resl_scheduler()

                # Zero the gradients
                self.G.zero_grad()
                self.D.zero_grad()

                # Update discriminator
                self.x.data = self.feed_interpolated_input(self.loader.get_batch())
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                self.fx = self.D(self.x)
                self.fx_tilde = self.D(self.x_tilde.detach())
                real_loss = self.criterion(torch.squeeze(self.fx), self.real_label)
                fake_loss = self.criterion(torch.squeeze(self.fx_tilde), self.fake_label)
                loss_d = real_loss + fake_loss

                # Compute gradients and apply update to parameters
                loss_d.backward()
                self.opt_d.step()

                # Update generator
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.criterion(torch.squeeze(fx_tilde), self.real_label.detach())
                
                # Compute gradients and apply update to parameters
                loss_g.backward()
                self.opt_g.step()

                # Log information
                log_msg = ' [epoch:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch,
                    self.global_tick,
                    self.stack,
                    len(self.loader.dataset),
                    loss_d.data[0],
                    loss_g.data[0],
                    self.resl,
                    int(pow(2, floor(self.resl))),
                    self.phase,
                    self.complete['gen'],
                    self.complete['dis'],
                    self.lr)
                tqdm.write(log_msg)

                # Save the model
                self.snapshot('./repo/model')

                # Save the image grid
                if self.global_iter % self.config.save_img_every == 0:
                    x_test = self.G(self.z_test)
                    os.system('mkdir -p repo/save/grid')
                    utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.global_iter / self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
                    os.system('mkdir -p repo/save/resl_{}'.format(int(floor(self.resl))))
                    utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)), int(self.global_iter / self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))

                # Tensorboard visualization
                if self.use_tb:
                    x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0], self.global_iter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0], self.global_iter)
                    self.tb.add_scalar('tick/lr', self.lr, self.global_iter)
                    self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.global_iter)
                    self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1, 1], [0, 1]), self.global_iter)
                    self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1, 1], [0, 1]), self.global_iter)
                    self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1, 1], [0, 1]), self.global_iter)
예제 #11
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
        if self.use_captions:
            test_caps_set = False
            self.caps_test = torch.FloatTensor(self.loader.batchsize, self.ncap)
            if self.use_cuda:
                self.caps_test = self.caps_test.cuda()
            self.caps_test = Variable(self.caps_test, volatile=True)
        
        
        for step in range(2, self.max_resl+1+5):
            for iter in tqdm(range(0,(self.trns_tick*2+self.stab_tick*2)*self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter+1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack%(ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()
                
                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                if self.use_captions:
                    batch_imgs, batch_caps = self.loader.get_batch()
                    if self.use_cuda:
                        batch_caps = batch_caps.cuda()
                    self.caps.data = batch_caps
                    if not test_caps_set:
                        self.caps_test.data = batch_caps
                        test_caps_set = True
                else:
                    batch_imgs, _ = self.loader.get_batch()
                self.x.data = self.feed_interpolated_input(batch_imgs)
                if self.flag_add_noise:
                    self.x = self.add_noise(self.x)
                self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
                if not self.use_captions:
                    self.x_tilde = self.G(self.z)
                else:
                    self.x_tilde = self.G(self.z, self.caps)
                if not self.use_captions:
                    self.fx = self.D(self.x)
                    self.fx_tilde = self.D(self.x_tilde.detach())
                else:
                    self.fx = self.D(self.x, self.caps)
                    self.fx_tilde = self.D(self.x_tilde.detach(), self.caps)

                if self.gan_type == 'lsgan':
                    loss_d = self.mse(self.fx, self.real_label) + self.mse(self.fx_tilde, self.fake_label)
                elif self.gan_type == 'wgan-gp':
                    D_real_loss = -torch.mean(self.fx_tilde)
                    D_fake_loss = torch.mean(self.x_tilde)

                    if self.use_cuda:
                        alpha = torch.rand(self.x.size().cuda())
                    else:
                        alpha = torch.rand(self.x.size())

                    x_hat = Variable(alpha * self.x.data + (1- alpha) * self.G.data, requires_grad=True)

                    pred_hat = self.D(x_hat)

                    if self.use_cuda:
                        gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(),
                                     create_graph=True, retain_graph=True, only_inputs=True)[0]
                    else:
                        gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()),
                                         create_graph=True, retain_graph=True, only_inputs=True)[0]

                    gradient_penalty = self.lambda * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()

                    loss_d = D_real_loss + D_fake_loss + gradient_penalty

                loss_d.backward()
                self.opt_d.step()

                # update generator.
                if not self.use_captions:
                    fx_tilde = self.D(self.x_tilde)
                else:
                    fx_tilde = self.D(self.x_tilde, self.caps)

                if self.gan_type == 'lsgan':
                    loss_g = self.mse(fx_tilde, self.real_label.detach())
                elif self.gan_type == 'wgan-gp':
                    loss_g = -torch.mean(fx_tilde)

                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(self.epoch, self.globalTick, self.stack, len(self.loader.dataset), loss_d.data[0], loss_g.data[0], self.resl, int(pow(2,floor(self.resl))), self.phase, self.complete['gen'], self.complete['dis'], self.lr)
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo/model')

                # save image grid.
                if self.globalIter%self.config.save_img_every == 0:
                    if not self.use_captions:
                        x_test = self.G(self.z_test)
                    else:
                        x_test = self.G(self.z_test, self.caps_test)
                    os.system('mkdir -p repo/save/grid')
                    utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
                    os.system('mkdir -p repo/save/resl_{}'.format(int(floor(self.resl))))
                    utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)),int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))


                # tensorboard visualization.
                if self.use_tb:
                    if not self.use_captions:
                        x_test = self.G(self.z_test)
                    else:
                        x_test = self.G(self.z_test, self.caps_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0], self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0], self.globalIter)
                    self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                    self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.globalIter)
                    self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1,1], [0,1]), self.globalIter)
                    self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1,1], [0,1]), self.globalIter)
예제 #12
0
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)

        # summary(self.G.module.model, input_size=(512, ))
        # summary(self.D.module.model, input_size=(3, 4, 4))
<<<<<<< HEAD
=======
        # exit()
>>>>>>> 70f249f1f09f20dcdc0d807fa8124fd44f1b6256

        net.soft_copy_param(self.Gs, self.G, 1.)
        x_test = self.G(self.z_test)
        Gs_test = self.Gs(self.z_test)
        os.system('mkdir -p repo/save/grid')
        utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.png'.format(int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']), imsize=2**self.max_resl*4)
        utils.save_image_grid(Gs_test.data, 'repo/save/grid/{}_{}_G{}_D{}_Gs.png'.format(int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']), imsize=2**self.max_resl*4)
        
        for step in range(int(floor(self.resl)), self.max_resl+1+5):
            if self.phase == 'init':
                total_tick = self.stab_tick
                start_tick = self.globalTick
            else:
                total_tick = self.trns_tick + self.stab_tick
                start_tick = self.globalTick - (step - 2.5) * total_tick
                if step > self.max_resl:
                    start_tick = 0
            print('Start from tick', start_tick, 'till', total_tick)
            for iter in tqdm(range(int(start_tick) * self.TICK, (total_tick)*self.TICK, self.loader.batchsize*self.minibatch_repeat)):
                self.globalIter = self.globalIter+self.minibatch_repeat
                self.stack = self.stack + self.loader.batchsize*self.minibatch_repeat
예제 #13
0
    def train(self):
        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        # for step in range(2, self.max_resolution+1+5):
        # for iter in range(0,(self.transition_tick*2+self.stablize_tick*2)*self.TICK, self.loader.batchsize):
        final_step = 0
        while True:
            self.globalIter = self.globalIter + 1
            self.stack = self.stack + self.loader.batchsize
            if self.stack > ceil(len(self.loader.dataset)):
                self.epoch = self.epoch + 1
                self.stack = int(self.stack % (ceil(len(self.loader.dataset))))

            # resolutionolution scheduler.
            sched_results = self.resolution_scheduler()

            # zero gradients.
            self.G.zero_grad()
            self.D.zero_grad()

            # update discriminator.
            self.x.data = self.feed_interpolated_input(self.loader.get_batch())
            if self.flag_add_noise:
                self.x = self.add_noise(self.x)
            self.z.data.resize_(self.loader.batchsize,
                                self.nz).normal_(0.0, 1.0)
            self.x_tilde = self.G(self.z)

            self.fx = self.D(self.x)
            self.fx_tilde = self.D(self.x_tilde.detach())

            loss_d = self.mse(self.fx.squeeze(), self.real_label) + \
                                self.mse(self.fx_tilde, self.fake_label)
            loss_d.backward()
            self.opt_d.step()

            # update generator.
            fx_tilde = self.D(self.x_tilde)
            loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach())
            loss_g.backward()
            self.opt_g.step()
            if self.globalIter % self.config.freq_print == 0:
                # logging.
                log_msg = sched_results[
                    'ticked'] + ' [E:{0}][T:{1}]  errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resolution:{7:4}][{8}]'.format(
                        self.epoch, self.globalTick, self.stack,
                        len(self.loader.dataset), loss_d.item(),
                        loss_g.item(), self.resolution,
                        int(pow(2, floor(self.resolution))), self.phase,
                        self.complete['gen'], self.complete['dis'], self.lr)
                if hasattr(self, 'fadein') and self.fadein['dis'] is not None:
                    log_msg += '|D-Alpha:{:0.2f}'.format(
                        self.fadein['dis'].alpha)

                if hasattr(self, 'fadein') and self.fadein['gen'] is not None:
                    log_msg += '|G-Alpha:{:0.2f}'.format(
                        self.fadein['gen'].alpha)

                print(log_msg)
            if self.phase == 'final':
                final_step += 1
                if final_step > self.config.final_steps:
                    self.snapshot('repo/model')
                    break
            # tqdm.write(log_msg)

            # save model.
            self.snapshot('repo/model')

            # save image grid.
            if self.globalIter % self.config.save_img_every == 0:
                with torch.no_grad():
                    x_test = self.G(self.z_test)
                utils.mkdir('repo/save/grid')
                utils.save_image_grid(
                    x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(
                        int(self.globalIter / self.config.save_img_every),
                        self.phase, self.complete['gen'],
                        self.complete['dis']))
                utils.mkdir('repo/save/resolution_{}'.format(
                    int(floor(self.resolution))))
                utils.save_image_single(
                    x_test.data,
                    'repo/save/resolution_{}/{}_{}_G{}_D{}.jpg'.format(
                        int(floor(self.resolution)),
                        int(self.globalIter / self.config.save_img_every),
                        self.phase, self.complete['gen'],
                        self.complete['dis']))
                # import ipdb; ipdb.set_trace()
            # tensorboard visualization.
            if self.use_tb:
                with torch.no_grad():
                    x_test = self.G(self.z_test)
                self.tb.add_scalar('data/loss_g', loss_g.item(),
                                   self.globalIter)
                self.tb.add_scalar('data/loss_d', loss_d.item(),
                                   self.globalIter)
                self.tb.add_scalar('tick/lr', self.lr, self.globalIter)
                self.tb.add_scalar('tick/cur_resolution',
                                   int(pow(2, floor(self.resolution))),
                                   self.globalIter)
                '''IMAGE GRID
예제 #14
0
    def train(self):

        # optimizer
        betas = (self.config.beta1, self.config.beta2)
        if self.optimizer == 'adam':
            self.opt_g = Adam(filter(lambda p: p.requires_grad,
                                     self.G.parameters()),
                              lr=self.config.lr,
                              betas=betas,
                              weight_decay=0.0)
            self.opt_d = Adam(filter(lambda p: p.requires_grad,
                                     self.D.parameters()),
                              lr=self.config.lr,
                              betas=betas,
                              weight_decay=0.0)

        # noise for test.
        self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz)
        if self.use_cuda:
            self.z_test = self.z_test.cuda()
        self.z_test = Variable(self.z_test, volatile=True)
        self.z_test.data.resize_(self.loader.batchsize,
                                 self.nz).normal_(0.0, 1.0)

        for step in range(2, self.max_resl):
            for iter in tqdm(
                    range(0, (self.trns_tick * 2 + self.stab_tick * 2) *
                          self.TICK, self.loader.batchsize)):
                self.globalIter = self.globalIter + 1
                self.stack = self.stack + self.loader.batchsize
                if self.stack > ceil(len(self.loader.dataset)):
                    self.epoch = self.epoch + 1
                    self.stack = int(self.stack %
                                     (ceil(len(self.loader.dataset))))

                # reslolution scheduler.
                self.resl_scheduler()

                # zero gradients.
                self.G.zero_grad()
                self.D.zero_grad()

                # update discriminator.
                self.x.data = self.feed_interpolated_input(
                    self.loader.get_batch())
                self.z.data.resize_(self.loader.batchsize,
                                    self.nz).normal_(0.0, 1.0)
                self.x_tilde = self.G(self.z)

                fx = self.D(self.x)
                fx_tilde = self.D(self.x_tilde.detach())
                loss_d = self.mse(fx, self.real_label) + self.mse(
                    fx_tilde, self.fake_label)
                loss_d.backward()
                self.opt_d.step()

                # update generator.
                fx_tilde = self.D(self.x_tilde)
                loss_g = self.mse(fx_tilde, self.real_label.detach())
                loss_g.backward()
                self.opt_g.step()

                # logging.
                log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}]  errD: {4:.4f} | errG: {5:.4f} | [cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
                    self.epoch, self.globalTick, self.stack,
                    len(self.loader.dataset), loss_d.data[0], loss_g.data[0],
                    self.resl, int(pow(2, floor(self.resl))), self.phase,
                    self.complete['gen'], self.complete['dis'])
                tqdm.write(log_msg)

                # save model.
                self.snapshot('repo/model')

                # save image grid.
                if self.globalIter % self.config.save_img_every == 0:
                    x_test = self.G(self.z_test)
                    os.system('mkdir -p repo/save/grid')
                    utils.save_image_grid(
                        x_test.data, 'repo/save/grid/{}.jpg'.format(
                            int(self.globalIter / self.config.save_img_every)))
                    os.system('mkdir -p repo/save/resl_{}'.format(
                        int(floor(self.resl))))
                    utils.save_image_single(
                        x_test.data, 'repo/save/resl_{}/{}.jpg'.format(
                            int(floor(self.resl)),
                            int(self.globalIter / self.config.save_img_every)))

                # tensorboard visualization.
                if self.use_tb:
                    x_test = self.G(self.z_test)
                    self.tb.add_scalar('data/loss_g', loss_g.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('data/loss_d', loss_d.data[0],
                                       self.globalIter)
                    self.tb.add_scalar('tick/globalTick', int(self.globalTick),
                                       self.globalIter)
                    self.tb.add_image_grid('grid/x_test', 4,
                                           x_test.data.float(),
                                           self.globalIter)
                    self.tb.add_image_grid('grid/x_tilde', 4,
                                           self.x_tilde.data.float(),
                                           self.globalIter)
                    self.tb.add_image_grid('grid/x_intp', 1,
                                           self.x.data.float(),
                                           self.globalIter)