def forward(self, x): s = self.stem(2 * x - 1.0) # perform pre-processing for cell in self.pre_process: s = cell(s) # run the main encoder tower combiner_cells_enc = [] combiner_cells_s = [] for cell in self.enc_tower: if cell.cell_type == 'combiner_enc': combiner_cells_enc.append(cell) combiner_cells_s.append(s) else: s = cell(s) # reverse combiner cells and their input for decoder combiner_cells_enc.reverse() combiner_cells_s.reverse() idx_dec = 0 ftr = self.enc0(s) # this reduces the channel dimension param0 = self.enc_sampler[idx_dec](ftr) mu_q, log_sig_q = torch.chunk(param0, 2, dim=1) dist = Normal(mu_q, log_sig_q) # for the first approx. posterior z, _ = dist.sample() log_q_conv = dist.log_p(z) # apply normalizing flows nf_offset = 0 for n in range(self.num_flows): z, log_det = self.nf_cells[n](z, ftr) log_q_conv -= log_det nf_offset += self.num_flows all_q = [dist] all_log_q = [log_q_conv] # To make sure we do not pass any deterministic features from x to decoder. s = 0 # prior for z0 dist = Normal(mu=torch.zeros_like(z), log_sigma=torch.zeros_like(z)) log_p_conv = dist.log_p(z) all_p = [dist] all_log_p = [log_p_conv] idx_dec = 0 s = self.prior_ftr0.unsqueeze(0) batch_size = z.size(0) s = s.expand(batch_size, -1, -1, -1) for cell in self.dec_tower: if cell.cell_type == 'combiner_dec': if idx_dec > 0: # form prior param = self.dec_sampler[idx_dec - 1](s) mu_p, log_sig_p = torch.chunk(param, 2, dim=1) # form encoder ftr = combiner_cells_enc[idx_dec - 1]( combiner_cells_s[idx_dec - 1], s) param = self.enc_sampler[idx_dec](ftr) mu_q, log_sig_q = torch.chunk(param, 2, dim=1) dist = Normal(mu_p + mu_q, log_sig_p + log_sig_q) if self.res_dist else Normal( mu_q, log_sig_q) z, _ = dist.sample() log_q_conv = dist.log_p(z) # apply NF for n in range(self.num_flows): z, log_det = self.nf_cells[nf_offset + n](z, ftr) log_q_conv -= log_det nf_offset += self.num_flows all_log_q.append(log_q_conv) all_q.append(dist) # evaluate log_p(z) dist = Normal(mu_p, log_sig_p) log_p_conv = dist.log_p(z) all_p.append(dist) all_log_p.append(log_p_conv) # 'combiner_dec' s = cell(s, z) idx_dec += 1 else: s = cell(s) if self.vanilla_vae: s = self.stem_decoder(z) for cell in self.post_process: s = cell(s) logits = self.image_conditional(s) # compute kl kl_all = [] kl_diag = [] log_p, log_q = 0., 0. for q, p, log_q_conv, log_p_conv in zip(all_q, all_p, all_log_q, all_log_p): if self.with_nf: kl_per_var = log_q_conv - log_p_conv else: kl_per_var = q.kl(p) kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=[2, 3]), dim=0)) kl_all.append(torch.sum(kl_per_var, dim=[1, 2, 3])) log_q += torch.sum(log_q_conv, dim=[1, 2, 3]) log_p += torch.sum(log_p_conv, dim=[1, 2, 3]) return logits, log_q, log_p, kl_all, kl_diag
def forward(self, x, global_step, args): if args.fp16: x = x.half() metrics = {} alpha_i = utils.kl_balancer_coeff( num_scales=self.num_latent_scales, groups_per_scale=self.groups_per_scale, fun='square') x_in = self.preprocess(x) if args.fp16: x_in = x_in.half() s = self.stem(x_in) # perform pre-processing for cell in self.pre_process: s = cell(s) # run the main encoder tower combiner_cells_enc = [] combiner_cells_s = [] for cell in self.enc_tower: if cell.cell_type == 'combiner_enc': combiner_cells_enc.append(cell) combiner_cells_s.append(s) else: s = cell(s) # reverse combiner cells and their input for decoder combiner_cells_enc.reverse() combiner_cells_s.reverse() idx_dec = 0 ftr = self.enc0(s) # this reduces the channel dimension param0 = self.enc_sampler[idx_dec](ftr) mu_q, log_sig_q = torch.chunk(param0, 2, dim=1) dist = Normal(mu_q, log_sig_q) # for the first approx. posterior z, _ = dist.sample() log_q_conv = dist.log_p(z) # apply normalizing flows nf_offset = 0 for n in range(self.num_flows): z, log_det = self.nf_cells[n](z, ftr) log_q_conv -= log_det nf_offset += self.num_flows all_q = [dist] all_log_q = [log_q_conv] # To make sure we do not pass any deterministic features from x to decoder. s = 0 # prior for z0 dist = Normal(mu=torch.zeros_like(z), log_sigma=torch.zeros_like(z)) log_p_conv = dist.log_p(z) all_p = [dist] all_log_p = [log_p_conv] idx_dec = 0 s = self.prior_ftr0.unsqueeze(0) batch_size = z.size(0) s = s.expand(batch_size, -1, -1) for cell in self.dec_tower: if cell.cell_type == 'combiner_dec': if idx_dec > 0: # form prior param = self.dec_sampler[idx_dec - 1](s) mu_p, log_sig_p = torch.chunk(param, 2, dim=1) # form encoder ftr = combiner_cells_enc[idx_dec - 1]( combiner_cells_s[idx_dec - 1], s) param = self.enc_sampler[idx_dec](ftr) mu_q, log_sig_q = torch.chunk(param, 2, dim=1) dist = Normal(mu_p + mu_q, log_sig_p + log_sig_q) if self.res_dist else Normal( mu_q, log_sig_q) z, _ = dist.sample() log_q_conv = dist.log_p(z) # apply NF for n in range(self.num_flows): z, log_det = self.nf_cells[nf_offset + n](z, ftr) log_q_conv -= log_det nf_offset += self.num_flows all_log_q.append(log_q_conv) all_q.append(dist) # evaluate log_p(z) dist = Normal(mu_p, log_sig_p) log_p_conv = dist.log_p(z) all_p.append(dist) all_log_p.append(log_p_conv) # 'combiner_dec' s = cell(s, z) idx_dec += 1 else: s = cell(s) if self.vanilla_vae: s = self.stem_decoder(z) for cell in self.post_process: s = cell(s) logits = self.image_conditional(s) # compute kl kl_all = [] kl_diag = [] log_p, log_q = 0., 0. for q, p, log_q_conv, log_p_conv in zip(all_q, all_p, all_log_q, all_log_p): if self.with_nf: kl_per_var = log_q_conv - log_p_conv else: kl_per_var = q.kl(p) kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=2), dim=0)) kl_all.append(torch.sum(kl_per_var, dim=[1, 2])) log_q += torch.sum(log_q_conv, dim=[1, 2]) log_p += torch.sum(log_p_conv, dim=[1, 2]) output = self.decoder_output(logits) """ def _spectral_loss(x_target, x_out, args): if hps.use_nonrelative_specloss: sl = spectral_loss(x_target, x_out, args) / args.bandwidth['spec'] else: sl = spectral_convergence(x_target, x_out, args) sl = t.mean(sl) return sl def _multispectral_loss(x_target, x_out, args): sl = multispectral_loss(x_target, x_out, args) / args.bandwidth['spec'] sl = t.mean(sl) return sl """ kl_coeff = utils.kl_coeff(global_step, args.kl_anneal_portion * args.num_total_iter, args.kl_const_portion * args.num_total_iter, args.kl_const_coeff) recon_loss = utils.reconstruction_loss(output, x, crop=self.crop_output) balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i) nelbo_batch = recon_loss + balanced_kl loss = torch.mean(nelbo_batch) bn_loss = self.batchnorm_loss() norm_loss = self.spectral_norm_parallel() #x_target = audio_postprocess(x.float(), args) #x_out = audio_postprocess(output.sample(), args) #spec_loss = _spectral_loss(x_target, x_out, args) #multispec_loss = _multispectral_loss(x_target, x_out, args) if args.weight_decay_norm_anneal: assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.' wdn_coeff = (1. - kl_coeff) * np.log( args.weight_decay_norm_init) + kl_coeff * np.log( args.weight_decay_norm) wdn_coeff = np.exp(wdn_coeff) else: wdn_coeff = args.weight_decay_norm loss += bn_loss * wdn_coeff + norm_loss * wdn_coeff metrics.update( dict(recon_loss=recon_loss, bn_loss=bn_loss, norm_loss=norm_loss, wdn_coeff=wdn_coeff, kl_all=torch.mean(sum(kl_all)), kl_coeff=kl_coeff)) for key, val in metrics.items(): metrics[key] = val.detach() return output, loss, metrics