Beispiel #1
0
class Trainer(object):
    def __init__(self, config):

        # Images data path & Output path
        self.dataset = config.dataset
        self.data_path = config.data_path
        self.save_path = os.path.join(config.save_path, config.name)

        # Training settings
        self.batch_size = config.batch_size
        self.total_step = config.total_step
        self.d_steps_per_iter = config.d_steps_per_iter
        self.g_steps_per_iter = config.g_steps_per_iter
        self.d_lr = config.d_lr
        self.g_lr = config.g_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.inst_noise_sigma = config.inst_noise_sigma
        self.inst_noise_sigma_iters = config.inst_noise_sigma_iters
        self.start = 0  # Unless using pre-trained model

        # Image transforms
        self.shuffle = config.shuffle
        self.drop_last = config.drop_last
        self.resize = config.resize
        self.imsize = config.imsize
        self.centercrop = config.centercrop
        self.centercrop_size = config.centercrop_size
        self.tanh_scale = config.tanh_scale
        self.normalize = config.normalize

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.save_n_images = config.save_n_images
        self.max_frames_per_gif = config.max_frames_per_gif

        # Pretrained model
        self.pretrained_model = config.pretrained_model

        # Misc
        self.manual_seed = config.manual_seed
        self.disable_cuda = config.disable_cuda
        self.parallel = config.parallel
        self.dataloader_args = config.dataloader_args

        # Output paths
        self.model_weights_path = os.path.join(self.save_path,
                                               config.model_weights_dir)
        self.sample_path = os.path.join(self.save_path, config.sample_dir)

        # Model hyper-parameters
        self.adv_loss = config.adv_loss
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.lambda_gp = config.lambda_gp

        # Model name
        self.name = config.name

        # Create directories if not exist
        utils.make_folder(self.save_path)
        utils.make_folder(self.model_weights_path)
        utils.make_folder(self.sample_path)

        # Copy files
        utils.write_config_to_file(config, self.save_path)
        utils.copy_scripts(self.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(
            self.batch_size, self.dataset, self.data_path, self.shuffle,
            self.drop_last, self.dataloader_args, self.resize, self.imsize,
            self.centercrop, self.centercrop_size)

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        # Start with pretrained model (if it exists)
        if self.pretrained_model != '':
            utils.load_pretrained_model(self)

        if self.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()

    def train(self):

        # Seed
        np.random.seed(self.manual_seed)
        random.seed(self.manual_seed)
        torch.manual_seed(self.manual_seed)

        # For fast training
        cudnn.benchmark = True

        # For BatchNorm
        self.G.train()
        self.D.train()

        # Fixed noise for sampling from G
        fixed_noise = torch.randn(self.batch_size,
                                  self.z_dim,
                                  device=self.device)
        if self.num_of_classes < self.batch_size:
            fixed_labels = torch.from_numpy(
                np.tile(np.arange(self.num_of_classes),
                        self.batch_size // self.num_of_classes +
                        1)[:self.batch_size]).to(self.device)
        else:
            fixed_labels = torch.from_numpy(np.arange(self.batch_size)).to(
                self.device)

        # For gan loss
        label = torch.full((self.batch_size, ), 1, device=self.device)
        ones = torch.full((self.batch_size, ), 1, device=self.device)

        # Losses file
        log_file_name = os.path.join(self.save_path, 'log.txt')
        log_file = open(log_file_name, "wt")

        # Init
        start_time = time.time()
        G_losses = []
        D_losses_real = []
        D_losses_fake = []
        D_losses = []
        D_xs = []
        D_Gz_trainDs = []
        D_Gz_trainGs = []

        # Instance noise - make random noise mean (0) and std for injecting
        inst_noise_mean = torch.full(
            (self.batch_size, 3, self.imsize, self.imsize),
            0,
            device=self.device)
        inst_noise_std = torch.full(
            (self.batch_size, 3, self.imsize, self.imsize),
            self.inst_noise_sigma,
            device=self.device)

        # Start training
        for self.step in range(self.start, self.total_step):

            # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters
            inst_noise_sigma_curr = 0 if self.step > self.inst_noise_sigma_iters else (
                1 - self.step /
                self.inst_noise_sigma_iters) * self.inst_noise_sigma
            inst_noise_std.fill_(inst_noise_sigma_curr)

            # ================== TRAIN D ================== #

            for _ in range(self.d_steps_per_iter):

                # Zero grad
                self.reset_grad()

                # TRAIN with REAL

                # Get real images & real labels
                real_images, real_labels = self.get_real_samples()

                # Get D output for real images & real labels
                inst_noise = torch.normal(mean=inst_noise_mean,
                                          std=inst_noise_std).to(self.device)
                d_out_real = self.D(real_images + inst_noise, real_labels)

                # Compute D loss with real images & real labels
                if self.adv_loss == 'hinge':
                    d_loss_real = torch.nn.ReLU()(ones - d_out_real).mean()
                elif self.adv_loss == 'wgan_gp':
                    d_loss_real = -d_out_real.mean()
                else:
                    label.fill_(1)
                    d_loss_real = self.criterion(d_out_real, label)

                # Backward
                d_loss_real.backward()

                # TRAIN with FAKE

                # Create random noise
                z = torch.randn(self.batch_size,
                                self.z_dim,
                                device=self.device)

                # Generate fake images for same real labels
                fake_images = self.G(z, real_labels)

                # Get D output for fake images & same real labels
                inst_noise = torch.normal(mean=inst_noise_mean,
                                          std=inst_noise_std).to(self.device)
                d_out_fake = self.D(fake_images.detach() + inst_noise,
                                    real_labels)

                # Compute D loss with fake images & real labels
                if self.adv_loss == 'hinge':
                    d_loss_fake = torch.nn.ReLU()(ones + d_out_fake).mean()
                elif self.adv_loss == 'dcgan':
                    label.fill_(0)
                    d_loss_fake = self.criterion(d_out_fake, label)
                else:
                    d_loss_fake = d_out_fake.mean()

                # Backward
                d_loss_fake.backward()

                # If WGAN_GP, compute GP and add to D loss
                if self.adv_loss == 'wgan_gp':
                    d_loss_gp = self.lambda_gp * self.compute_gradient_penalty(
                        real_images, real_labels, fake_images.detach())
                    d_loss_gp.backward()

                # Optimize
                self.D_optimizer.step()

            # ================== TRAIN G ================== #

            for _ in range(self.g_steps_per_iter):

                # Zero grad
                self.reset_grad()

                # Get real images & real labels (only need real labels)
                real_images, real_labels = self.get_real_samples()

                # Create random noise
                z = torch.randn(self.batch_size, self.z_dim).to(self.device)

                # Generate fake images for same real labels
                fake_images = self.G(z, real_labels)

                # Get D output for fake images & same real labels
                inst_noise = torch.normal(mean=inst_noise_mean,
                                          std=inst_noise_std).to(self.device)
                g_out_fake = self.D(fake_images + inst_noise, real_labels)

                # Compute G loss with fake images & real labels
                if self.adv_loss == 'dcgan':
                    label.fill_(1)
                    g_loss = self.criterion(g_out_fake, label)
                else:
                    g_loss = -g_out_fake.mean()

                # Backward + Optimize
                g_loss.backward()
                self.G_optimizer.step()

            # Print out log info
            if self.step % self.log_step == 0:
                G_losses.append(g_loss.mean().item())
                D_losses_real.append(d_loss_real.mean().item())
                D_losses_fake.append(d_loss_fake.mean().item())
                D_loss = D_losses_real[-1] + D_losses_fake[-1]
                if self.adv_loss == 'wgan_gp':
                    D_loss += d_loss_gp.mean().item()
                D_losses.append(D_loss)
                D_xs.append(d_out_real.mean().item())
                D_Gz_trainDs.append(d_out_fake.mean().item())
                D_Gz_trainGs.append(g_out_fake.mean().item())
                curr_time = time.time()
                curr_time_str = datetime.datetime.fromtimestamp(
                    curr_time).strftime('%Y-%m-%d %H:%M:%S')
                elapsed = str(
                    datetime.timedelta(seconds=(curr_time - start_time)))
                log = (
                    "[{}] : Elapsed [{}], Iter [{} / {}], G_loss: {:.4f}, D_loss: {:.4f}, D_loss_real: {:.4f}, D_loss_fake: {:.4f}, D(x): {:.4f}, D(G(z))_trainD: {:.4f}, D(G(z))_trainG: {:.4f}\n"
                    .format(curr_time_str, elapsed, self.step, self.total_step,
                            G_losses[-1], D_losses[-1], D_losses_real[-1],
                            D_losses_fake[-1], D_xs[-1], D_Gz_trainDs[-1],
                            D_Gz_trainGs[-1]))
                print(log)
                log_file.write(log)
                log_file.flush()
                utils.make_plots(G_losses, D_losses, D_losses_real,
                                 D_losses_fake, D_xs, D_Gz_trainDs,
                                 D_Gz_trainGs, self.log_step, self.save_path)

            # Sample images
            if self.step % self.sample_step == 0:
                self.G.eval()
                fake_images = self.G(fixed_noise, fixed_labels)
                self.G.train()
                sample_images = utils.denorm(
                    fake_images.detach()[:self.save_n_images])
                # Save batch images
                vutils.save_image(
                    sample_images,
                    os.path.join(self.sample_path,
                                 'fake_{:05d}.png'.format(self.step)))
                # Save gif
                utils.make_gif(
                    sample_images[0].cpu().numpy().transpose(1, 2, 0) * 255,
                    self.step,
                    self.sample_path,
                    self.name,
                    max_frames_per_gif=self.max_frames_per_gif)

            # Save model
            if self.step % self.model_save_step == 0:
                utils.save_ckpt(self)

    def build_models(self):
        self.G = Generator(self.z_dim, self.g_conv_dim,
                           self.num_of_classes).to(self.device)
        self.D = Discriminator(self.d_conv_dim,
                               self.num_of_classes).to(self.device)
        if 'cuda' in self.device.type and self.parallel and torch.cuda.device_count(
        ) > 1:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.G_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        # print networks
        print(self.G)
        print(self.D)

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def get_real_samples(self):
        try:
            real_images, real_labels = next(self.data_iter)
        except:
            self.data_iter = iter(self.dataloader)
            real_images, real_labels = next(self.data_iter)

        real_images, real_labels = real_images.to(self.device), real_labels.to(
            self.device)
        return real_images, real_labels

    def compute_gradient_penalty(self, real_images, real_labels, fake_images):
        # Compute gradient penalty
        alpha = torch.rand(real_images.size(0), 1, 1,
                           1).expand_as(real_images).to(device)
        interpolated = torch.tensor(alpha * real_images +
                                    (1 - alpha) * fake_images,
                                    requires_grad=True)
        out = self.D(interpolated, real_labels)
        exp_grad = torch.ones(out.size()).to(device)
        grad = torch.autograd.grad(outputs=out,
                                   inputs=interpolated,
                                   grad_outputs=exp_grad,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]
        grad = grad.view(grad.size(0), -1)
        grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
        d_loss_gp = torch.mean((grad_l2norm - 1)**2)
        return d_loss_gp
Beispiel #2
0
class Tester(object):
    def __init__(self, data_loader, config):

        # Data loader
        self.data_loader = data_loader

        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)
        self.test_store_path = os.path.join(config.test_store_path,
                                            self.version)

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def test(self):
        for iter in range(500):
            fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))

            fake_images, _, _ = self.G(fixed_z)

            fakeimage = np.transpose(var2numpy(fake_images.data), (0, 2, 3, 1))
            self.output_fig(
                fakeimage,
                os.path.join(self.test_store_path,
                             '{:03d}_image.png'.format(iter + 1)))

    def output_fig(self, images_array, file_name):
        plt.figure(figsize=(6, 6), dpi=100)
        plt.imshow(helper.images_square_grid(images_array))
        plt.axis("off")
        plt.savefig(file_name + '.png', bbox_inches='tight', pad_inches=0)

    def build_model(self):

        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize,
                               self.d_conv_dim).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D)

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))
Beispiel #3
0
class Trainer(object):
    def __init__(self, data_loader, config):

        # Data loader
        self.data_loader = data_loader

        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.imchan = config.imchan
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def train(self):

        # Data iterator
        data_iter = iter(self.data_loader)
        step_per_epoch = len(self.data_loader)
        model_save_step = int(self.model_save_step * step_per_epoch)

        # Fixed input for debugging
        fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 0

        # Start time
        start_time = time.time()
        for step in range(start, self.total_step):

            # ================== Train D ================== #
            self.D.train()
            self.G.train()

            try:
                real_images, _ = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                real_images, _ = next(data_iter)

            # Compute loss with real images
            # dr1, dr2, df1, df2, gf1, gf2 are attention scores
            real_images = tensor2var(real_images)
            d_out_real, dr1, dr2 = self.D(real_images)
            if self.adv_loss == 'wgan-gp':
                d_loss_real = -torch.mean(d_out_real)
            elif self.adv_loss == 'hinge':
                d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

            # apply Gumbel Softmax
            z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
            fake_images, gf1, gf2 = self.G(z)
            d_out_fake, df1, df2 = self.D(fake_images)

            if self.adv_loss == 'wgan-gp':
                d_loss_fake = d_out_fake.mean()
            elif self.adv_loss == 'hinge':
                d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            if self.adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_images.size(0), 1, 1,
                                   1).cuda().expand_as(real_images)
                interpolated = Variable(alpha * real_images.data +
                                        (1 - alpha) * fake_images.data,
                                        requires_grad=True)
                out, _, _ = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp

                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

            # ================== Train G and gumbel ================== #
            # Create random noise
            z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
            fake_images, _, _ = self.G(z)

            # Compute loss with fake images
            g_out_fake, _, _ = self.D(fake_images)  # batch x n
            if self.adv_loss == 'wgan-gp':
                g_loss_fake = -g_out_fake.mean()
            elif self.adv_loss == 'hinge':
                g_loss_fake = -g_out_fake.mean()

            self.reset_grad()
            g_loss_fake.backward()
            self.g_optimizer.step()

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                if self.G.attn2:
                    ave_gamma_l4 = "{:.4f}".format(
                        self.G.attn2.gamma.mean().item())
                else:
                    ave_gamma_l4 = "n/a"
                print("Elapsed [{}], Step [{}/{}], d_out_real: {:.4f}, "
                      " ave_gamma_l3: {:.4f}, ave_gamma_l4: {}".format(
                          elapsed, step + 1, self.total_step,
                          d_loss_real.item(),
                          self.G.attn1.gamma.mean().item(), ave_gamma_l4))

            # Sample images
            if (step + 1) % self.sample_step == 0:
                fake_images, _, _ = self.G(fixed_z)
                save_image(
                    denorm(fake_images.data),
                    os.path.join(self.sample_path,
                                 '{}_fake.png'.format(step + 1)))

            if (step + 1) % model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_G.pth'.format(step + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_D.pth'.format(step + 1)))

    def build_model(self):
        self.G = Generator(batch_size=self.batch_size,
                           image_size=self.imsize,
                           z_dim=self.z_dim,
                           conv_dim=self.g_conv_dim,
                           image_channels=self.imchan).cuda()
        self.D = Discriminator(batch_size=self.batch_size,
                               image_size=self.imsize,
                               conv_dim=self.d_conv_dim,
                               image_channels=self.imchan).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D)

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))
class Trainer(object):
    def __init__(self, data_loader, config):
        self.data_loader = data_loader

        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.ge_lr = config.ge_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.mura_class = config.mura_class
        self.mura_type = config.mura_type
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        if self.use_tensorboard:
            self.build_tensorboard()

        self.build_model()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def train(self):
        # Data iterator
        print('inside the train')
        data_iter = iter(self.data_loader)
        step_per_epoch = len(self.data_loader)
        model_save_step = int(self.model_save_step * step_per_epoch)

        # Fixed input for debugging
        fixed_img, _ = next(data_iter)
        fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))
        if self.use_tensorboard:
            self.writer.add_image('img/fixed_img', denorm(fixed_img.data), 0)
        else:
            save_image(denorm(fixed_img.data),
                       os.path.join(self.sample_path, 'fixed_img.png'))

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 0

        self.D.train()
        self.E.train()
        self.G.train()

        # Start time
        start_time = time.time()
        for step in range(start, self.total_step):
            self.reset_grad()
            # Sample from data and prior
            try:
                real_images, _ = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                real_images, _ = next(data_iter)

            real_images = tensor2var(real_images)
            fake_z = tensor2var(torch.randn(real_images.size(0), self.z_dim))

            noise1 = torch.Tensor(real_images.size()).normal_(
                0, 0.01 * (step + 1 - self.total_step) / (step + 1))

            noise2 = torch.Tensor(real_images.size()).normal_(
                0, 0.01 * (step + 1 - self.total_step) / (step + 1))
            # Sample from condition
            real_z, _, _ = self.E(real_images)
            fake_images, gf1, gf2 = self.G(fake_z)

            dr, dr5, dr4, dr3, drz, dra2, dra1 = self.D(
                real_images + noise1, real_z)
            df, df5, df4, df3, dfz, dfa2, dfa1 = self.D(
                fake_images + noise2, fake_z)

            # Compute loss with real and fake images
            # dr1, dr2, df1, df2, gf1, gf2 are attention scores
            if self.adv_loss == 'wgan-gp':
                d_loss_real = -torch.mean(dr)
                d_loss_fake = df.mean()
                g_loss_fake = -df.mean()
                e_loss_real = -dr.mean()
            elif self.adv_loss == 'hinge1':
                d_loss_real = torch.nn.ReLU()(1.0 - dr).mean()
                d_loss_fake = torch.nn.ReLU()(1.0 + df).mean()
                g_loss_fake = -df.mean()
                e_loss_real = -dr.mean()
            elif self.adv_loss == 'hinge':
                d_loss_real = -log(dr).mean()
                d_loss_fake = -log(1.0 - df).mean()
                g_loss_fake = -log(df).mean()
                e_loss_real = -log(1.0 - dr).mean()
            elif self.adv_loss == 'inverse':
                d_loss_real = -log(1.0 - dr).mean()
                d_loss_fake = -log(df).mean()
                g_loss_fake = -log(1.0 - df).mean()
                e_loss_real = -log(dr).mean()

            # ================== Train D ================== #
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward(retain_graph=True)
            self.d_optimizer.step()

            if self.adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_images.size(0), 1, 1,
                                   1).expand_as(real_images)
                interpolated = Variable(alpha * real_images.data +
                                        (1 - alpha) * fake_images.data,
                                        requires_grad=True)
                out, _, _ = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp

                d_loss.backward()
                self.d_optimizer.step()

            # ================== Train G and E ================== #
            ge_loss = g_loss_fake + e_loss_real
            ge_loss.backward()
            self.ge_optimizer.step()

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print(
                    f"Elapsed: [{elapsed}], step: [{step+1}/{self.total_step}], d_loss: {d_loss}, ge_loss: {ge_loss}"
                )

                if self.use_tensorboard:
                    self.writer.add_scalar('d/loss_real', d_loss_real.data,
                                           step + 1)
                    self.writer.add_scalar('d/loss_fake', d_loss_fake.data,
                                           step + 1)
                    self.writer.add_scalar('d/loss', d_loss.data, step + 1)
                    self.writer.add_scalar('ge/loss_real', e_loss_real.data,
                                           step + 1)
                    self.writer.add_scalar('ge/loss_fake', g_loss_fake.data,
                                           step + 1)
                    self.writer.add_scalar('ge/loss', ge_loss.data, step + 1)
                    self.writer.add_scalar('ave_gamma/l3',
                                           self.G.attn1.gamma.mean().data,
                                           step + 1)
                    self.writer.add_scalar('ave_gamma/l4',
                                           self.G.attn2.gamma.mean().data,
                                           step + 1)

            # Sample images
            if (step + 1) % self.sample_step == 0:
                img_from_z, _, _ = self.G(fixed_z)
                z_from_img, _, _ = self.E(tensor2var(fixed_img))
                reimg_from_z, _, _ = self.G(z_from_img)

                if self.use_tensorboard:
                    self.writer.add_image('img/reimg_from_z',
                                          denorm(reimg_from_z.data), step + 1)
                    self.writer.add_image('img/img_from_z',
                                          denorm(img_from_z.data), step + 1)
                else:
                    save_image(
                        denorm(img_from_z.data),
                        os.path.join(self.sample_path,
                                     '{}_img_from_z.png'.format(step + 1)))
                    save_image(
                        denorm(reimg_from_z.data),
                        os.path.join(self.sample_path,
                                     '{}_reimg_from_z.png'.format(step + 1)))

            if (step + 1) % model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_G.pth'.format(step + 1)))
                torch.save(
                    self.E.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_E.pth'.format(step + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_D.pth'.format(step + 1)))

    def build_model(self):
        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim)
        self.E = Encoder(self.batch_size, self.imsize, self.z_dim,
                         self.d_conv_dim)
        self.D = Discriminator(self.batch_size, self.imsize, self.z_dim,
                               self.d_conv_dim)
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.E = nn.DataParallel(self.E)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        self.ge_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad,
                   itertools.chain(self.G.parameters(), self.E.parameters())),
            self.ge_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        # print(self.G)
        # print(self.E)
        # print(self.D)

    def build_tensorboard(self):
        '''Initialize tensorboard writeri'''
        self.writer = SummaryWriter(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.E.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_E.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.ge_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))
Beispiel #5
0
class Trainer(object):
    def __init__(self, data_loader, config):
        torch.manual_seed(config.seed)
        torch.cuda.manual_seed(config.seed)

        self.data_loader = data_loader
        self.model = config.model
        self.adv_loss = config.adv_loss

        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel
        self.extra = config.extra

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_scheduler = config.lr_scheduler
        self.g_beta1 = config.g_beta1
        self.d_beta1 = config.d_beta1
        self.beta2 = config.beta2

        self.dataset = config.dataset
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version
        self.backup_freq = config.backup_freq
        self.bup_path = config.bup_path

        # Path
        self.optim = config.optim
        self.svrg = config.svrg
        self.avg_start = config.avg_start
        self.build_model()

        if self.svrg:
            self.mu_g = []
            self.mu_d = []
            self.g_snapshot = copy.deepcopy(self.G)
            self.d_snapshot = copy.deepcopy(self.D)
            self.svrg_freq_sampler = bernoulli.Bernoulli(torch.tensor([1 / len(self.data_loader)]))

        self.info_logger = setup_logger(self.log_path)
        self.info_logger.info(config)
        self.cont = config.cont

    def train(self):
        self.data_gen = self._data_gen()

        fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))

        if self.cont:
            start = self.load_backup()
        else:
            start = 0

        start_time = time.time()
        if self.svrg:
            self.update_svrg_stats()
        for step in range(start, self.total_step):

            # =================== SVRG =================== #
            if self.svrg and self.svrg_freq_sampler.sample() == 1:

                # ================= Update Avg ================= #
                if self.avg_start >= 0 and step > 0 and step >= self.avg_start:
                    self.update_avg_nets()
                    if self.avg_freq_restart_sampler.sample() == 1:
                        self.G.load_state_dict(self.avg_g.state_dict())
                        self.D.load_state_dict(self.avg_d.state_dict())
                        self.avg_step = 1
                        self.info_logger.info('Params updated with avg-nets at %d-th step.' % step)

                self.update_svrg_stats()
                self.info_logger.info("SVRG stats updated at %d-th step." % step)

            # ================= Train pair ================= #
            d_loss_real = self._update_pair(step)

            # --- storing stuff ---
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print("Elapsed [{}], Step [{}/{}]".format(elapsed, step + 1, self.total_step))

            if (step + 1) % self.sample_step == 0:
                save_image(denorm(self.G(fixed_z).data),
                           os.path.join(self.sample_path, 'gen', 'iter%08d.png' % step))
                save_image(denorm(self.G_avg(fixed_z).data),
                           os.path.join(self.sample_path, 'gen_avg', 'iter%08d.png' % step))
                save_image(denorm(self.G_ema(fixed_z).data),
                           os.path.join(self.sample_path, 'gen_ema', 'iter%08d.png' % step))

            if self.model_save_step > 0 and (step+1) % self.model_save_step == 0:
                torch.save(self.G.state_dict(),
                           os.path.join(self.model_save_path, 'gen', 'iter%08d.pth' % step))
                torch.save(self.G_avg.state_dict(),
                           os.path.join(self.model_save_path, 'gen_avg', 'iter%08d.pth' % step))
                torch.save(self.G_ema.state_dict(),
                           os.path.join(self.model_save_path, 'gen_ema', 'iter%08d.pth' % step))
            if self.backup_freq > 0 and (step+1) % self.backup_freq == 0:
                self.backup(step)

    def _data_gen(self):
        """ Data iterator

        :return: s
        """
        data_iter = iter(self.data_loader)
        while True:
            try:
                real_images, _ = next(data_iter)
            except StopIteration:
                data_iter = iter(self.data_loader)
                real_images, _ = next(data_iter)
            yield real_images

    def _update_pair(self, step):
        _lr_scheduler = self.lr_scheduler > 0 and step > 0 and step % len(self.data_loader) == 0
        self.D.train()
        self.G.train()
        real_images = tensor2var(next(self.data_gen))

        self._extra_sync_nets()
        if self.extra:
            # ================== Train D @ t + 1/2 ================== #
            self._backprop_disc(D=self.D_extra, G=self.G, real_images=real_images,
                                d_optim=self.d_optimizer_extra, svrg=self.svrg,
                                scheduler_d=self.scheduler_d_extra if _lr_scheduler else None)

            # ================== Train G @ t + 1/2 ================== #
            self._backprop_gen(G=self.G_extra, D=self.D, bsize=real_images.size(0),
                               g_optim=self.g_optimizer_extra, svrg=self.svrg,
                               scheduler_g=self.scheduler_g_extra if _lr_scheduler else None)

            real_images = tensor2var(next(self.data_gen))  # Re-sample

        # ================== Train D @ t + 1 ================== #
        d_loss_real = self._backprop_disc(G=self.G_extra, D=self.D, real_images=real_images,
                                          d_optim=self.d_optimizer, svrg=self.svrg,
                                          scheduler_d=self.scheduler_d if _lr_scheduler else None)

        # ================== Train G and gumbel @ t + 1 ================== #
        self._backprop_gen(G=self.G, D=self.D_extra, bsize=real_images.size(0),
                           g_optim=self.g_optimizer, svrg=self.svrg,
                           scheduler_g=self.scheduler_g if _lr_scheduler else None)

        # === Moving avg Generator-nets ===
        self._update_avg_gen(step)
        self._update_ema_gen()
        return d_loss_real

    def _normalize_acc_grads(self, net):
        """Divides accumulated gradients with len(self.data_loader)"""
        for _param in filter(lambda p: p.requires_grad, net.parameters()):
            _param.grad.data.div_(len(self.data_loader))

    def update_svrg_stats(self):
        self.mu_g, self.mu_d = [], []

        # Update mu_d  ####################
        self.d_optimizer.zero_grad()
        for _, _data in enumerate(self.data_loader):
            real_images = tensor2var(_data[0])
            self._backprop_disc(self.G, self.D, real_images, d_optim=None, svrg=False)
        self._normalize_acc_grads(self.D)
        for _param in filter(lambda p: p.requires_grad, self.D.parameters()):
            self.mu_d.append(_param.grad.data.clone())

        # Update mu_g  ####################
        self.g_optimizer.zero_grad()
        for _ in range(len(self.data_loader)):
            self._backprop_gen(self.G, self.D, self.batch_size, g_optim=None, svrg=False)
        self._normalize_acc_grads(self.G)
        for _param in filter(lambda p: p.requires_grad, self.G.parameters()):
            self.mu_g.append(_param.grad.data.clone())

        # Update snapshots  ###############
        self.g_snapshot.load_state_dict(self.G.state_dict())
        self.d_snapshot.load_state_dict(self.D.state_dict())

    @staticmethod
    def _update_grads_svrg(params, snapshot_params, mu):
        """Helper function which updates the accumulated gradients of
        params by subtracting those of snapshot and adding mu.

        Operates in-place.
        See line 12 & 14 of Algo. 3 in SVRG-GAN.

        Raises:
            ValueError if the inputs have different lengths (the length
            corresponds to the number of layers in the network)

        :param params: [list of torch.nn.parameter.Parameter]
        :param snapshot_params: [torch.nn.parameter.Parameter]
        :param mu: [list of torch(.cuda).FloatTensor]
        :return: [None]
        """
        if not len(params) == len(snapshot_params) == len(mu):
            raise ValueError("Expected input of identical length. "
                             "Got {}, {}, {}".format(len(params),
                                                     len(snapshot_params),
                                                     len(mu)))
        for i in range(len(mu)):
            params[i].grad.data.sub_(snapshot_params[i].grad.data)
            params[i].grad.data.add_(mu[i])

    def _backprop_disc(self, G, D, real_images, d_optim=None, svrg=False, scheduler_d=None):
        """Updates D (Vs. G).

        :param G:
        :param D:
        :param real_images:
        :param d_optim: if None, only backprop
        :param svrg:
        :return:
        """
        d_out_real = D(real_images)
        if self.adv_loss == 'wgan-gp':
            d_loss_real = - torch.mean(d_out_real)
        elif self.adv_loss == 'hinge':
            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
        else:
            raise NotImplementedError

        z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
        fake_images = G(z)
        d_out_fake = D(fake_images)

        if self.adv_loss == 'wgan-gp':
            d_loss_fake = d_out_fake.mean()
        elif self.adv_loss == 'hinge':
            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
        else:
            raise NotImplementedError

        # Backward + Optimize
        d_loss = d_loss_real + d_loss_fake
        if d_optim is not None:
            d_optim.zero_grad()
        d_loss.backward()
        if d_optim is not None:
            if svrg:  # d_snapshot Vs. g_snapshot
                d_out_real = self.d_snapshot(real_images)
                d_out_fake = self.d_snapshot(self.g_snapshot(z))
                if self.adv_loss == 'wgan-gp':
                    d_s_loss_real = - torch.mean(d_out_real)
                    d_loss_fake = d_out_fake.mean()
                elif self.adv_loss == 'hinge':
                    d_s_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
                    d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
                else:
                    raise NotImplementedError

                d_loss = d_s_loss_real + d_loss_fake
                self.d_snapshot.zero_grad()
                d_loss.backward()

                self._update_grads_svrg(list(filter(lambda p: p.requires_grad, D.parameters())),
                                        list(filter(lambda p: p.requires_grad, self.d_snapshot.parameters())),
                                        self.mu_d)
            d_optim.step()
            if scheduler_d is not None:
                scheduler_d.step()

        if self.adv_loss == 'wgan-gp':  # Todo: add SVRG for wgan-gp
            raise NotImplementedError('SVRG-WGAN-gp is not implemented yet')
            # Compute gradient penalty
            alpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images)
            interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)
            out = D(interpolated)

            grad = torch.autograd.grad(outputs=out,
                                       inputs=interpolated,
                                       grad_outputs=torch.ones(out.size()).cuda(),
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]

            grad = grad.view(grad.size(0), -1)
            grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
            d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)

            # Backward + Optimize
            d_loss = self.lambda_gp * d_loss_gp

            if d_optim is not None:
                d_optim.reset_grad()
            d_loss.backward()
            if d_optim is not None:
                self.d_optimizer.step()
        return d_loss_real.data.item()

    def _backprop_gen(self, G, D, bsize, g_optim=True, svrg=False, scheduler_g=None):
        """Updates G (Vs. D).

        :param G:
        :param D:
        :param bsize:
        :param g_optim: if None only backprop
        :param svrg:
        :return:
        """
        z = tensor2var(torch.randn(bsize, self.z_dim))
        fake_images = G(z)

        g_out_fake = D(fake_images)  # batch x n
        if self.adv_loss == 'wgan-gp' or self.adv_loss == 'hinge':
            g_loss_fake = - g_out_fake.mean()

        if g_optim is not None:
            g_optim.zero_grad()
        g_loss_fake.backward()
        if g_optim is not None:
            if svrg:  # G_snapshot Vs. D_snapshot
                self.g_snapshot.zero_grad()
                if self.adv_loss == 'wgan-gp' or self.adv_loss == 'hinge':
                    (- self.d_snapshot(self.g_snapshot(z)).mean()).backward()
                else:
                    raise NotImplementedError
                self._update_grads_svrg(list(filter(lambda p: p.requires_grad, G.parameters())),
                                        list(filter(lambda p: p.requires_grad, self.g_snapshot.parameters())),
                                        self.mu_g)
            g_optim.step()
            if scheduler_g is not None:
                scheduler_g.step()
        return g_loss_fake.data.item()

    def build_model(self):
        # Models                    ###################################################################
        self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda()
        # Todo: do not allocate unnecessary GPU mem for G_extra and D_extra if self.extra == False
        self.G_extra = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda()
        self.D_extra = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda()
        if self.avg_start >= 0:
            self.avg_g = copy.deepcopy(self.G)
            self.avg_d = copy.deepcopy(self.D)
            self._requires_grad(self.avg_g, False)
            self._requires_grad(self.avg_d, False)
            self.avg_g.eval()
            self.avg_d.eval()
            self.avg_step = 1
            self.avg_freq_restart_sampler = bernoulli.Bernoulli(.1)

        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)
            self.G_extra = nn.DataParallel(self.G_extra)
            self.D_extra = nn.DataParallel(self.D_extra)
            if self.avg_start >= 0:
                self.avg_g = nn.DataParallel(self.avg_g)
                self.avg_d = nn.DataParallel(self.avg_d)
        self.G_extra.train()
        self.D_extra.train()

        self.G_avg = copy.deepcopy(self.G)
        self.G_ema = copy.deepcopy(self.G)
        self._requires_grad(self.G_avg, False)
        self._requires_grad(self.G_ema, False)

        # Logs, Loss & optimizers   ###################################################################
        grad_var_logger_g = setup_logger(self.log_path, 'log_grad_var_g.log')
        grad_var_logger_d = setup_logger(self.log_path, 'log_grad_var_d.log')
        grad_mean_logger_g = setup_logger(self.log_path, 'log_grad_mean_g.log')
        grad_mean_logger_d = setup_logger(self.log_path, 'log_grad_mean_d.log')

        if self.optim == 'sgd':
            self.g_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.G.parameters()),
                                               self.g_lr,
                                               logger_mean=grad_mean_logger_g,
                                               logger_var=grad_var_logger_g)
            self.d_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.D.parameters()),
                                               self.d_lr,
                                               logger_mean=grad_mean_logger_d,
                                               logger_var=grad_var_logger_d)
            self.g_optimizer_extra = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                            self.G_extra.parameters()),
                                                     self.g_lr)
            self.d_optimizer_extra = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                            self.D_extra.parameters()),
                                                     self.d_lr)
        elif self.optim == 'adam':
            self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()),
                                                self.g_lr, [self.g_beta1, self.beta2],
                                                logger_mean=grad_mean_logger_g,
                                                logger_var=grad_var_logger_g)
            self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()),
                                                self.d_lr, [self.d_beta1, self.beta2],
                                                logger_mean=grad_mean_logger_d,
                                                logger_var=grad_var_logger_d)
            self.g_optimizer_extra = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                             self.G_extra.parameters()),
                                                      self.g_lr, [self.g_beta1, self.beta2])
            self.d_optimizer_extra = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                             self.D_extra.parameters()),
                                                      self.d_lr, [self.d_beta1, self.beta2])
        elif self.optim == 'svrgadam':
            self.g_optimizer = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.G.parameters()),
                                                    self.g_lr, [self.g_beta1, self.beta2],
                                                    logger_mean=grad_mean_logger_g,
                                                    logger_var=grad_var_logger_g)
            self.d_optimizer = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.D.parameters()),
                                                    self.d_lr, [self.d_beta1, self.beta2],
                                                    logger_mean=grad_mean_logger_d,
                                                    logger_var=grad_var_logger_d)
            self.g_optimizer_extra = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad,
                                                          self.G_extra.parameters()),
                                                          self.g_lr, [self.g_beta1, self.beta2])
            self.d_optimizer_extra = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad,
                                                          self.D_extra.parameters()),
                                                          self.d_lr, [self.d_beta1, self.beta2])
        else:
            raise NotImplementedError('Supported optimizers: SGD, Adam, Adadelta')

        if self.lr_scheduler > 0:  # Exponentially decaying learning rate
            self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR(self.g_optimizer,
                                                                      gamma=self.lr_scheduler)
            self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(self.d_optimizer,
                                                                      gamma=self.lr_scheduler)
            self.scheduler_g_extra = torch.optim.lr_scheduler.ExponentialLR(self.g_optimizer_extra,
                                                                            gamma=self.lr_scheduler)
            self.scheduler_d_extra = torch.optim.lr_scheduler.ExponentialLR(self.d_optimizer_extra,
                                                                            gamma=self.lr_scheduler)

        print(self.G)
        print(self.D)

    def _extra_sync_nets(self):
        """ Helper function. Copies the current parameters to the t+1/2 parameters,
         stored as 'net' and 'extra_net', respectively.

        :return: [None]
        """
        self.G_extra.load_state_dict(self.G.state_dict())
        self.D_extra.load_state_dict(self.D.state_dict())

    @staticmethod
    def _update_avg(avg_net, net, avg_step):
        """Updates average network."""
        # Todo: input val
        net_param = list(net.parameters())
        for i, p in enumerate(avg_net.parameters()):
            p.mul_((avg_step - 1) / avg_step)
            p.add_(net_param[i].div(avg_step))

    @staticmethod
    def _requires_grad(_net, _bool=True):
        """Helper function which sets the requires_grad of _net to _bool.

        Raises:
            TypeError: _net is given but is not derived from nn.Module, or
                       _bool is not boolean

        :param _net: [nn.Module]
        :param _bool: [bool, optional] Default: True
        :return: [None]
        """
        if _net and not isinstance(_net, torch.nn.Module):
            raise TypeError("Expected torch.nn.Module. Got: {}".format(type(_net)))
        if not isinstance(_bool, bool):
            raise TypeError("Expected bool. Got: {}".format(type(_bool)))

        if _net is not None:
            for _w in _net.parameters():
                _w.requires_grad = _bool

    def update_avg_nets(self):
        self._update_avg(self.avg_g, self.G, self.avg_step)
        self._update_avg(self.avg_d, self.D, self.avg_step)
        self.avg_step += 1

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))

    def backup(self, iteration):
        """Back-ups the networks & optimizers' states.

        Note: self.g_extra & self.d_extra are not stored, as these are copied from
        self.G & self.D at the beginning of each iteration. However, the optimizers
        are backed up.

        :param iteration: [int]
        :return: [None]
        """
        torch.save(self.G.state_dict(), os.path.join(self.bup_path, 'gen.pth'))
        torch.save(self.D.state_dict(), os.path.join(self.bup_path, 'disc.pth'))
        torch.save(self.G_avg.state_dict(), os.path.join(self.bup_path, 'gen_avg.pth'))
        torch.save(self.G_ema.state_dict(), os.path.join(self.bup_path, 'gen_ema.pth'))

        torch.save(self.g_optimizer.state_dict(), os.path.join(self.bup_path, 'gen_optim.pth'))
        torch.save(self.d_optimizer.state_dict(), os.path.join(self.bup_path, 'disc_optim.pth'))
        torch.save(self.g_optimizer_extra.state_dict(), os.path.join(self.bup_path, 'gen_extra_optim.pth'))
        torch.save(self.d_optimizer_extra.state_dict(), os.path.join(self.bup_path, 'disc_extra_optim.pth'))

        with open(os.path.join(self.bup_path, "timestamp.txt"), "w") as fff:
            fff.write("%d" % iteration)

    def load_backup(self):
        """Loads the Backed-up networks & optimizers' states.

        Note: self.g_extra & self.d_extra are not stored, as these are copied from
        self.G & self.D at the beginning of each iteration. However, the optimizers
        are backed up.

        :return: [int] timestamp to continue from
        """
        if not os.path.exists(self.bup_path):
            raise ValueError('Cannot load back-up. Directory {} '
                             'does not exist.'.format(self.bup_path))

        self.G.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen.pth')))
        self.D.load_state_dict(torch.load(os.path.join(self.bup_path, 'disc.pth')))
        self.G_avg.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_avg.pth')))
        self.G_ema.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_ema.pth')))

        self.g_optimizer.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_optim.pth')))
        self.d_optimizer.load_state_dict(torch.load(os.path.join(self.bup_path, 'disc_optim.pth')))
        self.g_optimizer_extra.load_state_dict(torch.load(os.path.join(self.bup_path, 'gen_extra_optim.pth')))
        self.d_optimizer_extra.load_state_dict(torch.load(os.path.join(self.bup_path, 'disc_extra_optim.pth')))

        with open(os.path.join(self.bup_path, "timestamp.txt"), "r") as fff:
            timestamp = [int(x) for x in next(fff).split()]  # read first line
            if not len(timestamp) == 1:
                raise ValueError('Could not determine timestamp of the backed-up models.')
            timestamp = int(timestamp[0]) + 1

        self.info_logger.info("Loaded models from %s, at timestamp %d." %
                              (self.bup_path, timestamp))
        return timestamp

    def _update_avg_gen(self, n_gen_update):
        """ Updates the uniform average generator. """
        l_param = list(self.G.parameters())
        l_avg_param = list(self.G_avg.parameters())
        if len(l_param) != len(l_avg_param):
            raise ValueError("Got different lengths: {}, {}".format(len(l_param), len(l_avg_param)))

        for i in range(len(l_param)):
            l_avg_param[i].data.copy_(l_avg_param[i].data.mul(n_gen_update).div(n_gen_update + 1.).add(
                                      l_param[i].data.div(n_gen_update + 1.)))

    def _update_ema_gen(self, beta_ema=0.9999):
        """ Updates the exponential moving average generator. """
        l_param = list(self.G.parameters())
        l_ema_param = list(self.G_ema.parameters())
        if len(l_param) != len(l_ema_param):
            raise ValueError("Got different lengths: {}, {}".format(len(l_param), len(l_ema_param)))

        for i in range(len(l_param)):
            l_ema_param[i].data.copy_(l_ema_param[i].data.mul(beta_ema).add(
                l_param[i].data.mul(1-beta_ema)))
class Trainer(object):
    def __init__(self, data_loader, config):

        # Data loader
        self.data_loader = data_loader

        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss
        self.conv_G = config.conv_G

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path, self.version)
        self.cuda = torch.cuda.is_available() #and cuda
        print("Using cuda:", self.cuda)

        self.build_model()

        #self.use_tensorboard = True
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()



    def train(self):

        # Data iterator
        data_iter = iter(self.data_loader)
        step_per_epoch = len(self.data_loader)
        model_save_step = int(self.model_save_step * step_per_epoch)

        # Fixed input for debugging
        fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 0

        # Start time
        start_time = time.time()
        for step in range(start, self.total_step):

            # ================== Train D ================== #
            self.D.train()
            self.G.train()

            try:
                items = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                items = next(data_iter)

            X, Y = items
            fake_class = torch.Tensor(np.ones(Y.shape)* np.random.randint(0, 6, size=(Y.shape[0], 1, 1, 1)))
            X, Y = X.type(torch.FloatTensor), Y.type(torch.FloatTensor)
            #X, Y = Variable(X.cuda()), Variable(Y.cuda())
            X, Y = Variable(X), Variable(Y)
            if self.cuda:
                X = X.cuda()
                Y = Y.cuda()

            class_label = Y[:,0,0,0]
            #class_one_hot = torch.zeros(Y.shape[0], 6)
            #for i, elem in enumerate(class_label):
                #class_one_hot[i, int(elem.item())] = 1.0
            class_one_hot = class_label.type(torch.LongTensor)
            #FRITS: the real_disc_in consists of the images X and the desired class
            #desired class chosen randomly, different from real class Y
            real_disc_in = X#,torch.cat((X,Y), dim = 1)
            generator_in = torch.cat((X,fake_class), dim = 1)
            # Compute loss with real images
            # dr1, dr2, df1, df2, gf1, gf2 are attention scores
            #real_images = tensor2var(real_images)
            #Frits TODO: why feed the real_disc_in to D?
            d_out_real,dr1,dr2 = self.D(real_disc_in)
            print(d_out_real)
            print(class_one_hot)
            if self.adv_loss == 'wgan-gp':
                d_loss_real = - torch.mean(d_out_real)
            elif self.adv_loss == 'hinge':
                d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
            elif self.adv_loss == 'softmax':
                d_loss_real = F.cross_entropy(d_out_real, class_one_hot).mean()

            # apply Gumbel Softmax

            #Changed to input both image and class
            fake_images,gf1,gf2 = self.G(torch.cat((X,Y), dim = 1))
            fake_disc_in = fake_images#torch.cat((fake_images, Y), dim = 1)
            d_out_fake,df1,df2 = self.D(fake_disc_in)

            # if self.adv_loss == 'wgan-gp':
            #     d_loss_fake = d_out_fake.mean()
            if self.adv_loss == 'hinge':
                d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
            elif self.adv_loss == 'softmax':
                d_loss_fake = F.cross_entropy(d_out_fake, class_one_hot).mean()
            #elif self.adv_loss == 'softmax':


            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # ================== Train G and gumbel ================== #
            # Create random noise
            #z = tensor2var(torch.randn(real_images.size(0), self.z_dim))

            #TODO Fritz: Do we need this?
            fake_images,_,_ = self.G(generator_in)
            fake_disc_in = fake_images#torch.cat((fake_images, Y), dim = 1)
            # Compute loss with fake images
            g_out_fake,_,_ = self.D(fake_disc_in)  # batch x n
            if self.adv_loss == 'wgan-gp':
                g_loss_fake = - g_out_fake.mean()
            elif self.adv_loss == 'hinge':
                g_loss_fake = - g_out_fake.mean()
            elif self.adv_loss == 'softmax':
                g_loss_fake = - F.cross_entropy(g_out_fake, class_one_hot).mean()


            self.reset_grad()
            g_loss_fake.backward()
            self.g_optimizer.step()

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_loss: {:.4f}, g_loss {:.4f}"
                      " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".
                      format(elapsed, step + 1, self.total_step, (step + 1),
                             self.total_step , d_loss.item(), g_loss_fake.item(),
                             self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item() ))
                    # format(elapsed, step + 1, self.total_step, (step + 1),
                    #        self.total_step , d_loss.data[0], g_loss_fake.data[0],
                    #        self.G.attn1.gamma.mean().data[0], self.G.attn2.gamma.mean().data[0] ))
                with open('log_info.txt', 'a') as f:
                    # f.write("Step {}, D Loss {}, G Loss {}\n".format(step + 1, d_loss.data[0], g_loss_fake.data[0]))
                    f.write("Step {}, D Loss {}, G Loss {}\n".format(step + 1, d_loss.item(), g_loss_fake.item()))

            # Sample images
            if (step + 1) % self.sample_step == 0:
                fake_images,_,_= self.G(generator_in)
                result = torch.cat((X, fake_images, Y), dim = 2)
                save_image(denorm(result.data),
                           os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))
                with open(os.path.join(self.sample_path,'step_{}.txt'.format(step+1)), 'a') as f:
                    # f.write("Step {}, D Loss {}, G Loss {}\n".format(step + 1, d_loss.data[0], g_loss_fake.data[0]))
                    real_labels = Y[:, 0, 0, 0]
                    fake_labels = fake_class[:, 0, 0, 0]
                    print(fake_labels)
                    f.write("Step {}, Real Labels: {}, Target(Fake) Labels {}\n".format(step + 1, real_labels, fake_labels))

            if (step+1) % model_save_step==0:
                torch.save(self.G.state_dict(),
                           os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1)))
                torch.save(self.D.state_dict(),
                           os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1)))

    def build_model(self):
        self.G = None
        if self.conv_G:
            self.G = UpDownConvolutionalGenerator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim)
            if self.cuda:
                self.G = self.G.cuda()
        else:
            self.G = Generator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim)
            if self.cuda:
                self.G = self.G.cuda()
        self.D = Discriminator(self.batch_size,self.imsize, self.d_conv_dim)
        if self.cuda:
            self.D = self.D.cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D)

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
Beispiel #7
0
class Tester(object):
    def __init__(self, data_loader, args, train_loader, model_decoder, chamfer,
                 vis_Valida):

        # decoder settings
        self.model_decoder = model_decoder
        self.chamfer = chamfer
        self.vis = vis_Valida
        self.j = 0

        # Data loader
        #  self.data_loader = data_loader
        self.train_loader = train_loader  # TODO

        # exact model and loss
        self.model = args.model
        self.adv_loss = args.adv_loss

        # Model hyper-parameters
        self.imsize = args.imsize
        self.g_num = args.g_num
        self.z_dim = args.z_dim
        self.g_conv_dim = args.g_conv_dim
        self.d_conv_dim = args.d_conv_dim
        self.parallel = args.parallel

        self.lambda_gp = args.lambda_gp
        self.total_step = args.total_step
        self.d_iters = args.d_iters
        self.batch_size = args.batch_size
        self.num_workers = args.num_workers
        self.g_lr = args.g_lr
        self.d_lr = args.d_lr
        self.lr_decay = args.lr_decay
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.pretrained_model = args.pretrained_model

        self.dataset = args.dataset
        self.use_tensorboard = args.use_tensorboard
        self.image_path = args.image_path
        self.log_path = args.log_path
        self.model_save_path = args.model_save_path
        self.sample_path = args.sample_path
        self.log_step = args.log_step
        self.sample_step = args.sample_step
        self.model_save_step = args.model_save_step
        self.version = args.version

        # Path
        self.log_path = os.path.join(args.log_path, self.version)
        self.sample_path = os.path.join(args.sample_path, self.version)
        self.model_save_path = os.path.join(args.model_save_path, self.version)

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def train(self):

        # Data iterator
        #  data_iter = iter(self.data_loader)
        train_iter = iter(self.train_loader)  # TODO

        # step_per_epoch = len(self.data_loader)
        train_step_per_epoch = len(self.train_loader)  # TODO

        # model_save_step = int(self.model_save_step * step_per_epoch)
        model_save_step = int(self.model_save_step *
                              train_step_per_epoch)  # TODO

        # Fixed input for debugging
        fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 0

        # Start time
        start_time = time.time()
        for step in range(start, self.total_step):

            # ================== Train D ================== #
            self.D.train()
            self.G.train()

            try:
                #  real_images, _ = next(data_iter)
                real_images = next(train_iter)  # TODO

            except:
                #  data_iter = iter(self.data_loader)
                train_iter = iter(self.train_loader)  # TODO

                # real_images, _ = next(data_iter)
                real_images = next(train_iter)

            # Compute loss with real images
            # dr1, dr2, df1, df2, gf1, gf2 are attention scores
            real_images = tensor2var(real_images)
            d_out_real, dr1 = self.D(real_images)  #,dr2
            if self.adv_loss == 'wgan-gp':
                d_loss_real = -torch.mean(d_out_real)
            elif self.adv_loss == 'hinge':
                d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

            # apply Gumbel Softmax
            z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
            fake_images, gf1 = self.G(z)  #,gf2
            d_out_fake, df1 = self.D(fake_images)  #,df2

            if self.adv_loss == 'wgan-gp':
                d_loss_fake = d_out_fake.mean()
            elif self.adv_loss == 'hinge':
                d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            if self.adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_images.size(0), 1, 1,
                                   1).cuda().expand_as(real_images)
                interpolated = Variable(alpha * real_images.data +
                                        (1 - alpha) * fake_images.data,
                                        requires_grad=True)
                out, _ = self.D(interpolated)  # TODO "_"

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp

                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

            # ================== Train G and gumbel ================== #
            # Create random noise
            z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
            fake_images, _ = self.G(z)  # _

            # Compute loss with fake images
            g_out_fake, _ = self.D(fake_images)  # batch x n  TODO "_"
            if self.adv_loss == 'wgan-gp':
                g_loss_fake = -g_out_fake.mean()
            elif self.adv_loss == 'hinge':
                g_loss_fake = -g_out_fake.mean()

            self.reset_grad()
            g_loss_fake.backward()
            self.g_optimizer.step()

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print(
                    "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
                    " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".format(
                        elapsed, step + 1, self.total_step, (step + 1),
                        self.total_step, d_loss_real.data[0],
                        self.G.attn1.gamma.mean().data[0],
                        self.G.attn2.gamma.mean().data[0]))

            # Sample images
            if (step + 1) % self.sample_step == 0:
                fake_images, _ = self.G(fixed_z)  #TODO "_"

                encoded = fake_images.contiguous().view(64, 128)

                pc_1 = self.model_decoder(encoded)
                #pc_1_temp = pc_1[0, :, :]

                epoch = 0
                for self.j in range(0, 10):
                    pc_1_temp = pc_1[self.j, :, :]
                    test = fixed_z.detach().cpu().numpy()
                    test1 = np.asscalar(test[self.j, 0])

                    visuals = OrderedDict([('Validation Predicted_pc',
                                            pc_1_temp.detach().cpu().numpy())])
                    self.vis[self.j].display_current_results(visuals,
                                                             epoch,
                                                             step,
                                                             z=test1)

                save_image(
                    denorm(fake_images.data),
                    os.path.join(self.sample_path,
                                 '{}_fake.png'.format(step + 1)))

            if (step + 1) % model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_G.pth'.format(step + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_D.pth'.format(step + 1)))

    def build_model(self):

        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize,
                               self.d_conv_dim).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D)

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))
Beispiel #8
0
class Solver(object):
    """Solver for training and testing StarGAN."""
    def __init__(self, celeba_loader, rafd_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.celeba_loader = celeba_loader
        self.rafd_loader = rafd_loader

        # Model configurations.
        self.c_dim = config.c_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters
        self.selected_attrs = config.selected_attrs

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            self.G = Generator(self.batch_size, self.image_size, self.c_dim,
                               self.g_conv_dim).cuda()
            self.D = Discriminator(self.batch_size, self.image_size,
                                   self.c_dim, self.d_conv_dim).cuda()
        # TODO add config: self.parallel (see line 195 in sagan/trainer.py)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        self.G.to(self.device)
        self.D.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def create_labels(self,
                      c_org,
                      c_dim=5,
                      dataset='CelebA',
                      selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        if dataset == 'CelebA':
            hair_color_indices = []
            for i, attr_name in enumerate(selected_attrs):
                if attr_name in [
                        'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
                ]:
                    hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            if dataset == 'CelebA':
                c_trg = c_org.clone()
                if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                    c_trg[:, i] = 1
                    for j in hair_color_indices:
                        if j != i:
                            c_trg[:, j] = 0
                else:
                    c_trg[:, i] = (c_trg[:,
                                         i] == 0)  # Reverse attribute value.
            elif dataset == 'RaFD':
                c_trg = self.label2onehot(torch.ones(c_org.size(0)) * i, c_dim)

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        if dataset == 'CelebA':
            return F.binary_cross_entropy_with_logits(
                logit, target, size_average=False) / logit.size(0)
        elif dataset == 'RaFD':
            return F.cross_entropy(logit, target)

    def train(self):
        """Train StarGAN within a single dataset."""
        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset,
                                          self.selected_attrs)

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_org = next(data_iter)

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]

            if self.dataset == 'CelebA':
                c_org = label_org.clone()
                c_trg = label_trg.clone()
            elif self.dataset == 'RaFD':
                c_org = self.label2onehot(label_org, self.c_dim)
                c_trg = self.label2onehot(label_trg, self.c_dim)

            x_real = x_real.to(self.device)  # Input images.
            c_org = c_org.to(self.device)  # Original domain labels.
            c_trg = c_trg.to(self.device)  # Target domain labels.
            label_org = label_org.to(
                self.device)  # Labels for computing classification loss.
            label_trg = label_trg.to(
                self.device)  # Labels for computing classification loss.

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            # TODO: hinge loss (see line 107 & 117 in sagan/trainer.py)

            # Compute loss with real images.
            # dr1, dr2, df1, df2, gfd1, gfd2, gfu1, gfu2 are attention scores
            out_src, out_cls, dr1, dr2 = self.D(x_real)
            d_loss_real = -torch.mean(out_src)  # TODO: flip labels
            d_loss_cls = self.classification_loss(out_cls, label_org,
                                                  self.dataset)

            # Compute loss with fake images.
            x_fake, gfd1, gfd2, gfu1, gfu2 = self.G(x_real, c_trg)
            out_src, out_cls, df1, df2 = self.D(x_fake.detach())
            d_loss_fake = torch.mean(out_src)

            # Compute loss for gradient penalty.
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src, _, _, _ = self.D(x_hat)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            # Backward and optimize.
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_cls'] = d_loss_cls.item()
            loss['D/loss_gp'] = d_loss_gp.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #

            #if (i+1) % self.n_critic == 0: ## SA-GAN: Every time
            # Original-to-target domain.
            x_fake, _, _, _, _ = self.G(x_real, c_trg)
            out_src, out_cls, _, _ = self.D(x_fake)
            g_loss_fake = -torch.mean(out_src)
            g_loss_cls = self.classification_loss(out_cls, label_trg,
                                                  self.dataset)

            # Target-to-original domain.
            x_reconst, _, _, _, _ = self.G(x_fake, c_org)
            g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

            # Backward and optimize.
            g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
            self.reset_grad()
            g_loss.backward()
            self.g_optimizer.step()

            # Logging.
            loss['G/loss_fake'] = g_loss_fake.item()
            loss['G/loss_rec'] = g_loss_rec.item()
            loss['G/loss_cls'] = g_loss_cls.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i + 1)

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_fixed_list:
                        x_fake, _, _, _, _ = self.G(x_fixed, c_fixed)
                        x_fake_list.append(x_fake)
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir,
                                               '{}-images.jpg'.format(i + 1))
                    save_image(self.denorm(x_concat.data.cpu()),
                               sample_path,
                               nrow=1,
                               padding=0)
                    print('Saved real and fake images into {}...'.format(
                        sample_path))

            # Save model checkpoints.
            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

            # Decay learning rates.
            if (i + 1) % self.lr_update_step == 0 and (i + 1) > (
                    self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

    def test(self):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader

        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.c_dim,
                                                self.dataset,
                                                self.selected_attrs)

                # Translate images.
                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    x_fake, _, _, _, _ = self.G(x_real, c_trg)
                    x_fake_list.append(x_fake)

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir,
                                           '{}-images.jpg'.format(i + 1))
                save_image(self.denorm(x_concat.data.cpu()),
                           result_path,
                           nrow=1,
                           padding=0)
                print('Saved real and fake images into {}...'.format(
                    result_path))
Beispiel #9
0
class Tester(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        self.test_path = config.test_path
        self.test_path = os.path.join(config.test_path, self.version)

        self.build_model()

        self.load_pretrained_model()

    def build_model(self):

        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize,
                               self.d_conv_dim).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def test(self):

        num_of_images = 9

        for i in range(500):
            z = tensor2var(torch.randn(num_of_images, self.z_dim))

            fake_images, _, _ = self.G(z)

            #print(fake_images.data)
            #(9,3,w,h) -> (9,w,h,3)
            #print(fake_images.data.shape)
            #print(fake_images.data[0])

            #save_image(denorm(fake_images.data),
            #           os.path.join(self.test_path, '{}_fake.png'.format(i+1)),
            #           nrow=3)

            #file_name = os.path.join(self.test_path, '{}_fake_class_format.png'.format(i+1))

            #self.output_fig(denorm(fake_images.data), file_name)

            transpose_image = np.transpose(var2numpy(fake_images.data),
                                           (0, 2, 3, 1))
            self.output_fig(
                transpose_image,
                os.path.join(self.test_path,
                             '{}_fake_class_format.png'.format(i + 1)))

    def output_fig(self, images_array, file_name):
        # the shape of your images_array should be (9, width, height, 3),  28 <= width, height <= 112
        plt.figure(figsize=(6, 6), dpi=100)
        plt.imshow(helper.images_square_grid(images_array))
        plt.axis("off")
        plt.savefig(file_name, bbox_inches='tight', pad_inches=0)
Beispiel #10
0
class Trainer(object):
    def __init__(self, data_loader, config):

        # Data loader
        self.data_loader = data_loader

        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version
        self.gpu = 'gpu'

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)
        self.path1 = config.path1
        self.path2 = config.path2
        self.dims = config.dims

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def mytest(self, step):
        num = 200
        z = tensor2var(torch.randn(num, self.z_dim))  # 32*128

        fake_images, gf1, gf2 = self.G(z)  # 1*3*64*64
        fake_images = fake_images.data
        # inception_score(fake_images)
        # fake_images = fake_images.resize_((1, 3, 218, 178))
        for n in range(num):
            save_image(fake_images[n],
                       '/home/xinzi/dataset_k40/test_celeba/%d.jpg' % n)
            str0 = '/home/xinzi/dataset_k40/test_celeba/' + str(n) + '.jpg'
            im = Image.open(str0)
            im = im.resize((178, 218))
            im.save(str0)
        '''
        path =  '/home/jingjie/xinzi/dataset/test'
        
        a = os.listdir(path)
        #a.sort()
        for file in a:
            
            file_path = os.path.join(path, file)
            if os.path.splitext(file_path)[1] == '.png':
                im = Image.open(file_path)
        '''

        #args = parser.parse_args()
        #os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

        # fid_value = calculate_fid_given_paths(self.path1, self.path2, 1, self.gpu != '', self.dims)
        # print('FID: ', fid_value)
        # a = 0
        str1 = '/home/xinzi/dataset_k40/celebAtemp'
        times = 5
        sum = 0
        for i in range(times):
            shutil.rmtree(str1)
            os.mkdir(str1)

            random_copyfile(self.path1, str1, 40000)

            # args.path0 = str1
            """
                dims = 64
            """
            fid_value = calculate_fid_given_paths(str1, self.path2, 100,
                                                  self.gpu != '', self.dims)
            sum = sum + fid_value
            print('FID: ', fid_value)
        print(float(sum / times))
        f = open('FID.txt', 'a')
        f.write('\n')
        f.write(str(float(sum / times)))
        f.close()
        #d_out_fake, df1, df2 = self.D(fake_images)

    def train(self):

        # Data iterator
        data_iter = iter(self.data_loader)
        step_per_epoch = len(self.data_loader)
        model_save_step = int(self.model_save_step * step_per_epoch)  # ?????

        # Fixed input for debugging
        fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 1

        # Start time
        start_time = time.time()
        for step in range(start, self.total_step):
            # print(step)

            # ================== Train D ================== #
            self.D.train()
            self.G.train()

            try:  # try...except...是异常检测的语句,try中的语句出现错误会执行except里面的语句
                real_images, _ = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                real_images, _ = next(data_iter)

            # Compute loss with real images
            # dr1, dr2, df1, df2, gf1, gf2 are attention scores
            real_images = tensor2var(
                real_images)  # 将real_images装到cuda以及Variable里
            d_out_real, dr1, dr2 = self.D(real_images)
            if self.adv_loss == 'wgan-gp':
                d_loss_real = -torch.mean(d_out_real)  # mean为求平均值
            elif self.adv_loss == 'hinge':
                d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

            # apply Gumbel Softmax
            self.temp = real_images.size(0)
            z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
            fake_images, gf1, gf2 = self.G(z)
            d_out_fake, df1, df2 = self.D(
                fake_images)  # 此处为什么不是fake_images.detach()?

            if self.adv_loss == 'wgan-gp':
                d_loss_fake = d_out_fake.mean()
            elif self.adv_loss == 'hinge':
                d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            if self.adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_images.size(0), 1, 1,
                                   1).cuda().expand_as(real_images)
                interpolated = Variable(alpha * real_images.data +
                                        (1 - alpha) * fake_images.data,
                                        requires_grad=True)
                out, _, _ = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp

                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

            # ================== Train G and gumbel ================== #
            # Create random noise
            z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
            fake_images, _, _ = self.G(z)

            # Compute loss with fake images
            g_out_fake, _, _ = self.D(fake_images)  # batch x n
            if self.adv_loss == 'wgan-gp':
                g_loss_fake = -g_out_fake.mean()
            elif self.adv_loss == 'hinge':
                g_loss_fake = -g_out_fake.mean()

            self.reset_grad()
            g_loss_fake.backward()
            self.g_optimizer.step()

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print(
                    "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
                    " ave_gamma_l3: , ave_gamma_l4: ".format(
                        elapsed, step + 1, self.total_step, (step + 1),
                        self.total_step, d_loss_real.data[0]))

            # Sample images
            if (step + 1) % self.sample_step == 0:
                fake_images, _, _ = self.G(fixed_z)
                save_image(
                    denorm(fake_images.data),
                    os.path.join(self.sample_path,
                                 '{}_fake.png'.format(step + 1)))

            if (step + 1) % model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_G.pth'.format(step + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_D.pth'.format(step + 1)))

            if step >= 3000:
                if step % 400 == 0:
                    print('====================testing====================')
                    self.mytest(step)

    def build_model(self):

        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize,
                               self.d_conv_dim).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D)

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))
Beispiel #11
0
class Trainer(object):
    def __init__(self, loader, config, data_loader_val=None):

        # Data loader
        data_loader, data_loader_val = loader
        self.data_loader = data_loader
        self.data_loader_val = data_loader_val

        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.cam_view_z = (20 + 40 + 10) * 2 + 5
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.vae_rec_path = os.path.join(config.sample_path, "vae_rec")
        os.makedirs(self.vae_rec_path, exist_ok=True)  # TODO
        print('vae_rec_path: {}'.format(self.vae_rec_path))
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        self.num_pixels = self.imsize * 2 * 3

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def train(self):
        def cycle(iterable):
            while True:
                for x in iterable:
                    yield x

        # Using itertools.cycle has an important drawback, in that it does not shuffle the data after each iteration:
        # WARNING  itertools.cycle  does not shuffle the data after each iteratio
        # Data iterator
        data_iter = iter(cycle(self.data_loader))
        self.loader_val_iter = iter(cycle(self.data_loader_val))
        step_per_epoch = len(self.data_loader)
        model_save_step = int(self.model_save_step * step_per_epoch)

        # Fixed input for debugging
        fixed_z = None

        # Start with trained model
        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 0

        # Start time
        start_time = time.time()

        num_views = 2
        key_views = ["frames views {}".format(i) for i in range(num_views)]

        lable_keys_cam_view_info = []  # list with keys for view 0 and view 1
        for view_i in range(num_views):
            lable_keys_cam_view_info.append([
                "cam_pitch_view_{}".format(view_i),
                "cam_yaw_view_{}".format(view_i),
                "cam_distance_view_{}".format(view_i)
            ])

        mapping_cam_info_lable = OrderedDict()
        mapping_cam_info_one_hot = OrderedDict()
        # create a different mapping for echt setting
        n_classes = []
        for cam_info_view in lable_keys_cam_view_info:
            for cam_inf in cam_info_view:
                if "pitch" in cam_inf:
                    min_val, max_val = -50, -35.
                    n_bins = 20
                elif "yaw" in cam_inf:
                    min_val, max_val = -60., 210.
                    n_bins = 40
                elif "distance" in cam_inf:
                    min_val, max_val = 0.7, 1.
                    n_bins = 10

                to_l, to_hot_l = create_lable_func(min_val, max_val, n_bins)
                mapping_cam_info_lable[cam_inf] = to_l
                mapping_cam_info_one_hot[cam_inf] = to_hot_l
                if "view_0" in cam_inf:
                    n_classes.append(n_bins)
        print('cam view one hot infputs {}'.format(n_classes))
        task_progess_bins = 5
        _, task_progress_hot_func = create_lable_func(0,
                                                      115,
                                                      n_bins=task_progess_bins,
                                                      clip=True)

        assert sum(n_classes) * 2 + task_progess_bins == self.cam_view_z

        def changing_factor(start, end, steps):
            for i in range(steps):
                yield i / (steps / (end - start)) + start

        cycle_factor_gen = changing_factor(0.5, 1., self.total_step)
        triplet_factor_gen = changing_factor(0.1, 1., self.total_step)
        for step in range(start, self.total_step):

            # ================== Train D ================== #
            self.D.train()
            self.G.train()

            if isinstance(self.data_loader.dataset, DoubleViewPairDataset):
                data = next(data_iter)
                key_views, lable_keys_cam_view_info = shuffle(
                    key_views, lable_keys_cam_view_info)
                # real_images = torch.cat([data[key_views[0]], data[key_views[1]]])
                #  for now only view 0
                real_images = data[key_views[0]]
                real_images_view1 = data[key_views[1]]
                label_c = OrderedDict()
                label_c_hot_in = OrderedDict()
                for key_l, lable_func in mapping_cam_info_lable.items():
                    # contin cam values to labels
                    label_c[key_l] = torch.tensor(lable_func(
                        data[key_l])).cuda()
                    label_c_hot_in[key_l] = torch.tensor(
                        mapping_cam_info_one_hot[key_l](data[key_l]),
                        dtype=torch.float32).cuda()
                d_one_hot_view0 = [
                    label_c_hot_in[l] for l in lable_keys_cam_view_info[0]
                ]
                d_one_hot_view1 = [
                    label_c_hot_in[l] for l in lable_keys_cam_view_info[1]
                ]
                d_task_progress = torch.tensor(task_progress_hot_func(
                    data['frame index']),
                                               dtype=torch.float32).cuda()
            else:
                real_images, _ = next(data_iter)
            # Compute loss with real images
            # dr1, dr2, df1, df2, gf1, gf2 are attention scores
            real_images = tensor2var(real_images)
            real_images_view1 = tensor2var(real_images_view1)
            d_out_real, dr1, dr2 = self.D(real_images)
            if self.adv_loss == 'wgan-gp':
                d_loss_real = -torch.mean(d_out_real)
            elif self.adv_loss == 'hinge':
                d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

            # apply Gumbel Softmax
            encoded = self.G.encoder(real_images)
            sampled = self.G.encoder.sampler(encoded)
            z = torch.randn(real_images.size(0), self.z_dim).cuda()
            z = torch.cat(
                [*d_one_hot_view0, *d_one_hot_view1, d_task_progress, sampled],
                dim=1)  # add view info from to
            if fixed_z is None:
                fixed_z = tensor2var(
                    torch.cat([
                        *d_one_hot_view0, *d_one_hot_view1, d_task_progress,
                        sampled
                    ],
                              dim=1))  # add view info
            z = tensor2var(z)
            fake_images, gf1, gf2 = self.G(z)
            d_out_fake, df1, df2 = self.D(fake_images)

            if self.adv_loss == 'wgan-gp':
                d_loss_fake = d_out_fake.mean()
            elif self.adv_loss == 'hinge':
                d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            if self.adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_images.size(0), 1, 1,
                                   1).cuda().expand_as(real_images)
                interpolated = Variable(alpha * real_images.data +
                                        (1 - alpha) * fake_images.data,
                                        requires_grad=True)
                out, _, _ = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp

                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

            # ================== Train VAE================== #
            encoded = self.G.encoder(real_images)
            mu_0 = encoded[0]
            logvar = encoded[1]
            KLD_element = mu_0.pow(2).add_(
                logvar.exp()).mul_(-1).add_(1).add_(logvar)
            KLD = torch.sum(KLD_element).mul_(-0.5)
            # save_image(denorm(real_images[::2]), os.path.join(self.sample_path, "ancor.png"))
            # save_image(denorm(real_images[1::2]), os.path.join(self.sample_path, "neg.png"))
            # save_image(denorm(real_images_view1[::2]), os.path.join(self.sample_path, "pos.png"))

            sampled = self.G.encoder.sampler(encoded)
            z = torch.cat(
                [*d_one_hot_view0, *d_one_hot_view1, d_task_progress, sampled],
                dim=1)  # add view info 0
            z = tensor2var(z)
            fake_images_0, _, _ = self.G(z)
            MSEerr = self.MSECriterion(fake_images_0, real_images_view1)
            rec = fake_images_0
            VAEerr = MSEerr + KLD * 0.1
            # encode the fake view and recon loss to view1
            encoded = self.G.encoder(fake_images_0)
            mu_1 = encoded[0]
            logvar = encoded[1]
            KLD_element = mu_1.pow(2).add_(
                logvar.exp()).mul_(-1).add_(1).add_(logvar)
            KLD = torch.sum(KLD_element).mul_(-0.5)
            sampled = self.G.encoder.sampler(encoded)
            z = torch.cat(
                [*d_one_hot_view1, *d_one_hot_view0, d_task_progress, sampled],
                dim=1)  # add view info 1
            z = tensor2var(z)
            fake_images_view1, _, _ = self.G(z)
            rec_fake = fake_images_view1
            MSEerr = self.MSECriterion(fake_images_view1, real_images)
            VAEerr += (KLD * 0.1 + MSEerr) * next(
                cycle_factor_gen)  # (KLD + MSEerr)  # *0.5
            triplet_loss = self.triplet_loss(anchor=mu_0[::2],
                                             positive=mu_0[1::2],
                                             negative=mu_1[::2])
            # ================== Train G and gumbel ================== #
            # Create random noise
            # z = tensor2var(torch.randn(real_images.size(0), self.z_dim))
            # fake_images, _, _ = self.G(z)

            # Compute loss with fake images
            # fake_images = torch.cat([fake_images_0, fake_images_view1]) # rm triplets
            fake_images = torch.cat(
                [fake_images_0[::2], fake_images_view1[::2]])  # rm triplets
            g_out_fake, _, _ = self.D(fake_images)  # batch x n
            if self.adv_loss == 'wgan-gp':
                g_loss_fake = -g_out_fake.mean()
            elif self.adv_loss == 'hinge':
                g_loss_fake = -g_out_fake.mean()

            self.reset_grad()
            VAEerr *= self.num_pixels
            triplet_loss *= self.num_pixels
            loss = g_loss_fake * 4. + VAEerr + triplet_loss * next(
                triplet_factor_gen)
            loss.backward()

            self.g_optimizer.step()

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print(
                    "Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
                    " ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f},vae {:.4f}".
                    format(elapsed, step + 1, self.total_step, (step + 1),
                           self.total_step, d_loss_real,
                           self.G.attn1.gamma.mean(),
                           self.G.attn2.gamma.mean(), VAEerr))

                if vis is not None:
                    kw_update_vis = None

                    if self.d_plot is not None:
                        kw_update_vis = 'append'
                        # kw_update_vis["update"] = 'append'
                    self.d_plot = vis.line(np.array(
                        [d_loss_real.detach().cpu().numpy()]),
                                           X=np.array([step]),
                                           win=self.d_plot,
                                           update=kw_update_vis,
                                           opts=dict(title="d_loss_real",
                                                     xlabel='Timestep',
                                                     ylabel='loss'))
                    self.d_plot_fake = vis.line(np.array(
                        [d_loss_fake.detach().cpu().numpy()]),
                                                X=np.array([step]),
                                                win=self.d_plot_fake,
                                                update=kw_update_vis,
                                                opts=dict(title="d_loss_fake",
                                                          xlabel='Timestep',
                                                          ylabel='loss'))
                    self.d_plot_vae = vis.line(np.array(
                        [VAEerr.detach().cpu().numpy()]),
                                               X=np.array([step]),
                                               win=self.d_plot_vae,
                                               update=kw_update_vis,
                                               opts=dict(title="VAEerr",
                                                         xlabel='Timestep',
                                                         ylabel='loss'))
                    self.d_plot_triplet_loss = vis.line(
                        np.array([triplet_loss.detach().cpu().numpy()]),
                        X=np.array([step]),
                        win=self.d_plot_triplet_loss,
                        update=kw_update_vis,
                        opts=dict(title="triplet_loss",
                                  xlabel='Timestep',
                                  ylabel='loss'))

            # Sample images
            if (step + 1) % self.sample_step == 0:
                fake_images, _, _ = self.G(fixed_z)
                fake_images = denorm(fake_images)
                save_image(
                    fake_images.data,
                    os.path.join(self.sample_path,
                                 '{}_fake.png'.format(step + 1)))
                n = 8
                imgs = denorm(torch.cat([real_images.data[:n], rec.data[:n]]))
                imgs_rec_fake = denorm(
                    torch.cat([real_images_view1.data[:n], rec_fake.data[:n]]))
                title = '{}_var_rec_real'.format(step + 1)
                title_rec_fake = '{}_var_rec_fake'.format(step + 1)
                title_fixed = '{}_fixed'.format(step + 1)
                save_image(imgs,
                           os.path.join(self.vae_rec_path, title + ".png"),
                           nrow=n)
                distance_pos, product_pos, distance_neg, product_neg = self._get_view_pair_distances(
                )

                print("distance_pos {:.3}, neg {:.3},dot pos {:.3} neg {:.3}".
                      format(distance_pos, distance_neg, product_pos,
                             product_neg))
                if vis is not None:
                    self.rec_win = vis.images(
                        imgs,
                        win=self.rec_win,
                        opts=dict(title=title, width=64 * n, height=64 * 2),
                    )
                    self.rec_fake_win = vis.images(
                        imgs_rec_fake,
                        win=self.rec_fake_win,
                        opts=dict(title=title_rec_fake,
                                  width=64 * n,
                                  height=64 * 2),
                    )
                    self.fixed_win = vis.images(
                        fake_images,
                        win=self.fixed_win,
                        opts=dict(title=title_fixed,
                                  width=64 * n,
                                  height=64 * 4),
                    )

                    kw_update_vis = None
                    if self.d_plot_distance_pos is not None:
                        kw_update_vis = 'append'
                    self.d_plot_distance_pos = vis.line(
                        np.array([distance_pos]),
                        X=np.array([step]),
                        win=self.d_plot_distance_pos,
                        update=kw_update_vis,
                        opts=dict(title="distance pos",
                                  xlabel='Timestep',
                                  ylabel='dist'))
                    self.d_plot_distance_neg = vis.line(
                        np.array([distance_neg]),
                        X=np.array([step]),
                        win=self.d_plot_distance_neg,
                        update=kw_update_vis,
                        opts=dict(title="distance neg",
                                  xlabel='Timestep',
                                  ylabel='dist'))
            if (step + 1) % model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_G.pth'.format(step + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_D.pth'.format(step + 1)))

    def _get_view_pair_distances(self):
        def encode(x):
            encoded = self.G.encoder(x)
            mu = encoded[0]
            return mu

        # dot product are mean free
        key_views = ["frames views {}".format(i) for i in range(2)]
        sample_batched = next(self.loader_val_iter)
        anchor_emb = encode(sample_batched[key_views[0]].cuda())
        positive_emb = encode(sample_batched[key_views[1]].cuda())
        distance_pos = np.linalg.norm(anchor_emb.data.cpu().numpy() -
                                      positive_emb.data.cpu().numpy(),
                                      axis=1).mean()
        dots = []
        for e1, e2 in zip(anchor_emb.data.cpu().numpy(),
                          positive_emb.data.cpu().numpy()):
            dots.append(np.dot(e1 - e1.mean(), e2 - e2.mean()))
        product_pos = np.mean(dots)

        n = len(anchor_emb)
        emb_pos = anchor_emb.data.cpu().numpy()
        emb_neg = positive_emb.data.cpu().numpy()
        cnt, distance_neg, product_neg = 0., 0., 0.
        for i in range(n):
            for j in range(n):
                if j != i:
                    d_negative = np.linalg.norm(emb_pos[i] - emb_neg[j])
                    distance_neg += d_negative
                    product_neg += np.dot(emb_pos[i] - emb_pos[i].mean(),
                                          emb_neg[j] - emb_neg[j].mean())
                    cnt += 1
        distance_neg /= cnt
        product_neg /= cnt
        # distance_pos = np.asscalar(distance_pos)
        # product_pos = np.asscalar(product_pos)
        # distance_neg = np.asscalar(distance_neg)
        # product_neg = np.asscalar(product_neg)
        return distance_pos, product_pos, distance_neg, product_neg

    def build_model(self):
        self.rec_win = None
        self.rec_fake_win = None
        self.fixed_win = None
        self.d_plot = None
        self.d_plot_fake = None
        self.d_plot_vae = None
        self.d_plot_triplet_loss = None
        self.d_plot_distance_neg = None
        self.d_plot_distance_pos = None
        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.cam_view_z, self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize,
                               self.d_conv_dim).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)
        self.MSECriterion = nn.MSELoss()
        self.triplet_loss = nn.TripletMarginLoss()
        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D)

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))
class Trainer(object):
    def __init__(self, data_loader, config):

        # Data loader

        self.data_loader = data_loader

        # exact model and loss

        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters

        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path

        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model

        if self.pretrained_model:
            self.load_pretrained_model()

    def train(self):

        # Data iterator

        data_iter = iter(self.data_loader)
        step_per_epoch = len(self.data_loader)
        model_save_step = int(self.model_save_step * step_per_epoch)

        # Fixed input for debugging

        fixed_z = tensor2var(
            torch.normal(0,
                         torch.ones([self.batch_size, self.z_dim]) * 3))

        # Start with trained model

        if self.pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 0

        # Start time

        start_time = time.time()
        i = 0
        for step in range(start, self.total_step):

            # ================== Train D ================== #

            self.D.train()
            self.G.train()

            try:
                (real_images, _) = next(data_iter)
            except:
                data_iter = iter(self.data_loader)
                (real_images, _) = next(data_iter)

            # Compute loss with real images
            # dr1, dr2, df1, df2, gf1, gf2 are attention scores

            real_images = tensor2var(real_images)
            d_out_real = self.D(real_images)
            if self.adv_loss == 'wgan-gp':
                d_loss_real = -torch.mean(d_out_real)
            elif self.adv_loss == 'hinge':
                d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

            # apply Gumbel Softmax
            z = tensor2var(
                torch.normal(0,
                             torch.ones([real_images.size(0), self.z_dim]) *
                             3))
            # (fake_images, gf1, gf2) = self.G(z)
            (fake_images, gf2) = self.G(z)

            if i < 1:
                print('***** Result Image size now *****')
                print(fake_images.size())
                # print(gf1.size())
                print(gf2.size())
            i = i + 1

            d_out_fake = self.D(fake_images)

            if self.adv_loss == 'wgan-gp':
                d_loss_fake = d_out_fake.mean()
            elif self.adv_loss == 'hinge':
                d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()

            # Backward + Optimize

            d_loss = d_loss_real + d_loss_fake
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            if self.adv_loss == 'wgan-gp':

                # Compute gradient penalty

                alpha = torch.rand(real_images.size(0), 1, 1,
                                   1).cuda().expand_as(real_images)
                interpolated = Variable(alpha * real_images.data +
                                        (1 - alpha) * fake_images.data,
                                        requires_grad=True)
                out = self.D(interpolated)

                grad = torch.autograd.grad(
                    outputs=out,
                    inputs=interpolated,
                    grad_outputs=torch.ones(out.size()).cuda(),
                    retain_graph=True,
                    create_graph=True,
                    only_inputs=True,
                )[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize

                d_loss = self.lambda_gp * d_loss_gp

                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

            # ================== Train G and gumbel ================== #
            # Create random noise
            z = tensor2var(
                torch.normal(0,
                             torch.ones([real_images.size(0), self.z_dim]) *
                             3))
            # (fake_images, _, _) = self.G(z)
            (fake_images, _) = self.G(z)

            # Compute loss with fake images

            g_out_fake = self.D(fake_images)  # batch x n
            if self.adv_loss == 'wgan-gp':
                g_loss_fake = -g_out_fake.mean()
            elif self.adv_loss == 'hinge':
                g_loss_fake = -g_out_fake.mean()

            self.reset_grad()
            g_loss_fake.backward()
            self.g_optimizer.step()

            # Print out log info

            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print(
                    'Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, ave_gamma_l4: {:.4f}'
                    .format(
                        elapsed,
                        step + 1,
                        self.total_step,
                        step + 1,
                        self.total_step,
                        d_loss_real.data[0],
                        self.G.module.attn2.gamma.mean().data[0],
                    ))

                # (1) Log values of the losses (scalars)

                info = {
                    'd_loss_real': d_loss_real.data[0],
                    'd_loss_fake': d_loss_fake.data[0],
                    'd_loss': d_loss.data[0],
                    'g_loss_fake': g_loss_fake.data[0],
                    # 'ave_gamma_l3': self.G.module.attn1.gamma.mean().data[0],
                    'ave_gamma_l4': self.G.module.attn2.gamma.mean().data[0],
                }

                for (tag, value) in info.items():
                    self.logger.scalar_summary(tag, value, step + 1)

            # Sample images / Save and log

            if (step + 1) % self.sample_step == 0:

                # (2) Log values and gradients of the parameters (histogram)

                for (net, name) in zip([self.G, self.D], ['G_', 'D_']):
                    for (tag, value) in net.named_parameters():
                        tag = name + tag.replace('.', '/')
                        self.logger.histo_summary(tag,
                                                  value.data.cpu().numpy(),
                                                  step + 1)

                # (3) Log the tensorboard

                info = \
                    {'fake_images': (fake_images.view(fake_images.size())[:
                     16, :, :, :]).data.cpu().numpy(),
                     'real_images': (real_images.view(real_images.size())[:
                     16, :, :, :]).data.cpu().numpy()}

                # (fake_images, at1, at2) = self.G(fixed_z)
                (fake_images, at2) = self.G(fixed_z)
                if (step + 1) % (self.sample_step * 10) == 0:
                    save_image(
                        denorm(fake_images.data),
                        os.path.join(self.sample_path,
                                     '{}_fake.png'.format(step + 1)))

                # print('***** Fake Image size now *****')
                # print('fake_images ', fake_images.size())
                # print('at2 ', at2.size())   # B * N * N
                at2_4d = at2.view(
                    *(at2.size()[0],
                      at2.size()[1], int(math.sqrt(at2.size()[2])),
                      int(math.sqrt(at2.size()[2]))))  # W * N * W * H
                # print('at2_4d ', at2_4d.size())
                at2_mean = at2_4d.mean(dim=1, keepdim=False)  # B * W * H
                # print('at2_mean ', at2_mean.size())

                print('***** start create activation map *****')
                attn_list = []
                for i in range(at2.size()[0]):
                    # print('fake_images size: ',fake_images[i].size())
                    # print('at2 mean size', at2_mean[i].size())

                    f = BytesIO()
                    img = np.uint8(
                        fake_images[i, :, :, :].mul(255).data.cpu().numpy())
                    a = np.uint8(at2_mean[i, :, :].mul(255).data.cpu().numpy())
                    # print('image: ', img.shape)
                    # print('a shape: ',a.shape)

                    im_image = img.reshape(img.shape[1], img.shape[2],
                                           img.shape[0])
                    im_attn = cv2.applyColorMap(a, cv2.COLORMAP_JET)

                    img_with_heatmap = np.float32(im_attn) + np.float32(
                        im_image)
                    img_with_heatmap = img_with_heatmap / np.max(
                        img_with_heatmap)

                    attn_np = np.uint8((255 * img_with_heatmap).reshape(
                        img_with_heatmap.shape[2], img_with_heatmap.shape[0],
                        img_with_heatmap.shape[1]))
                    attn_torch = torch.from_numpy(attn_np)
                    # print('final attn image size: ', attn_torch.size())
                    attn_list.append(attn_torch.unsqueeze(0))

                attn_images = torch.cat(attn_list)
                print('attn images list: ', attn_images.size())
                info['attn_images'] = (attn_images.view(
                    attn_images.size())[:16, :, :, :]).numpy()

                for (tag, image) in info.items():
                    self.logger.image_summary(tag, image, step + 1)

            if (step + 1) % model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_G.pth'.format(step + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 '{}_D.pth'.format(step + 1)))

    def build_model(self):

        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize,
                               self.d_conv_dim).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])

        self.g_optimizer = torch.optim.Adam(filter(lambda p: \
                p.requires_grad, self.G.parameters()), self.g_lr,
                [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(filter(lambda p: \
                p.requires_grad, self.D.parameters()), self.d_lr,
                [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()

        # print networks

        print(self.G)
        print(self.D)

    def build_tensorboard(self):
        from logger import Logger
        #if os.path.exists(self.log_path):
        #    shutil.rmtree(self.log_path)
        #os.makedirs(self.log_path)
        self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def save_sample(self, data_iter):
        (real_images, _) = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))

    def save_gradient_images(self, gradient, file_name):
        """
            Exports the original gradient image

        Args:
            gradient (np arr): Numpy array of the gradient with shape (3, 224, 224)
            file_name (str): File name to be exported
        """
        if not os.path.exists('attn2/results'):
            os.makedirs('attn2/results')
        # Normalize
        gradient = gradient - gradient.min()
        gradient /= gradient.max()
        # Save image
        path = os.path.join('attn2/results', file_name + '.jpg')
        im = gradient
        if isinstance(im, np.ndarray):
            if len(im.shape) == 2:
                im = np.expand_dims(im, axis=0)
            if im.shape[0] == 1:
                # Converting an image with depth = 1 to depth = 3, repeating the same values
                # For some reason PIL complains when I want to save channel image as jpg without
                # additional format in the .save()
                im = np.repeat(im, 3, axis=0)
                # Convert to values to range 1-255 and W,H, D
            if im.shape[0] == 3:
                im = im.transpose(1, 2, 0) * 255
            im = Image.fromarray(im.astype(np.uint8))
        im.save(path)
Beispiel #13
0
class Tester(object):
    def __init__(self, data_loader, config):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.data_loader = data_loader

        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.ge_lr = config.ge_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.mura_class = config.mura_class
        self.mura_type = config.mura_type
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path,
                                            self.version)

        # Build tensorboard for debugiing
        self.build_tensorboard()

        # Build model
        self.build_model()

        # Load models
        self.load_pretrained_model()

    def test(self):
        data_iter = iter(self.data_loader)
        self.D.eval()
        self.E.eval()
        self.G.eval()

        with torch.no_grad():
            for i, data in enumerate(data_iter):

                val_images, val_labels = data
                val_images = tensor2var(val_images)

                # Run val images through models X -> E(X) -> G(E(X))
                z, ef1, ef2 = self.E(val_images)
                re_images, gf1, gf2 = self.G(z)

                dv, dv5, dv4, dv3, dvz, dva2, dva1 = self.D(val_images, z)
                dr, dr5, dr4, dr3, drz, dra2, dra1 = self.D(re_images, z)

                # Compute residual loss
                l1 = (re_images - val_images).abs()
                l2 = (re_images - val_images).pow(2).sqrt()
                # Computer feature matching loss
                ld = (dv - dr).abs().view((self.batch_size, -1)).mean(dim=1)
                ld5 = (dv5 - dr5).abs().view((self.batch_size, -1)).mean(dim=1)
                ld4 = (dv4 - dr4).abs().view((self.batch_size, -1)).mean(dim=1)
                ld3 = (dv3 - dr3).abs().view((self.batch_size, -1)).mean(dim=1)

                import ipdb
                ipdb.set_trace()

                plt.scatter(range(1, self.batch_size + 1), l1, c=val_labels)

    def build_tensorboard(self):
        '''Initialize tensorboard writer'''
        self.writer = SummaryWriter(self.log_path)

    def build_model(self):
        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).to(self.device)
        self.E = Encoder(self.batch_size, self.imsize, self.z_dim,
                         self.d_conv_dim).to(self.device)
        self.D = Discriminator(self.batch_size, self.imsize, self.z_dim,
                               self.d_conv_dim).to(self.device)
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.E = nn.DataParallel(self.E)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        self.ge_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad,
                   itertools.chain(self.G.parameters(), self.E.parameters())),
            self.ge_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr,
            [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.E)
        print(self.D)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))
        self.E.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_E.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.ge_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))