class BeautyGlow(nn.Module): def __init__(self): self.w = nn.Linear(128, 128, bias=False) self.glow = Glow(3, 32, 4, affine=True, conv_lu=True) def forward(self, reference, source, l_x, l_y): l_ref = self.glow.reverse(reference) l_source = self.glow.reverse(source) f_ref = self.w(l_ref) f_source = self.w(l_souece) m_ref = F.linear(l_ref, torch.eye(128) - self.w.weight) m_source =F.linear(l_source, torch.eye(128) - self.w.weight) l_source_y = m_ref + f_source print(l_source_y) result = self.glow(l_source) perceptual_loss = F.mse_loss(f_ref, l_source) makeup_loss = F.mse_loss(m_ref, l_y - l_x) intra_domain_loss = F.mse_loss(f_ref, l_x) + F.mse_loss(l_source, l_y) l2_norm_f = F.mse_loss(f_ref, torch.zeros(f_ref.size())) * \ F.mse_loss(l_y, torch.zeros(l_y.size())) sim_f = torch.sum(f_ref * l_y) / l2_norm_f l2_norm_l = F.mse_loss(l_source, torch.zeros(l_source.size())) * \ F.mse_loss(l_x, torch.zeros(l_x.size())) sim_l = torch.sum(l_source * l_x) / l2_norm_l inter_domain_loss = 1 + sim_f + 1 + sim_l cycle_f = F.mse_loss(self.w(l_source_y), f_source) cycle_m = F.mse_loss(F.linear(l_source_y, torch.eye(128) - self.w.weight, m_ref)) cycle_consistency_loss = cycle_f + cycle_m perceptual = 0.01 cycle = 0.001 makeup = 0.1 intra = 0.1 inter = 1000 loss = perceptual_loss + cycle * cycle_consistency_loss + makeup * makeup_loss\ + intra * intra_domain_loss + inter * inter_domain_loss return result, loss
def train(cfg): date_today = date.today().strftime("%b-%d-%Y") summary_writer = SummaryWriter(cfg.log_dir, flush_secs=5, filename_suffix=date_today) train_data = mx.gluon.data.vision.MNIST( train=True).transform_first(data_xform) train_loader = mx.gluon.data.DataLoader(train_data, shuffle=True, batch_size=cfg.batch_size) image_shape = train_data[0][0].shape # No initialization. Custom blocks encapsulate initialization and setting of data. net = Glow(image_shape, cfg.K, cfg.L, cfg.affine, cfg.filter_size, cfg.temp, cfg.n_bits) ctx = get_context(cfg.use_gpu) net = set_context(net, ctx) trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': cfg.lr}) n_samples = len(train_loader) update_interval = n_samples // 2 # store the loss with summary writer twice loss_buffer = LossBuffer() global_step = 1 for epoch in range(1, cfg.n_epochs + 1): for idx, (batch, label) in enumerate(train_loader): print(f'Epoch {epoch} - Batch {idx}/{n_samples}', end='\r') data = mx.gluon.utils.split_and_load(batch, ctx) with mx.autograd.record(): for X in data: z_list, nll, bpd = net(X) prev_loss = loss_buffer.new_loss(bpd.mean()) loss_buffer.loss.backward() trainer.step(1) if prev_loss is not None and global_step % update_interval == 0: loss = prev_loss.asscalar() summary_writer.add_scalar(tag='bpd', value=loss, global_step=global_step) global_step += 1 # Sample from latent space to generate random digit and reverse from latent if (epoch % cfg.plot_interval) == 0: x_generate = net.reverse()[0] x_generate = x_generate.reshape(1, *x_generate.shape) x_recon = net.reverse(z_list[-1])[0] x_recon = x_recon.reshape(1, *x_recon.shape) x_real = data[0][0].reshape(1, *data[0][0].shape) minim = -0.5 maxim = 0.5 x_generate = x_generate.clip(minim, maxim) x_generate += -minim x_recon = x_recon.clip(minim, maxim) x_recon += -minim x_real += -minim img = mx.nd.concatenate([x_real, x_generate, x_recon], axis=0).asnumpy() summary_writer.add_image(tag='generations', image=img, global_step=global_step) summary_writer.close()