def train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging): alpha_i = utils.kl_balancer_coeff(num_scales=model.num_latent_scales, groups_per_scale=model.groups_per_scale, fun='square') nelbo = utils.AvgrageMeter() model.train() for step, x in enumerate(train_queue): x = x[0] if len(x) > 1 else x x = x.half().cuda() # change bit length x = utils.pre_process(x, args.num_x_bits) # warm-up lr if global_step < warmup_iters: lr = args.learning_rate * float(global_step) / warmup_iters for param_group in cnn_optimizer.param_groups: param_group['lr'] = lr # sync parameters, it may not be necessary if step % 100 == 0: utils.average_params(model.parameters(), args.distributed) cnn_optimizer.zero_grad() with autocast(): logits, log_q, log_p, kl_all, kl_diag = model(x) output = model.decoder_output(logits) kl_coeff = utils.kl_coeff( global_step, args.kl_anneal_portion * args.num_total_iter, args.kl_const_portion * args.num_total_iter, args.kl_const_coeff) recon_loss = utils.reconstruction_loss(output, x, crop=model.crop_output) balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer( kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i) nelbo_batch = recon_loss + balanced_kl loss = torch.mean(nelbo_batch) norm_loss = model.spectral_norm_parallel() bn_loss = model.batchnorm_loss() # get spectral regularization coefficient (lambda) if args.weight_decay_norm_anneal: assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.' wdn_coeff = (1. - kl_coeff) * np.log( args.weight_decay_norm_init) + kl_coeff * np.log( args.weight_decay_norm) wdn_coeff = np.exp(wdn_coeff) else: wdn_coeff = args.weight_decay_norm loss += norm_loss * wdn_coeff + bn_loss * wdn_coeff grad_scalar.scale(loss).backward() utils.average_gradients(model.parameters(), args.distributed) grad_scalar.step(cnn_optimizer) grad_scalar.update() nelbo.update(loss.data, 1) if (global_step + 1) % 100 == 0: if (global_step + 1) % 1000 == 0: # reduced frequency n = int(np.floor(np.sqrt(x.size(0)))) x_img = x[:n * n] output_img = output.mean if isinstance( output, torch.distributions.bernoulli.Bernoulli ) else output.sample() output_img = output_img[:n * n] x_tiled = utils.tile_image(x_img, n) output_tiled = utils.tile_image(output_img, n) in_out_tiled = torch.cat((x_tiled, output_tiled), dim=2) writer.add_image('reconstruction', in_out_tiled, global_step) # norm writer.add_scalar('train/norm_loss', norm_loss, global_step) writer.add_scalar('train/bn_loss', bn_loss, global_step) writer.add_scalar('train/norm_coeff', wdn_coeff, global_step) utils.average_tensor(nelbo.avg, args.distributed) logging.info('train %d %f', global_step, nelbo.avg) writer.add_scalar('train/nelbo_avg', nelbo.avg, global_step) writer.add_scalar( 'train/lr', cnn_optimizer.state_dict()['param_groups'][0]['lr'], global_step) writer.add_scalar('train/nelbo_iter', loss, global_step) writer.add_scalar('train/kl_iter', torch.mean(sum(kl_all)), global_step) writer.add_scalar( 'train/recon_iter', torch.mean( utils.reconstruction_loss(output, x, crop=model.crop_output)), global_step) writer.add_scalar('kl_coeff/coeff', kl_coeff, global_step) total_active = 0 for i, kl_diag_i in enumerate(kl_diag): utils.average_tensor(kl_diag_i, args.distributed) num_active = torch.sum(kl_diag_i > 0.1).detach() total_active += num_active # kl_ceoff writer.add_scalar('kl/active_%d' % i, num_active, global_step) writer.add_scalar('kl_coeff/layer_%d' % i, kl_coeffs[i], global_step) writer.add_scalar('kl_vals/layer_%d' % i, kl_vals[i], global_step) writer.add_scalar('kl/total_active', total_active, global_step) global_step += 1 utils.average_tensor(nelbo.avg, args.distributed) return nelbo.avg, global_step
def forward(self, x, global_step, args): if args.fp16: x = x.half() metrics = {} alpha_i = utils.kl_balancer_coeff( num_scales=self.num_latent_scales, groups_per_scale=self.groups_per_scale, fun='square') x_in = self.preprocess(x) if args.fp16: x_in = x_in.half() s = self.stem(x_in) # perform pre-processing for cell in self.pre_process: s = cell(s) # run the main encoder tower combiner_cells_enc = [] combiner_cells_s = [] for cell in self.enc_tower: if cell.cell_type == 'combiner_enc': combiner_cells_enc.append(cell) combiner_cells_s.append(s) else: s = cell(s) # reverse combiner cells and their input for decoder combiner_cells_enc.reverse() combiner_cells_s.reverse() idx_dec = 0 ftr = self.enc0(s) # this reduces the channel dimension param0 = self.enc_sampler[idx_dec](ftr) mu_q, log_sig_q = torch.chunk(param0, 2, dim=1) dist = Normal(mu_q, log_sig_q) # for the first approx. posterior z, _ = dist.sample() log_q_conv = dist.log_p(z) # apply normalizing flows nf_offset = 0 for n in range(self.num_flows): z, log_det = self.nf_cells[n](z, ftr) log_q_conv -= log_det nf_offset += self.num_flows all_q = [dist] all_log_q = [log_q_conv] # To make sure we do not pass any deterministic features from x to decoder. s = 0 # prior for z0 dist = Normal(mu=torch.zeros_like(z), log_sigma=torch.zeros_like(z)) log_p_conv = dist.log_p(z) all_p = [dist] all_log_p = [log_p_conv] idx_dec = 0 s = self.prior_ftr0.unsqueeze(0) batch_size = z.size(0) s = s.expand(batch_size, -1, -1) for cell in self.dec_tower: if cell.cell_type == 'combiner_dec': if idx_dec > 0: # form prior param = self.dec_sampler[idx_dec - 1](s) mu_p, log_sig_p = torch.chunk(param, 2, dim=1) # form encoder ftr = combiner_cells_enc[idx_dec - 1]( combiner_cells_s[idx_dec - 1], s) param = self.enc_sampler[idx_dec](ftr) mu_q, log_sig_q = torch.chunk(param, 2, dim=1) dist = Normal(mu_p + mu_q, log_sig_p + log_sig_q) if self.res_dist else Normal( mu_q, log_sig_q) z, _ = dist.sample() log_q_conv = dist.log_p(z) # apply NF for n in range(self.num_flows): z, log_det = self.nf_cells[nf_offset + n](z, ftr) log_q_conv -= log_det nf_offset += self.num_flows all_log_q.append(log_q_conv) all_q.append(dist) # evaluate log_p(z) dist = Normal(mu_p, log_sig_p) log_p_conv = dist.log_p(z) all_p.append(dist) all_log_p.append(log_p_conv) # 'combiner_dec' s = cell(s, z) idx_dec += 1 else: s = cell(s) if self.vanilla_vae: s = self.stem_decoder(z) for cell in self.post_process: s = cell(s) logits = self.image_conditional(s) # compute kl kl_all = [] kl_diag = [] log_p, log_q = 0., 0. for q, p, log_q_conv, log_p_conv in zip(all_q, all_p, all_log_q, all_log_p): if self.with_nf: kl_per_var = log_q_conv - log_p_conv else: kl_per_var = q.kl(p) kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=2), dim=0)) kl_all.append(torch.sum(kl_per_var, dim=[1, 2])) log_q += torch.sum(log_q_conv, dim=[1, 2]) log_p += torch.sum(log_p_conv, dim=[1, 2]) output = self.decoder_output(logits) """ def _spectral_loss(x_target, x_out, args): if hps.use_nonrelative_specloss: sl = spectral_loss(x_target, x_out, args) / args.bandwidth['spec'] else: sl = spectral_convergence(x_target, x_out, args) sl = t.mean(sl) return sl def _multispectral_loss(x_target, x_out, args): sl = multispectral_loss(x_target, x_out, args) / args.bandwidth['spec'] sl = t.mean(sl) return sl """ kl_coeff = utils.kl_coeff(global_step, args.kl_anneal_portion * args.num_total_iter, args.kl_const_portion * args.num_total_iter, args.kl_const_coeff) recon_loss = utils.reconstruction_loss(output, x, crop=self.crop_output) balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i) nelbo_batch = recon_loss + balanced_kl loss = torch.mean(nelbo_batch) bn_loss = self.batchnorm_loss() norm_loss = self.spectral_norm_parallel() #x_target = audio_postprocess(x.float(), args) #x_out = audio_postprocess(output.sample(), args) #spec_loss = _spectral_loss(x_target, x_out, args) #multispec_loss = _multispectral_loss(x_target, x_out, args) if args.weight_decay_norm_anneal: assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.' wdn_coeff = (1. - kl_coeff) * np.log( args.weight_decay_norm_init) + kl_coeff * np.log( args.weight_decay_norm) wdn_coeff = np.exp(wdn_coeff) else: wdn_coeff = args.weight_decay_norm loss += bn_loss * wdn_coeff + norm_loss * wdn_coeff metrics.update( dict(recon_loss=recon_loss, bn_loss=bn_loss, norm_loss=norm_loss, wdn_coeff=wdn_coeff, kl_all=torch.mean(sum(kl_all)), kl_coeff=kl_coeff)) for key, val in metrics.items(): metrics[key] = val.detach() return output, loss, metrics