def cond_ll(self, inputs, targets, lengths, z, pad_id): init_hidden = torch.tanh(self.fc(z)).unsqueeze(0) init_hidden = [ hn.contiguous() for hn in torch.chunk(init_hidden, 2, 2) ] dec_embeds = self.embed(inputs) outputs, _ = self.decoder(dec_embeds, lengths, init_hidden=init_hidden) outputs = self.fcout(outputs) loss = recon_loss(outputs, targets, pad_id).expand(z.size(0), 1) / z.size(0) return -loss
def baseline(args, data_iter, model, optimizer, epoch, train=True): batch_size = args.batch_size size = len(data_iter.dataset) if train: model.train() else: model.eval() # data_iter.init_epoch() re_loss = 0 r_re_loss = 0 kl_divergence = 0 r_kl_divergence = 0 discriminator_loss = 0 nll = 0 for i, (data, label) in enumerate(data_iter): data = data.to(args.device) disloss = torch.zeros(1).to(args.device) if train: recon, q_z, p_z, z = model(data) recon = recon.view(-1, data.size(-2), data.size(-1)) reloss = recon_loss(recon, data) # sum over batch kld = total_kld(q_z, p_z) # sum over batch optimizer.zero_grad() loss = (reloss + kld) / batch_size loss.backward() optimizer.step() else: angles = torch.randint(0, 3, (data.size(0), )).to(args.device) r_data = batch_rotate(data.clone(), angles) r_recon, r_qz, r_pz, r_z = model(r_data) r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1)) reloss = recon_loss(r_recon, r_data) kld = total_kld(r_qz, r_pz) re_loss += reloss.item() / size kl_divergence += kld.item() / size discriminator_loss += disloss.item() / size nll = re_loss + kl_divergence return nll, re_loss, kl_divergence, discriminator_loss
def train(args): if args.c_dim != len(args.selected_attrs): print("c_dim must be the same as the num of selected attributes. Modified c_dim.") args.c_dim = len(args.selected_attrs) # Dump the config information. config = dict() print("Used config:") for k in args.__dir__(): if not k.startswith("_"): config[k] = getattr(args, k) print("'{}' : {}".format(k, getattr(args, k))) # Prepare Generator and Discriminator based on user config. generator = functools.partial( model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num) discriminator = functools.partial(model.discriminator, image_size=args.image_size, conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num) x_real = nn.Variable( [args.batch_size, 3, args.image_size, args.image_size]) label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1]) label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1]) with nn.parameter_scope("dis"): dis_real_img, dis_real_cls = discriminator(x_real) with nn.parameter_scope("gen"): x_fake = generator(x_real, label_trg) x_fake.persistent = True # to retain its value during computation. # get an unlinked_variable of x_fake x_fake_unlinked = x_fake.get_unlinked_variable() with nn.parameter_scope("dis"): dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked) # ---------------- Define Loss for Discriminator ----------------- d_loss_real = (-1) * loss.gan_loss(dis_real_img) d_loss_fake = loss.gan_loss(dis_fake_img) d_loss_cls = loss.classification_loss(dis_real_cls, label_org) d_loss_cls.persistent = True # Gradient Penalty. alpha = F.rand(shape=(args.batch_size, 1, 1, 1)) x_hat = F.mul2(alpha, x_real) + \ F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked) with nn.parameter_scope("dis"): dis_for_gp, _ = discriminator(x_hat) grads = nn.grad([dis_for_gp], [x_hat]) l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5 d_loss_gp = F.mean((l2norm - 1.0) ** 2.0) # total discriminator loss. d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \ d_loss_cls + args.lambda_gp * d_loss_gp # ---------------- Define Loss for Generator ----------------- g_loss_fake = (-1) * loss.gan_loss(dis_fake_img) g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg) g_loss_cls.persistent = True # Reconstruct Images. with nn.parameter_scope("gen"): x_recon = generator(x_fake_unlinked, label_org) x_recon.persistent = True g_loss_rec = loss.recon_loss(x_real, x_recon) g_loss_rec.persistent = True # total generator loss. g_loss = g_loss_fake + args.lambda_rec * \ g_loss_rec + args.lambda_cls * g_loss_cls # -------------------- Solver Setup --------------------- d_lr = args.d_lr # initial learning rate for Discriminator g_lr = args.g_lr # initial learning rate for Generator solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2) solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2) # register parameters to each solver. with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) # -------------------- Create Monitors -------------------- monitor = Monitor(args.monitor_path) monitor_d_cls_loss = MonitorSeries( 'real_classification_loss', monitor, args.log_step) monitor_g_cls_loss = MonitorSeries( 'fake_classification_loss', monitor, args.log_step) monitor_loss_dis = MonitorSeries( 'discriminator_loss', monitor, args.log_step) monitor_recon_loss = MonitorSeries( 'reconstruction_loss', monitor, args.log_step) monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step) monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step) # -------------------- Prepare / Split Dataset -------------------- using_attr = args.selected_attrs dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr) random.seed(313) # use fixed seed. random.shuffle(dataset) # shuffle dataset. test_dataset = dataset[-2000:] # extract 2000 images for test if args.num_data: # Use training data partially. training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)] else: training_dataset = dataset[:-2000] print("Use {} images for training.".format(len(training_dataset))) # create data iterators. load_func = functools.partial(stargan_load_func, dataset=training_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) data_iterator = data_iterator_simple(load_func, len( training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) load_func_test = functools.partial(stargan_load_func, dataset=test_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) test_data_iterator = data_iterator_simple(load_func_test, len( test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) # Keep fixed test images for intermediate translation visualization. test_real_ndarray, test_label_ndarray = test_data_iterator.next() test_label_ndarray = test_label_ndarray.reshape( test_label_ndarray.shape + (1, 1)) # -------------------- Training Loop -------------------- one_epoch = data_iterator.size // args.batch_size num_max_iter = args.max_epoch * one_epoch for i in range(num_max_iter): # Get real images and labels. real_ndarray, label_ndarray = data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray # Generate target domain labels randomly. rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] # ---------------- Train Discriminator ----------------- # generate fake image. x_fake.forward(clear_no_need_grad=True) d_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() d_loss.backward(clear_buffer=True) solver_dis.update() monitor_loss_dis.add(i, d_loss.d.item()) monitor_d_cls_loss.add(i, d_loss_cls.d.item()) monitor_time.add(i) # -------------- Train Generator -------------- if (i + 1) % args.n_critic == 0: g_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() solver_gen.zero_grad() x_fake_unlinked.grad.zero() g_loss.backward(clear_buffer=True) x_fake.backward(grad=None) solver_gen.update() monitor_loss_gen.add(i, g_loss.d.item()) monitor_g_cls_loss.add(i, g_loss_cls.d.item()) monitor_recon_loss.add(i, g_loss_rec.d.item()) monitor_time.add(i) if (i + 1) % args.sample_step == 0: # save image. save_results(i, args, x_real, x_fake, label_org, label_trg, x_recon) if args.test_during_training: # translate images from test dataset. x_real.d, label_org.d = test_real_ndarray, test_label_ndarray label_trg.d = test_label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False) # Learning rates get decayed if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0: g_lr = max(0, g_lr - (args.lr_update_step * args.g_lr / float(0.5 * num_max_iter))) d_lr = max(0, d_lr - (args.lr_update_step * args.d_lr / float(0.5 * num_max_iter))) solver_gen.set_learning_rate(g_lr) solver_dis.set_learning_rate(d_lr) print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # Save parameters and training config. param_name = 'trained_params_{}.h5'.format( datetime.datetime.today().strftime("%m%d%H%M")) param_path = os.path.join(args.model_save_path, param_name) nn.save_parameters(param_path) config["pretrained_params"] = param_name with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f: json.dump(config, f) # -------------------- Translation on test dataset -------------------- for i in range(args.num_test): real_ndarray, label_ndarray = test_data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False)
def main(args): print("Loading data") dataset = args.data.rstrip('/').split('/')[-1] torch.cuda.set_device(args.cuda) device = args.device if dataset == 'mnist': train_loader, test_loader = get_mnist(args.batch_size, 'data/mnist') num = 10 elif dataset == 'fashion': train_loader, test_loader = get_fashion_mnist(args.batch_size, 'data/fashion') num = 10 elif dataset == 'svhn': train_loader, test_loader, _ = get_svhn(args.batch_size, 'data/svhn') num = 10 elif dataset == 'stl': train_loader, test_loader, _ = get_stl10(args.batch_size, 'data/stl10') elif dataset == 'cifar': train_loader, test_loader = get_cifar(args.batch_size, 'data/cifar') num = 10 elif dataset == 'chair': train_loader, test_loader = get_chair(args.batch_size, '~/data/rendered_chairs') num = 1393 elif dataset == 'yale': train_loader, test_loader = get_yale(args.batch_size, 'data/yale') num = 38 model = VAE(28 * 28, args.code_dim, args.batch_size, num, dataset).to(device) phi = nn.Sequential( nn.Linear(args.code_dim, args.phi_dim), nn.LeakyReLU(0.2, True), ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer_phi = torch.optim.Adam(phi.parameters(), lr=args.lr) criterion = nn.MSELoss(reduction='sum') for epoch in range(args.epochs): re_loss = 0 kl_div = 0 size = len(train_loader.dataset) for data, target in train_loader: data, target = data.squeeze(1).to(device), target.to(device) c = F.one_hot(target.long(), num_classes=num).float() output, q_z, p_z, z = model(data, c) hsic = HSIC(phi(z), target.long(), num) if dataset == 'mnist' or dataset == 'fashion': reloss = recon_loss(output, data.view(-1, 28 * 28)) else: reloss = criterion(output, data) kld = total_kld(q_z, p_z) loss = reloss + kld + args.c * hsic optimizer.zero_grad() loss.backward() optimizer.step() optimizer_phi.zero_grad() neg = -HSIC(phi(z.detach()), target.long(), num) neg.backward() optimizer_phi.step() re_loss += reloss.item() / size kl_div += kld.item() / size print('-' * 50) print( " Epoch {} |re loss {:5.2f} | kl div {:5.2f} | hs {:5.2f}".format( epoch, re_loss, kl_div, hsic)) for data, target in test_loader: data, target = data.squeeze(1).to(device), target.to(device) c = F.one_hot(target.long(), num_classes=num).float() output, _, _, z = model(data, c) break if dataset == 'mnist' or dataset == 'fashion': img_size = [data.size(0), 1, 28, 28] else: img_size = [data.size(0), 3, 32, 32] images = [data.view(img_size)[:30].cpu()] for i in range(10): c = F.one_hot(torch.ones(z.size(0)).long() * i, num_classes=num).float().to(device) output = model.decoder(torch.cat((z, c), dim=-1)) images.append(output.view(img_size)[:30].cpu()) images = torch.cat(images, dim=0) save_image(images, 'imgs/recon_c{}_{}.png'.format(int(args.c), dataset), nrow=30) torch.save(model.state_dict(), 'vae_c{}_{}.pt'.format(int(args.c), dataset)) # z = p_z.sample() # for i in range(10): # c = F.one_hot(torch.ones(z.size(0)).long()*i, num_classes=10).float().to(device) # output = model.decoder(torch.cat((z, c), dim=-1)) # n = min(z.size(0), 8) # save_image(output.view(z.size(0), 1, 28, 28)[:n].cpu(), 'imgs/recon_{}.png'.format(i), nrow=n) if args.tsne: datas, targets = [], [] for i, (data, target) in enumerate(test_loader): datas.append(data), targets.append(target) if i >= 5: break data, target = torch.cat(datas, dim=0), torch.cat(targets, dim=0) c = F.one_hot(target.long(), num_classes=num).float() _, _, _, z = model(data.to(args.device), c.to(args.device)) z, target = z.detach().cpu().numpy(), target.cpu().numpy() tsne = TSNE(n_components=2, init='pca', random_state=0) z_2d = tsne.fit_transform(z) plt.figure(figsize=(6, 5)) plot_embedding(z_2d, target) plt.savefig('tsnes/tsne_c{}_{}.png'.format(int(args.c), dataset))
def train(self): # Create the dataloader self.dataloader = self.create_dataloader() for e in range(self.num_epochs): discriminator_loss = 0 va_loss = 0 for i, states_batch in enumerate(self.dataloader): states_true_1, states_true_2 = states_batch states_true_1 = states_true_1.to(self.device) states_true_2 = states_true_2.to(self.device) reconstruction, latent, mu, logvar = self.model(states_true_1) vae_recon_loss = loss.recon_loss(states_true_1, reconstruction) vae_kl_divergence = loss.kl_divergence(mu, logvar) d_z = self.disc(latent) vae_tc_loss = (d_z[:, :1] - d_z[:, 1:]).mean() vae_loss = vae_recon_loss + vae_kl_divergence + self.beta * vae_tc_loss self.optim_vae.zero_grad() vae_loss.backward(retain_graph=True) self.optim_vae.step() va_loss += vae_loss.item() states_true_2 = states_true_2.to(self.device) z_prime = self.model(states_true_2, no_dec=True) z_pperm = loss.permute_dims(z_prime).detach() D_z_pperm = self.disc(z_pperm) try: D_tc_loss = 0.5 * (F.cross_entropy(d_z, self.zeros) + F.cross_entropy(D_z_pperm, self.ones)) except: batch_size, _ = d_z.shape ones = torch.ones(batch_size, dtype=torch.long, device=self.device) zeros = torch.zeros(batch_size, dtype=torch.long, device=self.device) D_tc_loss = 0.5 * (F.cross_entropy(d_z, zeros) + F.cross_entropy(D_z_pperm, ones)) self.optim_disc.zero_grad() D_tc_loss.backward() self.optim_disc.step() discriminator_loss += D_tc_loss.item() # Add the loss to tensorboard self.writer.add_scalar('data/disc_loss', discriminator_loss / len(self.dataloader), e) self.writer.add_scalar('data/vae_loss', va_loss / len(self.dataloader), e) self.writer.close() # Save model self.save_model() # GIF visualization self.visualize_traverse()
def run(args, data_iter, model, optimizer, epoch, train=True): batch_size = args.batch_size size = len(data_iter.dataset) if train: model.train() else: model.eval() # data_iter.init_epoch() re_loss = 0 kl_divergence = 0 discriminator_loss = 0 nll = 0 for i, (data, label) in enumerate(data_iter): data = data.to(args.device) recon, q_z, p_z, z = model(data) recon = recon.view(-1, data.size(-2), data.size(-1)) reloss = recon_loss(recon, data) # sum over batch kld = total_kld(q_z, p_z) # sum over batch disloss = torch.zeros(1).to(args.device) if args.ro: disloss, r_reloss, r_kld = [], [], [] for d in range(1, len(rotations)): angles = torch.tensor([d], dtype=torch.long, device=args.device).expand(data.size(0)) r_data = batch_rotate(data.clone(), angles) r_recon, r_qz, r_pz, r_z = model(r_data) r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1)) D_z = D(r_z) disloss.append(disc_loss(D_z, angles)) # sum over batch r_reloss.append(recon_loss(r_recon, r_data)) r_kld.append(total_kld(r_qz, r_pz)) disloss = sum(disloss) # / (len(rotations)-1) r_reloss = sum(r_reloss) # / (len(rotations)-1) r_kld = sum(r_kld) # / (len(rotations)-1) # angles = torch.randint(0, 3, (data.size(0), )).to(args.device) # r_data = batch_rotate(data.clone(), angles) # r_recon, r_qz, r_pz, r_z = model(r_data) # r_recon = r_recon.view(-1, 1, data.size(-2), data.size(-1)) # D_z = D(r_z) # disloss = disc_loss(D_z, angles) # sum over batch # r_reloss = recon_loss(r_recon, r_data) # r_kld = total_kld(r_qz, r_pz) if train: if args.ro: optimizer_D.zero_grad() D_loss = disloss / batch_size D_loss.backward(retain_graph=True) optimizer_D.step() optimizer.zero_grad() loss = (reloss + kld + r_reloss + r_kld - disloss) / batch_size loss.backward() optimizer.step() else: optimizer.zero_grad() loss = (reloss + kld) / batch_size loss.backward() optimizer.step() re_loss += reloss.item() / size kl_divergence += kld.item() / size discriminator_loss += disloss.item() / size nll = re_loss + kl_divergence return nll, re_loss, kl_divergence, discriminator_loss
def train(args): # Variable size. bs, ch, h, w = args.batch_size, 3, args.loadSizeH, args.loadSizeW # Determine normalization method. if args.norm == "instance": norm_layer = functools.partial(PF.instance_normalization, fix_parameters=True, no_bias=True, no_scale=True) else: norm_layer = PF.batch_normalization # Prepare Generator and Discriminator based on user config. generator = functools.partial(models.generator, input_nc=args.input_nc, output_nc=args.output_nc, ngf=args.ngf, norm_layer=norm_layer, use_dropout=False, n_blocks=9, padding_type='reflect') discriminator = functools.partial(models.discriminator, input_nc=args.output_nc, ndf=args.ndf, n_layers=args.n_layers_D, norm_layer=norm_layer, use_sigmoid=False) # --------------------- Computation Graphs -------------------- # Input images and masks of both source / target domain x = nn.Variable([bs, ch, h, w], need_grad=False) a = nn.Variable([bs, 1, h, w], need_grad=False) y = nn.Variable([bs, ch, h, w], need_grad=False) b = nn.Variable([bs, 1, h, w], need_grad=False) # Apply image augmentation and get an unlinked variable xa_aug = image_augmentation(args, x, a) xa_aug.persistent = True xa_aug_unlinked = xa_aug.get_unlinked_variable() yb_aug = image_augmentation(args, y, b) yb_aug.persistent = True yb_aug_unlinked = yb_aug.get_unlinked_variable() # variables used for Image Pool x_history = nn.Variable([bs, ch, h, w]) a_history = nn.Variable([bs, 1, h, w]) y_history = nn.Variable([bs, ch, h, w]) b_history = nn.Variable([bs, 1, h, w]) # Generate Images (x -> y') with nn.parameter_scope("gen_x2y"): yb_fake = generator(xa_aug_unlinked) yb_fake.persistent = True yb_fake_unlinked = yb_fake.get_unlinked_variable() # Generate Images (y -> x') with nn.parameter_scope("gen_y2x"): xa_fake = generator(yb_aug_unlinked) xa_fake.persistent = True xa_fake_unlinked = xa_fake.get_unlinked_variable() # Reconstruct Images (y' -> x) with nn.parameter_scope("gen_y2x"): xa_recon = generator(yb_fake_unlinked) xa_recon.persistent = True # Reconstruct Images (x' -> y) with nn.parameter_scope("gen_x2y"): yb_recon = generator(xa_fake_unlinked) yb_recon.persistent = True # Use Discriminator on y' and x' with nn.parameter_scope("dis_y"): d_y_fake = discriminator(yb_fake_unlinked) d_y_fake.persistent = True with nn.parameter_scope("dis_x"): d_x_fake = discriminator(xa_fake_unlinked) d_x_fake.persistent = True # Use Discriminator on y and x with nn.parameter_scope("dis_y"): d_y_real = discriminator(yb_aug_unlinked) with nn.parameter_scope("dis_x"): d_x_real = discriminator(xa_aug_unlinked) # Identity Mapping (x -> x) with nn.parameter_scope("gen_y2x"): xa_idt = generator(xa_aug_unlinked) # Identity Mapping (y -> y) with nn.parameter_scope("gen_x2y"): yb_idt = generator(yb_aug_unlinked) # -------------------- Loss -------------------- # (LS)GAN Loss (for Discriminator) loss_dis_x = (loss.lsgan_loss(d_y_fake, False) + loss.lsgan_loss(d_y_real, True)) * 0.5 loss_dis_y = (loss.lsgan_loss(d_x_fake, False) + loss.lsgan_loss(d_x_real, True)) * 0.5 loss_dis = loss_dis_x + loss_dis_y # Cycle Consistency Loss loss_cyc_x = args.lambda_cyc * loss.recon_loss(xa_recon, xa_aug_unlinked) loss_cyc_y = args.lambda_cyc * loss.recon_loss(yb_recon, yb_aug_unlinked) loss_cyc = loss_cyc_x + loss_cyc_y # Identity Mapping Loss loss_idt_x = args.lambda_idt * loss.recon_loss(xa_idt, xa_aug_unlinked) loss_idt_y = args.lambda_idt * loss.recon_loss(yb_idt, yb_aug_unlinked) loss_idt = loss_idt_x + loss_idt_y # Context Preserving Loss loss_ctx_x = args.lambda_ctx * \ loss.context_preserving_loss(xa_aug_unlinked, yb_fake_unlinked) loss_ctx_y = args.lambda_ctx * \ loss.context_preserving_loss(yb_aug_unlinked, xa_fake_unlinked) loss_ctx = loss_ctx_x + loss_ctx_y # (LS)GAN Loss (for Generator) d_loss_gen_x = loss.lsgan_loss(d_x_fake, True) d_loss_gen_y = loss.lsgan_loss(d_y_fake, True) d_loss_gen = d_loss_gen_x + d_loss_gen_y # Total Loss for Generator loss_gen = loss_cyc + loss_idt + loss_ctx + d_loss_gen # --------------------- Solvers -------------------- # Initial learning rates G_lr = args.learning_rate_G #D_lr = args.learning_rate_D # As opposed to the description in the paper, D_lr is set the same as G_lr. D_lr = args.learning_rate_G # Define solvers solver_gen_x2y = S.Adam(G_lr, args.beta1, args.beta2) solver_gen_y2x = S.Adam(G_lr, args.beta1, args.beta2) solver_dis_x = S.Adam(D_lr, args.beta1, args.beta2) solver_dis_y = S.Adam(D_lr, args.beta1, args.beta2) # Set Parameters to each solver with nn.parameter_scope("gen_x2y"): solver_gen_x2y.set_parameters(nn.get_parameters()) with nn.parameter_scope("gen_y2x"): solver_gen_y2x.set_parameters(nn.get_parameters()) with nn.parameter_scope("dis_x"): solver_dis_x.set_parameters(nn.get_parameters()) with nn.parameter_scope("dis_y"): solver_dis_y.set_parameters(nn.get_parameters()) # create convenient functions manipulating Solvers def solvers_zero_grad(): # Zeroing Gradients of all solvers solver_gen_x2y.zero_grad() solver_gen_y2x.zero_grad() solver_dis_x.zero_grad() solver_dis_y.zero_grad() def solvers_update_parameters(new_D_lr, new_G_lr): # Learning rate updater solver_gen_x2y.set_learning_rate(new_G_lr) solver_gen_y2x.set_learning_rate(new_G_lr) solver_dis_x.set_learning_rate(new_D_lr) solver_dis_y.set_learning_rate(new_D_lr) # -------------------- Data Iterators -------------------- ds_train_A = insta_gan_data_source(args, train=True, domain="A", shuffle=True) di_train_A = insta_gan_data_iterator(ds_train_A, args.batch_size) ds_train_B = insta_gan_data_source(args, train=True, domain="B", shuffle=True) di_train_B = insta_gan_data_iterator(ds_train_B, args.batch_size) # -------------------- Monitors -------------------- monitoring_targets_dis = { 'discriminator_loss_x': loss_dis_x, 'discriminator_loss_y': loss_dis_y } monitors_dis = Monitors(args, monitoring_targets_dis) monitoring_targets_gen = { 'generator_loss_x': d_loss_gen_x, 'generator_loss_y': d_loss_gen_y, 'reconstruction_loss_x': loss_cyc_x, 'reconstruction_loss_y': loss_cyc_y, 'identity_mapping_loss_x': loss_idt_x, 'identity_mapping_loss_y': loss_idt_y, 'content_preserving_loss_x': loss_ctx_x, 'content_preserving_loss_y': loss_ctx_y } monitors_gen = Monitors(args, monitoring_targets_gen) monitor_time = MonitorTimeElapsed("Training_time", Monitor(args.monitor_path), args.log_step) # Training loop epoch = 0 n_images = max([ds_train_B.size, ds_train_A.size]) print("{} images exist.".format(n_images)) max_iter = args.max_epoch * n_images // args.batch_size decay_iter = args.max_epoch - args.lr_decay_start_epoch for i in range(max_iter): if i % (n_images // args.batch_size) == 0 and i > 0: # Learning Rate Decay epoch += 1 print("epoch {}".format(epoch)) if epoch >= args.lr_decay_start_epoch: new_D_lr = D_lr * \ (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) / float(decay_iter - 1)) new_G_lr = G_lr * \ (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) / float(decay_iter - 1)) solvers_update_parameters(new_D_lr, new_G_lr) print("Current learning rate for Discriminator: {}".format( solver_dis_x.learning_rate())) print("Current learning rate for Generator: {}".format( solver_gen_x2y.learning_rate())) # Get data x_data, a_data = di_train_A.next() y_data, b_data = di_train_B.next() x.d, a.d = x_data, a_data y.d, b.d = y_data, b_data solvers_zero_grad() # Image Augmentation nn.forward_all([xa_aug, yb_aug], clear_buffer=True) # Generate fake images nn.forward_all([xa_fake, yb_fake], clear_no_need_grad=True) # -------- Train Discriminator -------- loss_dis.forward(clear_no_need_grad=True) monitors_dis.add(i) loss_dis.backward(clear_buffer=True) solver_dis_x.update() solver_dis_y.update() # -------- Train Generators -------- # since the gradients computed above remain, reset to zero. xa_fake_unlinked.grad.zero() yb_fake_unlinked.grad.zero() solvers_zero_grad() loss_gen.forward(clear_no_need_grad=True) monitors_gen.add(i) monitor_time.add(i) loss_gen.backward(clear_buffer=True) xa_fake.backward(grad=None, clear_buffer=True) yb_fake.backward(grad=None, clear_buffer=True) solver_gen_x2y.update() solver_gen_y2x.update() if i % (n_images // args.batch_size) == 0: # save translation results after every epoch. save_images(args, i, xa_aug, yb_fake, domain="x", reconstructed=xa_recon) save_images(args, i, yb_aug, xa_fake, domain="y", reconstructed=yb_recon) # save pretrained parameters nn.save_parameters(os.path.join(args.model_save_path, 'params_%06d.h5' % i))
def run(args, data_iter, model, pad_id, optimizer, epoch, train=True): if train is True: model.train() else: model.eval() data_iter.init_epoch() batch_time = AverageMeter() size = min(len(data_iter.data()), args.epoch_size * args.batch_size) re_loss = 0 kl_divergence = 0 flow_kl_divergence = 0 mutual_information1, mutual_information2 = 0, 0 seq_words = 0 mmd_loss = 0 negative_ll = 0 iw_negative_ll = 0 sum_log_j = 0 start = time.time() end = time.time() for i, batch in enumerate(data_iter): if i == args.epoch_size: break texts, lengths = batch.text batch_size = texts.size(0) inputs = texts[:, :-1].clone() targets = texts[:, 1:].clone() q_z, p_z, z, outputs, sum_log_jacobian, penalty, z0 = model( inputs, lengths - 1, pad_id) if args.loss_type == 'entropy': reloss = recon_loss(outputs, targets, pad_id, id=args.loss_type) else: reloss = recon_loss(inputs, outputs, pad_id, id=args.loss_type) kld = total_kld(q_z, p_z) if args.flow: f_kld = flow_kld(q_z, p_z, z, z0, sum_log_jacobian) else: f_kld = torch.zeros(1) mi_z = mutual_info(q_z, p_z, z0) nll = compute_nll(q_z, p_z, z, z0, sum_log_jacobian, reloss) if args.iw: iw_nll = model.iw_nll(q_z, p_z, inputs, targets, lengths - 1, pad_id, args.nsamples) else: iw_nll = torch.zeros(1) if args.flow: mi_flow = mutual_info_flow(q_z, p_z, z, z0, sum_log_jacobian) else: mi_flow = torch.zeros(1).to(z.device) mmd = torch.zeros(1).to(z.device) kld_weight = weight_schedule(args.epoch_size * (epoch - 1) + i) if args.kla else 1. if args.mmd: # prior_samples = torch.randn(z.size(0), z.size(-1)).to(z.device) mmd = compute_mmd(p_z, q_z, args.kernel) if kld_weight > args.t: kld_weight = args.t if args.nokld: kld_weight = 0 if train is True: optimizer.zero_grad() if args.flow: # loss = reloss / batch_size + kld_weight * (kld - torch.sum(sum_log_jacobian) + torch.sum(penalty)) / batch_size + (args.mmd_w - kld_weight) * mmd loss = reloss / batch_size + kld_weight * (q_z.log_prob( z0).sum() - p_z.log_prob(z).sum()) / batch_size - ( torch.sum(sum_log_jacobian) - torch.sum(penalty) ) / batch_size + (args.mmd_w - kld_weight) * mmd else: loss = (reloss + kld_weight * kld) / batch_size + ( args.mmd_w - kld_weight) * mmd loss.backward() optimizer.step() re_loss += reloss.item() / size kl_divergence += kld.item() / size flow_kl_divergence += f_kld.item() * batch_size / size mutual_information1 += mi_z.item() * batch_size / size mutual_information2 += mi_flow.item() * batch_size / size seq_words += torch.sum(lengths - 1).item() mmd_loss += mmd.item() * batch_size / size negative_ll += nll.item() * batch_size / size iw_negative_ll += iw_nll.item() * batch_size / size sum_log_j += torch.sum(sum_log_jacobian).item() / size batch_time.update(time.time() - end) if kl_divergence > 100: kl_divergence = 100 flow_kl_divergence = 100 if args.iw: nll_ppl = math.exp(iw_negative_ll * size / seq_words) else: nll_ppl = math.exp(negative_ll * size / seq_words) return re_loss, kl_divergence, flow_kl_divergence, mutual_information1, mutual_information2, mmd_loss, nll_ppl, negative_ll, iw_negative_ll, sum_log_j, start, batch_time