def test_glow(self): net = Glow(width=12, depth=3, n_levels=3) x = torch.randn(args.batch_size, 3, 32, 32) zs, logd = net(x) recon_x, inv_logd = net.inverse(zs) y, _ = net.inverse(batch_size=args.batch_size) d_data, d_data_y, d_logd = (recon_x - x).norm(), (x - y).norm(), (logd + inv_logd).norm() assert d_data < 1e-3, 'Data reconstruction fail - norm of difference = {}.'.format( d_data) # assert d_data_y < 1e-3, 'Data reconstruction (inv > base > inv) fail - norm of difference = {}.'.format(d_data_y) assert d_logd < 1e-3, 'Log determinant inversion fail. - norm of difference = {}'.format( d_logd)
class BeautyGlow(nn.Module): def __init__(self): self.w = nn.Linear(128, 128, bias=False) self.glow = Glow(3, 32, 4, affine=True, conv_lu=True) def forward(self, reference, source, l_x, l_y): l_ref = self.glow.reverse(reference) l_source = self.glow.reverse(source) f_ref = self.w(l_ref) f_source = self.w(l_souece) m_ref = F.linear(l_ref, torch.eye(128) - self.w.weight) m_source =F.linear(l_source, torch.eye(128) - self.w.weight) l_source_y = m_ref + f_source print(l_source_y) result = self.glow(l_source) perceptual_loss = F.mse_loss(f_ref, l_source) makeup_loss = F.mse_loss(m_ref, l_y - l_x) intra_domain_loss = F.mse_loss(f_ref, l_x) + F.mse_loss(l_source, l_y) l2_norm_f = F.mse_loss(f_ref, torch.zeros(f_ref.size())) * \ F.mse_loss(l_y, torch.zeros(l_y.size())) sim_f = torch.sum(f_ref * l_y) / l2_norm_f l2_norm_l = F.mse_loss(l_source, torch.zeros(l_source.size())) * \ F.mse_loss(l_x, torch.zeros(l_x.size())) sim_l = torch.sum(l_source * l_x) / l2_norm_l inter_domain_loss = 1 + sim_f + 1 + sim_l cycle_f = F.mse_loss(self.w(l_source_y), f_source) cycle_m = F.mse_loss(F.linear(l_source_y, torch.eye(128) - self.w.weight, m_ref)) cycle_consistency_loss = cycle_f + cycle_m perceptual = 0.01 cycle = 0.001 makeup = 0.1 intra = 0.1 inter = 1000 loss = perceptual_loss + cycle * cycle_consistency_loss + makeup * makeup_loss\ + intra * intra_domain_loss + inter * inter_domain_loss return result, loss
def __init__(self): self.w = nn.Linear(128, 128, bias=False) self.glow = Glow(3, 32, 4, affine=True, conv_lu=True)
def test_glow_3_3(self): model = Glow(width=24, depth=3, n_levels=3) self._train(model, 3)
def test_glow_1_1(self): model = Glow(width=12, depth=1, n_levels=1) self._train(model, 3)
def test_glow_depth_2_levels_2(self): # 1. sample data; 2. run model forward and reverse; 3. roconstruct data; 4. measure KL between Gaussian fitted to the data and the base distribution self.test_kl(Glow(width=12, depth=2, n_levels=2))
def train(self, params): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = params["batch_size"] learning_rate = params["learning_rate"] max_epoch = params["max_epoch"] interval = params["interval"] image_shape = params["image_shape"] dataset_name = params["dataset_name"] max_grad_clip = params["max_grad_clip"] max_grad_norm = params["max_grad_norm"] train_dataset = util.ImageDataset(params) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True) dt_now = datetime.now() dt_seq = dt_now.strftime("%y%m%d_%H%M") result_dir = os.path.join("./result", f"{dt_seq}_{dataset_name}") weight_dir = os.path.join(result_dir, "weights") sample_dir = os.path.join(result_dir, "sample") os.makedirs(result_dir, exist_ok=True) os.makedirs(weight_dir, exist_ok=True) os.makedirs(sample_dir, exist_ok=True) glow = Glow(params).to(device) optimizer = Adam(glow.parameters(), lr=learning_rate) initialized = False for epoch in range(max_epoch): for i, batch in enumerate(train_dataloader): batch = batch.to(device) if not initialized: glow.initialize_actnorm(batch) initialized = True z, nll = glow.inference(batch) loss_generative = torch.mean(nll) optimizer.zero_grad() loss_generative.backward() torch.nn.utils.clip_grad_value_(glow.parameters(), max_grad_clip) torch.nn.utils.clip_grad_norm_(glow.parameters(), max_grad_norm) optimizer.step() print( f"epoch {epoch} {i}/{len(train_dataloader)}, loss: {loss_generative.item():.4f}" ) if epoch % interval == 0: torch.save(glow.state_dict(), f"{weight_dir}/{epoch}_glow.pth") torch.save(optimizer.state_dict(), f"{weight_dir}/{epoch}_opt.pth") filename = f"{epoch}_glow.png" with torch.no_grad(): img = glow.generate(z, eps_std=0.5) util.save_samples(img, sample_dir, filename, image_shape, num_tiles=4)
def train(cfg): date_today = date.today().strftime("%b-%d-%Y") summary_writer = SummaryWriter(cfg.log_dir, flush_secs=5, filename_suffix=date_today) train_data = mx.gluon.data.vision.MNIST( train=True).transform_first(data_xform) train_loader = mx.gluon.data.DataLoader(train_data, shuffle=True, batch_size=cfg.batch_size) image_shape = train_data[0][0].shape # No initialization. Custom blocks encapsulate initialization and setting of data. net = Glow(image_shape, cfg.K, cfg.L, cfg.affine, cfg.filter_size, cfg.temp, cfg.n_bits) ctx = get_context(cfg.use_gpu) net = set_context(net, ctx) trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': cfg.lr}) n_samples = len(train_loader) update_interval = n_samples // 2 # store the loss with summary writer twice loss_buffer = LossBuffer() global_step = 1 for epoch in range(1, cfg.n_epochs + 1): for idx, (batch, label) in enumerate(train_loader): print(f'Epoch {epoch} - Batch {idx}/{n_samples}', end='\r') data = mx.gluon.utils.split_and_load(batch, ctx) with mx.autograd.record(): for X in data: z_list, nll, bpd = net(X) prev_loss = loss_buffer.new_loss(bpd.mean()) loss_buffer.loss.backward() trainer.step(1) if prev_loss is not None and global_step % update_interval == 0: loss = prev_loss.asscalar() summary_writer.add_scalar(tag='bpd', value=loss, global_step=global_step) global_step += 1 # Sample from latent space to generate random digit and reverse from latent if (epoch % cfg.plot_interval) == 0: x_generate = net.reverse()[0] x_generate = x_generate.reshape(1, *x_generate.shape) x_recon = net.reverse(z_list[-1])[0] x_recon = x_recon.reshape(1, *x_recon.shape) x_real = data[0][0].reshape(1, *data[0][0].shape) minim = -0.5 maxim = 0.5 x_generate = x_generate.clip(minim, maxim) x_generate += -minim x_recon = x_recon.clip(minim, maxim) x_recon += -minim x_real += -minim img = mx.nd.concatenate([x_real, x_generate, x_recon], axis=0).asnumpy() summary_writer.add_image(tag='generations', image=img, global_step=global_step) summary_writer.close()
opt = parser.parse_args() print(opt) if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) ###### Definition of variables ###### # Networks if opt.generator == "baseline": generator = CycleConsistentGenerator(opt.input_nc, opt.output_nc) generator.apply(weights_init_normal) elif opt.generator == "glow": generator = Glow(16, opt.input_nc, 256, squeeze=4) netD_A = Discriminator(opt.input_nc) netD_B = Discriminator(opt.output_nc) if opt.cuda: generator.cuda() netD_A.cuda() netD_B.cuda() netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss()