class ThreeGAN(SwapGAN): def __init__(self, *args, **kwargs): super(ThreeGAN, self).__init__(*args, **kwargs) self.dirichlet = Dirichlet(torch.FloatTensor([1.0, 1.0, 1.0])) def sampler(self, bs, f, is_2d, **kwargs): """Sampler function, which outputs an alpha which you can use to produce a convex combination between two examples. :param bs: batch size :param f: number of units / feature maps at encoding :param is_2d: is the bottleneck a 2d tensor? :returns: an alpha of shape `(bs, f)` is `is_2d` is set, otherwise `(bs, f, 1, 1)`. :rtype: """ if self.mixer == 'mixup': alpha = self.dirichlet.sample_n(bs) elif self.mixer == 'fm': #alpha = torch.randint(0, 3, size=(bs, f, 1, 1)).long() alpha = torch.zeros((bs, 3, f)).float() for b in range(bs): for j in range(alpha.shape[2]): alpha[b, np.random.randint(0, alpha.shape[1]), j] = 1. else: raise Exception("Not implemented for mixup scheme: %s" % self.mixer) if self.use_cuda: alpha = alpha.cuda() return alpha def train_on_instance(self, x_batch, y_batch, **kwargs): self._train() for key in self.optim: self.optim[key].zero_grad() ## ------------------ ## Generator training ## ------------------ enc = self.generator.encode(x_batch) dec_enc = self.generator.decode(enc) recon_loss = torch.mean(torch.abs(dec_enc - x_batch)) disc_g_recon_loss = self.gan_loss_fn(self.disc_x(dec_enc)[0], 1) perm = torch.randperm(x_batch.size(0)) perm2 = torch.randperm(x_batch.size(0)) is_2d = True if len(enc.size()) == 2 else False alpha = self.sampler(x_batch.size(0), enc.size(1), is_2d) if self.mixer == 'mixup': enc_mix = alpha[:, 0].view(x_batch.size(0), 1, 1, 1)*enc + \ alpha[:, 1].view(x_batch.size(0), 1, 1, 1)*enc[perm] + \ alpha[:, 2].view(x_batch.size(0), 1, 1, 1)*enc[perm2] else: enc_mix = alpha[:, 0].view(x_batch.size(0), enc.size(1), 1, 1)*enc + \ alpha[:, 1].view(x_batch.size(0), enc.size(1), 1, 1)*enc[perm] + \ alpha[:, 2].view(x_batch.size(0), enc.size(1), 1, 1)*enc[perm2] dec_enc_mix = self.generator.decode(enc_mix) disc_g_mix_loss = self.gan_loss_fn(self.disc_x(dec_enc_mix)[0], 1) if self.beta > 0: consist_loss = torch.mean( torch.abs(self.generator.encode(dec_enc_mix) - enc_mix)) else: consist_loss = torch.FloatTensor([0.]) if self.use_cuda: consist_loss = consist_loss.cuda() gen_loss = self.lamb * recon_loss if self.disable_g_recon is False: gen_loss = gen_loss + disc_g_recon_loss if self.disable_mix is False: gen_loss = gen_loss + disc_g_mix_loss + self.beta * consist_loss if (kwargs['iter'] - 1) % self.update_g_every == 0: gen_loss.backward() self.optim['g'].step() ## ---------------------- ## Discriminator on image ## ---------------------- self.optim['disc_x'].zero_grad() d_losses = [] # Do real images. dx_out, cx_out = self.disc_x(x_batch) d_x_real = self.gan_loss_fn(dx_out, 1) d_losses.append(d_x_real) if self.disable_g_recon is False: # Do reconstruction. d_x_fake = self.gan_loss_fn(self.disc_x(dec_enc.detach())[0], 0) d_losses.append(d_x_fake) if self.disable_mix is False: # Do mixes. d_out_mix = self.gan_loss_fn( self.disc_x(dec_enc_mix.detach())[0], 0) d_losses.append(d_out_mix) d_x = sum(d_losses) d_x.backward() self.optim['disc_x'].step() ## ---------------------------------------------- ## Classifier on bottleneck (NOTE: for debugging) ## ---------------------------------------------- if self.cls_enc is not None: self.optim['cls_enc'].zero_grad() if hasattr(self.cls_enc, 'legacy'): enc_flat = enc.detach().view(-1, self.cls_enc.n_in) else: enc_flat = enc.detach() cls_enc_out = self.cls_enc(enc_flat) cls_enc_preds_log = torch.log_softmax(cls_enc_out, dim=1) cls_enc_loss = nn.NLLLoss()(cls_enc_preds_log, y_batch.argmax(dim=1).long()) with torch.no_grad(): cls_enc_preds = torch.softmax(cls_enc_out, dim=1) cls_enc_acc = (cls_enc_preds.argmax(dim=1) == y_batch.argmax( dim=1).long()).float().mean() cls_enc_loss.backward() self.optim['cls_enc'].step() losses = { 'gen_loss': gen_loss.item(), 'disc_g_recon': disc_g_recon_loss.item(), 'disc_g_mix': disc_g_mix_loss.item(), 'recon': recon_loss.item(), 'consist': consist_loss.item(), 'd_x': d_x.item() / len(d_losses) } if self.cls_enc is not None: losses['cls_enc_loss'] = cls_enc_loss.item() losses['cls_enc_acc'] = cls_enc_acc.item() outputs = { 'recon': dec_enc, 'mix': dec_enc_mix, 'perm': perm, 'input': x_batch, } return losses, outputs def eval_on_instance(self, x_batch, y_batch, **kwargs): self._eval() with torch.no_grad(): enc = self.generator.encode(x_batch) dec_enc = self.generator.decode(enc) recon_loss = torch.mean(torch.abs(dec_enc - x_batch)) #disc_g_recon_loss = self.gan_loss_fn(self.disc_x(dec_enc)[0], 0) perm = torch.randperm(x_batch.size(0)) perm2 = torch.randperm(x_batch.size(0)) is_2d = True if len(enc.size()) == 2 else False alpha = self.sampler(x_batch.size(0), enc.size(1), is_2d) if self.mixer == 'mixup': enc_mix = alpha[:, 0].view(x_batch.size(0), 1, 1, 1)*enc + \ alpha[:, 1].view(x_batch.size(0), 1, 1, 1)*enc[perm] + \ alpha[:, 2].view(x_batch.size(0), 1, 1, 1)*enc[perm2] else: enc_mix = alpha[:, 0].view(x_batch.size(0), enc.size(1), 1, 1)*enc + \ alpha[:, 1].view(x_batch.size(0), enc.size(1), 1, 1)*enc[perm] + \ alpha[:, 2].view(x_batch.size(0), enc.size(1), 1, 1)*enc[perm2] dec_enc_mix = self.generator.decode(enc_mix) losses = {} if self.cls_enc is not None: if hasattr(self.cls_enc, 'legacy'): enc_flat = enc.detach().view(-1, self.cls_enc.n_in) else: enc_flat = enc.detach() cls_enc_out = self.cls_enc(enc_flat) cls_enc_preds = torch.softmax(cls_enc_out, dim=1) cls_enc_acc = (cls_enc_preds.argmax(dim=1) == y_batch.argmax( dim=1).long()).float().mean() losses['cls_enc_Acc'] = cls_enc_acc.item() outputs = { 'recon': dec_enc, 'mix': dec_enc_mix, 'perm': perm, 'input': x_batch, } return losses, outputs
class ThreeGAN(SwapGAN): """This class is old. You should instead use KGAN, which allows you to choose the k value for mixing. """ def __init__(self, *args, **kwargs): super(ThreeGAN, self).__init__(*args, **kwargs) if self.cls > 0: raise NotImplementedError("ThreeGAN not implemented for cls > 0") self.dirichlet = Dirichlet(torch.FloatTensor([1.0, 1.0, 1.0])) def sampler(self, bs, f, is_2d, **kwargs): """Sampler function, which outputs an alpha which you can use to produce a convex combination between two examples. :param bs: batch size :param f: number of units / feature maps at encoding :param is_2d: is the bottleneck a 2d tensor? :returns: an alpha of shape `(bs, f)` is `is_2d` is set, otherwise `(bs, f, 1, 1)`. :rtype: """ if self.mixer == 'mixup': with torch.no_grad(): alpha = self.dirichlet.sample_n(bs) if not is_2d: alpha = alpha.reshape(-1, alpha.size(1), 1, 1) elif self.mixer == 'fm': #alpha = torch.randint(0, 3, size=(bs, f, 1, 1)).long() if is_2d: alpha = np.zeros((bs, 3, f)).astype(np.float32) else: alpha = np.zeros((bs, 3, f, 1, 1)).astype(np.float32) for b in range(bs): for j in range(f): alpha[b, np.random.randint(0, 3), j] = 1. alpha = torch.from_numpy(alpha).float() else: raise Exception("Not implemented for mixup scheme: %s" % self.mixer) if self.use_cuda: alpha = alpha.cuda() return alpha def mix(self, enc): """Perform mixing operation on the encoding `enc`. :param enc: encoding of shape (bs, f) (if 2d) or (bs, f, h, w) if 4d. """ perm = torch.randperm(enc.size(0)) perm2 = torch.randperm(enc.size(0)) is_2d = True if len(enc.size()) == 2 else False alpha = self.sampler(enc.size(0), enc.size(1), is_2d) if self.mixer == 'mixup': enc_mix = alpha[:, 0:1]*enc + \ alpha[:, 1:2]*enc[perm] + \ alpha[:, 2:3]*enc[perm2] else: enc_mix = alpha[:, 0]*enc + \ alpha[:, 1]*enc[perm] + \ alpha[:, 2]*enc[perm2] return enc_mix, perm
class KGAN(SwapGAN): def __init__(self, k=10, *args, **kwargs): super(KGAN, self).__init__(*args, **kwargs) if self.cls > 0: raise NotImplementedError("FourGAN not implemented for cls > 0") self.dirichlet = Dirichlet(torch.FloatTensor([1.0] * k)) self.k = k def sampler(self, bs, f, is_2d, **kwargs): """Sampler function, which outputs an alpha which you can use to produce a convex combination between two examples. :param bs: batch size :param f: number of units / feature maps at encoding :param is_2d: is the bottleneck a 2d tensor? :returns: an alpha of shape `(bs, f)` is `is_2d` is set, otherwise `(bs, f, 1, 1)`. :rtype: """ if self.mixer == 'mixup': with torch.no_grad(): alpha = self.dirichlet.sample_n(bs) if not is_2d: alpha = alpha.reshape(-1, alpha.size(1), 1, 1) elif self.mixer == 'fm': #alpha = torch.randint(0, 3, size=(bs, f, 1, 1)).long() if is_2d: alpha = np.zeros((bs, self.k, f)).astype(np.float32) else: alpha = np.zeros((bs, self.k, f, 1, 1)).astype(np.float32) for b in range(bs): for j in range(f): alpha[b, np.random.randint(0, self.k), j] = 1. alpha = torch.from_numpy(alpha).float() else: raise Exception("Not implemented for mixup scheme: %s" % self.mixer) if self.use_cuda: alpha = alpha.cuda() return alpha def mix(self, enc): """Perform mixing operation on the encoding `enc`. :param enc: encoding of shape (bs, f) (if 2d) or (bs, f, h, w) if 4d. """ perms = [torch.arange(0, enc.size(0))] + \ [torch.randperm(enc.size(0)) for _ in range(self.k-1)] is_2d = True if len(enc.size()) == 2 else False alpha = self.sampler(enc.size(0), enc.size(1), is_2d) enc_mix = 0. #import pdb #pdb.set_trace() if self.mixer == 'mixup': for i in range(len(perms)): enc_mix += alpha[:, i:i + 1] * enc[perms[i]] else: for i in range(len(perms)): enc_mix = alpha[:, i] * enc[perms[i]] return enc_mix, perms[0]