Beispiel #1
0
def train(data_loader, G, z_train, g_optimizer, z_sample, train_dict, args, myargs, **kwargs):
  for i, (imgs, _) in enumerate(tqdm.tqdm(data_loader)):
    train_dict['batches_done'] += 1
    summary = {}

    G.train()
    imgs = imgs.cuda()

    # train G
    z_train.sample_()
    f_imgs = G(z_train)

    sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(x=imgs.view(imgs.size(0), -1), y=f_imgs.view(f_imgs.size(0), -1),
                                                 epsilon=args.sinkhorn_eps, niter=args.sinkhorn_niter,
                                                 cuda=True, pi_detach=args.sinkhorn_pi_detach)
    summary['sinkhorn_d'] = sinkhorn_d.item()

    g_loss = sinkhorn_d
    summary['g_loss'] = g_loss.item()

    G.zero_grad()
    g_loss.backward()
    g_optimizer.step()

    if i % args.sample_every == 0:
      # sample images
      G.eval()
      f_imgs_sample = G(z_sample)
      merged_img = torchvision.utils.make_grid(f_imgs_sample, normalize=True, pad_value=1, nrow=16)
      myargs.writer.add_images('G_z', merged_img.view(1, *merged_img.shape), train_dict['batches_done'])
      merged_img = torchvision.utils.make_grid(imgs, normalize=True, pad_value=1, nrow=16)
      myargs.writer.add_images('x', merged_img.view(1, *merged_img.shape), train_dict['batches_done'])
      # checkpoint
      myargs.checkpoint.save_checkpoint(checkpoint_dict=myargs.checkpoint_dict, filename='ckpt.tar')
      # summary
      for key in summary:
        myargs.writer.add_scalar('train_vs_batch/%s'%key, summary[key], train_dict['batches_done'])

      G.train()
    def train(**kwargs):
        for i, (imgs, _) in enumerate(tqdm.tqdm(data_loader)):
            train_dict['batches_done'] += 1
            step = train_dict['batches_done']
            summary = {}
            summary_d_logits_mean = {}
            summary_wd = {}

            G.train()
            G_ema.train()

            imgs = imgs.cuda()
            bs = imgs.size(0)

            z_train.sample_()
            f_imgs = G(z_train[:bs])

            # train D
            with torch.no_grad():
                sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(
                    x=imgs.view(bs, -1),
                    y=f_imgs.view(bs, -1).detach(),
                    epsilon=args.sinkhorn_eps,
                    niter=args.sinkhorn_niter,
                    cuda=True,
                    pi_detach=args.sinkhorn_pi_detach)
                summary_wd['D_sinkhorn_d'] = sinkhorn_d.item()

            r_logit = D(imgs)
            r_logit_mean = r_logit.mean()
            f_logit = D(f_imgs.detach())
            f_logit_mean = f_logit.mean()
            summary_d_logits_mean['D_r_logit_mean'] = r_logit_mean.item()
            summary_d_logits_mean['D_f_logit_mean'] = f_logit_mean.item()

            # Wasserstein-1 Distance
            wd = r_logit_mean - f_logit_mean
            gp = gan_losses.wgan_gp_gradient_penalty(imgs.data, f_imgs.data, D)
            d_loss = -wd + gp * 10.0 + torch.relu(wd - sinkhorn_d.item())
            summary_wd['wd'] = wd.item()
            summary['gp'] = gp.item()
            summary['d_loss'] = d_loss.item()

            D.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            if step % args.n_critic == 0:
                # train G
                z_train.sample_()
                f_imgs = G(z_train)

                sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(
                    x=imgs.view(imgs.size(0), -1),
                    y=f_imgs.view(f_imgs.size(0), -1),
                    epsilon=args.sinkhorn_eps,
                    niter=args.sinkhorn_niter,
                    cuda=True,
                    pi_detach=args.sinkhorn_pi_detach)
                summary_wd['G_sinkhorn_d'] = sinkhorn_d.item()

                f_logit = D(f_imgs)
                f_logit_mean = f_logit.mean()
                g_loss = -f_logit_mean + args.lambda_sinkhorn * sinkhorn_d
                summary_d_logits_mean['G_f_logit_mean'] = f_logit_mean.item()
                summary['g_loss'] = g_loss.item()

                D.zero_grad()
                G.zero_grad()
                g_loss.backward()
                g_optimizer.step()

            # end iter
            ema.update(train_dict['batches_done'])

            if i % args.sample_every == 0:
                # sample images
                G.eval()
                G_ema.eval()
                G_z = G(z_sample)
                merged_img = torchvision.utils.make_grid(G_z,
                                                         normalize=True,
                                                         pad_value=1,
                                                         nrow=16)
                myargs.writer.add_images('G_z',
                                         merged_img.view(1, *merged_img.shape),
                                         train_dict['batches_done'])
                # G_ema
                G_ema_z = G_ema(z_sample)
                merged_img = torchvision.utils.make_grid(G_ema_z,
                                                         normalize=True,
                                                         pad_value=1,
                                                         nrow=16)
                myargs.writer.add_images('G_ema_z',
                                         merged_img.view(1, *merged_img.shape),
                                         train_dict['batches_done'])
                # x
                merged_img = torchvision.utils.make_grid(imgs,
                                                         normalize=True,
                                                         pad_value=1,
                                                         nrow=16)
                myargs.writer.add_images('x',
                                         merged_img.view(1, *merged_img.shape),
                                         train_dict['batches_done'])
                # checkpoint
                myargs.checkpoint.save_checkpoint(
                    checkpoint_dict=myargs.checkpoint_dict,
                    filename='ckpt.tar')
                # summary
                for key in summary:
                    myargs.writer.add_scalar('train_vs_batch/%s' % key,
                                             summary[key],
                                             train_dict['batches_done'])
                myargs.writer.add_scalars('train_vs_batch',
                                          summary_d_logits_mean,
                                          train_dict['batches_done'])
                myargs.writer.add_scalars('wd', summary_wd,
                                          train_dict['batches_done'])

                G.train()
            elif train_dict['batches_done'] <= 20000:
                for key in summary:
                    myargs.writer.add_scalar('train_vs_batch/%s' % key,
                                             summary[key],
                                             train_dict['batches_done'])
                myargs.writer.add_scalars('train_vs_batch',
                                          summary_d_logits_mean,
                                          train_dict['batches_done'])
                myargs.writer.add_scalars('wd', summary_wd,
                                          train_dict['batches_done'])
Beispiel #3
0
  def train_one_epoch(self, ):
    config = self.config.train_one_epoch
    if config.dummy_train:
      return
    myargs = self.myargs
    train_dict = self.train_dict
    pbar = tqdm.tqdm(self.data_loader, file=myargs.stdout)
    self._summary_create()
    for i, (imgs, _) in enumerate(pbar):
      self.G.train()
      self.D.train()
      train_dict['batches_done'] += 1

      imgs = imgs.cuda()
      bs = imgs.size(0)
      if bs != self.config.noise.batch_size_train:
        return
      self.z_train.sample_()
      with torch.no_grad():
        f_imgs = self.G(self.z_train[:bs])

      imgs.requires_grad_()
      f_imgs.requires_grad_()
      # train D
      D_r_logit = self.D(imgs)
      D_r_logit_mean = D_r_logit.mean()
      D_f_logit = self.D(f_imgs)
      D_f_logit_mean = D_f_logit.mean()
      self.summary_logit_mean['D_r_logit_mean'] = D_r_logit_mean.item()
      self.summary_logit_mean['D_f_logit_mean'] = D_f_logit_mean.item()

      self.d_optimizer.zero_grad()

      wd =  D_f_logit_mean - D_r_logit_mean
      # Backward gp loss in this func
      gp = gan_losses.wgan_div_gradient_penalty(
        real_imgs=imgs, fake_imgs=f_imgs, real_validity=D_r_logit,
        fake_validity=D_f_logit, backward=True, retain_graph=True
      )

      if config.bound_type == 'constant':
        raise NotImplemented
        D_loss = -wd + torch.relu(wd - float(config.bound))
        # D_loss = -wd + gp * config.gp_lambda + \
        #          torch.relu(wd - float(config.bound))
        self.summary_wd['bound'] = config.bound
      elif config.bound_type == 'sinkhorn':
        with torch.no_grad():
          sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(
            x=imgs.view(bs, -1), y=f_imgs.view(bs, -1).detach(),
            epsilon=config.sinkhorn_eps, niter=config.sinkhorn_niter,
            cuda=True, pi_detach=config.sinkhorn_pi_detach)
          self.summary_wd['D_sinkhorn_d'] = sinkhorn_d.item()
        D_loss = -wd + torch.relu(wd - sinkhorn_d.item())
      else:
        D_loss = -wd

      self.summary_wd['wd'] = wd.item()
      self.summary['gp'] = gp.item()
      self.summary['D_loss'] = D_loss.item()

      D_loss.backward()
      self.d_optimizer.step()

      if i % config.n_critic == 0:
        # train G
        self.z_train.sample_()
        f_imgs = self.G(self.z_train)

        D_f_logit = self.D(f_imgs)
        D_f_logit_mean = D_f_logit.mean()
        g_loss_only = D_f_logit_mean

        if config.bound_type == 'sinkhorn':
          sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(
            x=imgs.view(imgs.size(0), -1), y=f_imgs.view(f_imgs.size(0), -1),
            epsilon=config.sinkhorn_eps, niter=config.sinkhorn_niter,
            cuda=True, pi_detach=config.sinkhorn_pi_detach)
          self.summary_wd['G_sinkhorn_d'] = sinkhorn_d.item()
          G_loss = g_loss_only + config.sinkhorn_lambda * sinkhorn_d
        else:
          G_loss = g_loss_only
        self.summary_logit_mean['G_f_logit_mean'] = D_f_logit_mean.item()
        self.summary['g_loss_only'] = g_loss_only.item()
        self.summary['G_loss'] = G_loss.item()

        self.g_optimizer.zero_grad()
        G_loss.backward()
        self.g_optimizer.step()

        # end iter
        self.ema.update(train_dict['batches_done'])

      if i % config.sample_every == 0:

        # checkpoint
        myargs.checkpoint.save_checkpoint(
          checkpoint_dict=myargs.checkpoint_dict, filename='ckpt.tar')
        # summary
        self._summary_scalars()

      elif train_dict['batches_done'] <= config.sample_start_iter:
        self._summary_scalars()

      if i % (len(self.data_loader)//4) == 0:
        # save images
        self._summary_figures()
        self._summary_images(imgs=imgs, itr=self.train_dict['batches_done'])
Beispiel #4
0
    def train_one_epoch(self, ):
        config = self.config.train_one_epoch
        if config.dummy_train:
            return
        myargs = self.myargs
        train_dict = myargs.checkpoint_dict['train_dict']
        batch_size = self.config.noise.batch_size
        pbar = tqdm.tqdm(self.loaders[0], file=self.myargs.stdout)
        for i, (images, labels) in enumerate(pbar):
            if len(images) % batch_size != 0:
                return
            summary_dict = collections.defaultdict(dict)
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            self.G.train()
            self.D.train()
            self.GD.train()
            self.G_ema.train()
            images, labels = images.cuda(), labels.cuda()

            # Optionally toggle D and G's "require_grad"
            if config.toggle_grads:
                gan_utils.toggle_grad(self.D, True)
                gan_utils.toggle_grad(self.G, False)

            # How many chunks to split x and y into?
            x = torch.split(images, batch_size)
            y = torch.split(labels, batch_size)
            counter = 0
            for step_index in range(config.num_D_steps):
                self.G.optim.zero_grad()
                self.D.optim.zero_grad()
                if getattr(config, 'weigh_loss', False):
                    self.alpha_optim.zero_grad()
                # Increment the iteration counter
                train_dict['batches_done'] += 1
                # If accumulating gradients, loop multiple times before an optimizer step
                for accumulation_index in range(config.num_D_accumulations):
                    self.z_.sample_()
                    dy = y[counter]
                    gy = dy
                    # self.y_.sample_()
                    D_fake, D_real, G_z = self.GD(z=self.z_[:batch_size],
                                                  gy=gy,
                                                  x=x[counter],
                                                  dy=dy,
                                                  train_G=False,
                                                  split_D=config.split_D,
                                                  return_G_z=True)

                    # Compute components of D's loss, average them, and divide by
                    # the number of gradient accumulations
                    D_real_mean, D_fake_mean, wd, _ = gan_losses.wgan_discriminator_loss(
                        r_logit=D_real, f_logit=D_fake)
                    # entropy loss
                    if getattr(config, 'weigh_loss',
                               False) and config.use_entropy:
                        weight_logit = 1. / torch.stack(
                            self.alpha_params).exp()
                        weight = F.softmax(weight_logit)
                        entropy = (-weight * F.log_softmax(weight_logit)).sum()
                        (-config.entropy_lambda * entropy).backward()
                    # gp
                    if getattr(config, 'weigh_loss', False) and \
                            hasattr(self, 'alpha_gp'):
                        if config.use_entropy:
                            weight_logit = 1. / torch.stack(
                                self.alpha_params).exp()
                            alpha_gp_logit = 1. / self.alpha_gp.exp()
                            weight = alpha_gp_logit.exp() / weight_logit.exp(
                            ).sum()
                        else:
                            weight = 1. / self.alpha_gp.exp()
                        weight_gp = weight
                        sigma_gp = torch.exp(self.alpha_gp / 2.)
                        gp_loss, gp = gan_losses.wgan_gp_gradient_penalty_cond(
                            x=x[counter],
                            G_z=G_z,
                            gy=gy,
                            f=self.D,
                            backward=True,
                            gp_lambda=weight_gp,
                            return_gp=True)
                        (self.alpha_gp / 2.).backward()
                        summary_dict['sigma']['sigma_gp'] = sigma_gp.item()
                        summary_dict['weight']['weight_gp'] = weight_gp.item()
                        summary_dict['gp']['gp_loss'] = gp_loss.item()
                        summary_dict['gp']['gp_raw'] = gp.item()
                    else:
                        gp_loss = gan_losses.wgan_gp_gradient_penalty_cond(
                            x=x[counter],
                            G_z=G_z,
                            gy=gy,
                            f=self.D,
                            backward=True,
                            gp_lambda=config.gp_lambda)
                        summary_dict['gp']['gp_loss'] = gp_loss.item()

                    # wd
                    if getattr(config, 'weigh_loss', False) and \
                            hasattr(self, 'alpha_wd'):
                        if config.use_entropy:
                            weight_logit = 1. / torch.stack(
                                self.alpha_params).exp()
                            alpha_wd_logit = 1. / self.alpha_wd.exp()
                            weight = alpha_wd_logit.exp() / weight_logit.exp(
                            ).sum()
                        else:
                            weight = 1. / self.alpha_wd.exp()
                        weight_wd = weight
                        sigma_wd = torch.exp(self.alpha_wd / 2.)
                        D_loss = weight_wd.item() * (-wd) + \
                          weight_wd * wd.abs().item() + self.alpha_wd / 2.
                        summary_dict['sigma']['sigma_wd'] = sigma_wd.item()
                        summary_dict['weight']['weight_wd'] = weight_wd.item()
                        summary_dict['wd']['wd_loss'] = \
                          weight_wd.item() * wd.abs().item()
                        summary_dict['wd']['wd_raw'] = wd.item()
                    elif getattr(config, 'd_sinkhorn', False):
                        sinkhorn_c = config.sinkhorn_c
                        with torch.no_grad():
                            sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(
                                x=x[counter].view(batch_size, -1),
                                y=G_z.view(batch_size, -1).detach(),
                                epsilon=sinkhorn_c.sinkhorn_eps,
                                niter=sinkhorn_c.sinkhorn_niter,
                                cuda=True,
                                pi_detach=sinkhorn_c.sinkhorn_pi_detach)
                            summary_dict['wd'][
                                'D_sinkhorn_d'] = sinkhorn_d.item()
                        D_loss = -wd + torch.relu(wd - sinkhorn_d.item())
                        summary_dict['wd']['wd_raw'] = wd.item()
                    else:
                        D_loss = -wd
                        summary_dict['wd']['wd_raw'] = wd.item()

                    D_losses = D_loss / float(config.num_D_accumulations)
                    # Accumulate gradients
                    D_losses.backward()
                    counter += 1
                    summary_dict['D_logit_mean'][
                        'D_real_mean'] = D_real_mean.item()
                    summary_dict['D_logit_mean'][
                        'D_fake_mean'] = D_fake_mean.item()
                    summary_dict['scalars']['D_loss'] = D_loss.item()

                # End accumulation
                # Optionally apply ortho reg in D
                if config.D_ortho > 0.0:
                    # Debug print to indicate we're using ortho reg in D.
                    print('using modified ortho reg in D')
                    weight_regularity.ortho(self.D, config.D_ortho)

                # for name, value in self.D.named_parameters():
                #   self.myargs.writer.add_histogram(
                #     name, value.grad, train_dict['batches_done'])
                self.D.optim.step()
                if getattr(config, 'weigh_loss', False):
                    self.alpha_optim.step()

            # Optionally toggle "requires_grad"
            if config.toggle_grads:
                gan_utils.toggle_grad(self.D, False)
                gan_utils.toggle_grad(self.G, True)

            # Zero G's gradients by default before training G, for safety
            self.G.optim.zero_grad()
            if getattr(config, 'weigh_loss', False):
                self.alpha_optim.zero_grad()
            # If accumulating gradients, loop multiple times
            for accumulation_index in range(config.num_G_accumulations):
                self.z_.sample_()
                gy = dy
                # self.y_.sample_()
                D_fake, G_z = self.GD(z=self.z_,
                                      gy=gy,
                                      train_G=True,
                                      return_G_z=True,
                                      split_D=config.split_D)

                G_fake_mean, _ = gan_losses.wgan_generator_loss(f_logit=D_fake)
                # wd for generator
                if getattr(config, 'weigh_loss', False) and \
                          hasattr(self, 'alpha_g'):
                    if config.use_entropy:
                        weight_logit = 1. / torch.stack(
                            self.alpha_params).exp()
                        alpha_g_logit = 1. / self.alpha_g.exp()
                        weight = alpha_g_logit.exp() / weight_logit.exp().sum()
                    else:
                        weight = 1. / self.alpha_g.exp()
                    weight_g = weight
                    sigma_g = torch.exp(self.alpha_g / 2.)
                    wd_g = D_real_mean.item() - G_fake_mean
                    G_loss = weight_g.item() * (wd_g) + \
                             weight_g * wd_g.abs().item() + self.alpha_g / 2.
                    summary_dict['sigma']['sigma_g'] = sigma_g.item()
                    summary_dict['weight']['weight_g'] = weight_g.item()
                    summary_dict['wd']['wd_g_raw'] = wd_g.item()
                    summary_dict['wd']['wd_g_loss'] = weight_g.item(
                    ) * wd_g.abs().item()
                elif getattr(config, 'g_sinkhorn', False):
                    sinkhorn_c = config.sinkhorn_c
                    sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(
                        x=x[-1].view(x[-1].size(0), -1),
                        y=G_z.view(G_z.size(0), -1),
                        epsilon=sinkhorn_c.sinkhorn_eps,
                        niter=sinkhorn_c.sinkhorn_niter,
                        cuda=True,
                        pi_detach=sinkhorn_c.sinkhorn_pi_detach)
                    summary_dict['wd']['G_sinkhorn_d'] = sinkhorn_d.item()
                    g_loss_only = -G_fake_mean
                    summary_dict['scalars']['g_loss_only'] = g_loss_only.item()
                    G_loss = g_loss_only + sinkhorn_c.sinkhorn_lambda * sinkhorn_d
                else:
                    G_loss = -G_fake_mean
                # Accumulate gradients
                G_loss = G_loss / float(config['num_G_accumulations'])
                G_loss.backward()

                summary_dict['D_logit_mean']['G_fake_mean'] = G_fake_mean.item(
                )
                summary_dict['scalars']['G_loss'] = G_loss.item()

            # Optionally apply modified ortho reg in G
            if config.G_ortho > 0.0:
                # Debug print to indicate we're using ortho reg in G
                print('using modified ortho reg in G')
                # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
                weight_regularity.ortho(
                    self.G,
                    config.G_ortho,
                    blacklist=[param for param in self.G.shared.parameters()])
            self.G.optim.step()
            if getattr(config, 'weigh_loss', False):
                self.alpha_optim.step()

            # If we have an ema, update it, regardless of if we test with it or not
            self.ema.update(train_dict['batches_done'])
            pbar.set_description('wd=%f' % wd.item())

            if i % config.sample_every == 0:
                # weights
                self.save_checkpoint(filename='ckpt.tar')
                self.summary_dicts(summary_dicts=summary_dict,
                                   prefix='train_one_epoch',
                                   step=train_dict['batches_done'])

                # singular values
                # G_sv_dict = utils_BigGAN.get_SVs(self.G, 'G')
                # D_sv_dict = utils_BigGAN.get_SVs(self.D, 'D')
                # myargs.writer.add_scalars('sv/G_sv_dict', G_sv_dict,
                #                           train_dict['batches_done'])
                # myargs.writer.add_scalars('sv/D_sv_dict', D_sv_dict,
                #                           train_dict['batches_done'])

            elif train_dict['batches_done'] <= config.sample_start_iter:
                # scalars
                self.summary_dicts(summary_dicts=summary_dict,
                                   prefix='train_one_epoch',
                                   step=train_dict['batches_done'])

            if (i + 1) % (len(self.loaders[0]) // 2) == 0:
                # save images
                # self.summary_figures(summary_dicts=summary_dict,
                #                      prefix='train_one_epoch')
                # samples
                self._summary_images(imgs=x[0])