def train_analytical(data_loader, ita, epoch, args): losses = 0 recons = 0 KLD_batches = 0 #KLDs_batches = np.zeros(args.z_dim) N = len(data_loader) for i, (x, y) in tqdm(enumerate(data_loader, 0), total=len(data_loader), smoothing=0.9): ita += 1 train_optimizer.zero_grad() x = x.cuda() x_hat, mu, logvar, z, param = model(x) model.module.capacity = torch.tensor( linear_annealing(0, args.capacity, ita, args.reg_anneal)).float().cuda().requires_grad_() loss, recon, KLD_total_mean, KLDs = model.module.losses( x, x_hat, mu, logvar, z, objective, equality=args.kl_equality) loss.backward() train_optimizer.step() losses += float(loss.detach()) recons += float(recon.detach()) KLD_batches += float(KLD_total_mean.detach()) batch_var_means = torch.std(mu.detach(), dim=0).pow(2) KLDs = KLDs.detach().mean(0) writer.add_scalar("Loss/ita", loss.detach(), ita) writer.add_scalar("recons/ita", recon.detach(), ita) writer.add_scalar("KLD/ita", KLD_total_mean, ita) dic = {} for j in range(KLDs.shape[0]): dic['u_{}'.format(j)] = KLDs[j] writer.add_scalars("KLD_units/ita", dic, ita) dic.clear() dic = {} for j in range(batch_var_means.shape[0]): dic['var_u_{}'.format(j)] = batch_var_means[j] writer.add_scalars("std_means/ita", dic, ita) dic.clear() if ita >= args.max_iter: break if i == N - 1: losses /= N recons /= N KLD_batches /= N print("After an epoch in itaration_{} AVG: loss:{}, recons:{}, KLD:{}". format(ita, losses, recons, KLD_batches)) return ita
def train_TC_montecarlo(data_loader, ita, epoch, args): losses = 0 recons = 0 KLD_batches = 0 mutual_info = 0 totall_coorelation = 0 regularzation = 0 N = len(data_loader) dataset_size = N * args.batch_size AVG_KLD = torch.zeros(model.module.z_dim) means = [] for i, (x, y) in tqdm(enumerate(data_loader, 0), total=len(data_loader), smoothing=0.9): ita += 1 train_optimizer.zero_grad() x = x.cuda(async=True) x_hat, mu, logvar, z, params = model(x) model.module.gamma = linear_annealing(0, 1, ita, args.reg_anneal) loss, recon, mi, tc, reg = model.module.beta_tc_loss( x, x_hat, params, z, dataset_size) loss.backward() train_optimizer.step() losses += float(loss) recons += float(recon) mutual_info += float(mi) totall_coorelation += float(tc) regularzation += float(reg) KLDs = model.module.kld_unit_guassians_per_sample( mu.clone().detach(), logvar.clone().detach()) KLDs = KLDs.detach().mean(0) batch_var_means = torch.std(mu.detach(), dim=0).pow(2) #var_means = torch.std(mu.detach(),dim=0).pow(2) if ita == args.max_iter or epoch == args.epochs - 1: AVG_KLD += KLDs means.append(mu.clone().detach()) writer.add_scalar("Loss/ita", loss, ita) writer.add_scalar("recons/ita", recon, ita) writer.add_scalar("mutual_info/ita", mi, ita) writer.add_scalar("totall_coorelation/ita", tc, ita) writer.add_scalar("reg/ita", reg, ita) dic = {} AVG_KLD /= dataset_size for j in range(KLDs.shape[0]): dic['u_{}'.format(j)] = KLDs[j] writer.add_scalars("KLD_units/ita", dic, ita) dic.clear() dic = {} for j in range(batch_var_means.shape[0]): dic['var_u_{}'.format(j)] = batch_var_means[j] writer.add_scalars("std_means/ita", dic, ita) dic.clear() if ita >= args.max_iter: break if i == N - 1: losses /= N recons /= N mutual_info /= N totall_coorelation /= N regularzation /= N print( "After an epoch in itaration_{} AVG: loss:{}, recons:{},mutual_info:{}, totall_coorelation:{}, regularzation{}" .format(ita, losses, recons, mutual_info, totall_coorelation, regularzation)) if ita == args.max_iter or epoch == args.epochs - 1: cat_means = torch.cat(means) VAR_means = torch.std(cat_means, dim=0).pow(2) torch.save({ 'AVG_KLDS': AVG_KLD / N, 'VAR_means': VAR_means }, "AVG_KLDs_VAR_means.pth") return ita
def train_analytical_factorvae_tc(data_loader, ita, epoch, args): losses = 0 recons = 0 KLD_batches = 0 #KLDs_batches = np.zeros(args.z_dim) N = len(data_loader) for i, (x, x2, y) in tqdm(enumerate(data_loader, 0), total=len(data_loader), smoothing=0.9): ita += 1 train_optimizer.zero_grad() x = x.cuda() x_hat, mu, logvar, z, param = model(x) model.module.capacity = torch.tensor( linear_annealing(0, args.capacity, ita, args.reg_anneal)).float().cuda() vae_tc_loss = 0 if args.factorvae_tc: dz = discriminator(z) vae_tc_loss = (dz[:, :1] - dz[:, 1:]).mean() loss, recon, KLD_total_mean, KLDs = model.module.losses( x, x_hat, mu, logvar, z, objective, vae_tc_loss) loss.backward(retain_graph=True) train_optimizer.step() x2 = x2.cuda() with torch.no_grad(): params = model.module.encoder(x2).view(x2.size(0), args.z_dim, 2) mu, logstd_var = params.select(-1, 0), params.select(-1, 1) z_prime = model.module.reparam(mu, logstd_var) z_pperm = permute_dims(z_prime).clone() params = model.module.encoder(x).view(x.size(0), args.z_dim, 2) mu, logstd_var = params.select(-1, 0), params.select(-1, 1) dz_pr = model.module.reparam(mu, logstd_var).clone() D_z_pperm = discriminator(z_pperm) D_tc_loss = 0.5 * (F.cross_entropy(dz_pr, ZEROS) + F.cross_entropy(D_z_pperm, ONES)) optim_D.zero_grad() D_tc_loss.backward() optim_D.step() losses += float(loss.clone()) recons += float(recon) KLD_batches += float(KLD_total_mean) KLDs = KLDs.detach().mean(0) writer.add_scalar("Loss/ita", loss, ita) writer.add_scalar("recons/ita", recon, ita) writer.add_scalar("KLD/ita", KLD_total_mean, ita) dic = {} for j in range(KLDs.shape[0]): dic['u_{}'.format(j)] = KLDs[j] writer.add_scalars("KLD_units/ita", dic, ita) dic.clear() if ita >= args.max_iter: break if i == N - 1: losses /= N recons /= N KLD_batches /= N print("After an epoch in itaration_{} AVG: loss:{}, recons:{}, KLD:{}". format(ita, losses, recons, KLD_batches)) return ita
def train_TC_montecarlo(data_loader,ita, epoch, args): losses = 0 recons = 0 KLD_batches = 0 mutual_info = 0 totall_coorelation = 0 regularzation = 0 disc_mutual_info = 0 disc_totall_coorelation = 0 disc_regularzation = 0 N = len(data_loader) dataset_size = N * args.batch_size AVG_KLD = torch.zeros(model.module.num_latent_dims) means = [] for i ,(x, y) in tqdm(enumerate(data_loader, 0),total=len(data_loader),smoothing=0.9): ita += 1 train_optimizer.zero_grad() x = x.cuda(async=True) x_hat, mu, logvar, z, alphas, rep_as = model(x) model.module.gamma = linear_annealing(0, 1, ita, args.reg_anneal) loss, recon, mi, tc, reg, mi_disc, tc_disc, reg_disc = \ model.module.beta_tc_loss(x, x_hat, mu, logvar, z, alphas, rep_as, dataset_size) loss.backward() train_optimizer.step() losses += float(loss) recons += float(recon) mutual_info += float(mi) totall_coorelation += float(tc) regularzation += float(reg) disc_mutual_info += float(tc_disc) disc_totall_coorelation += float(tc_disc) disc_regularzation += float(reg_disc) KLDs = model.module.kld_unit_guassians_per_sample(mu.clone().detach(), logvar.clone().detach()) KLDs = KLDs.detach().mean(0) batch_var_means = torch.std(mu.detach(),dim=0).pow(2) if ita == args.max_iter or epoch == args.epochs -1: AVG_KLD += KLDs means.append(mu.clone().detach()) ###discrete: if model.module.num_latent_dims_disc >0: KLDs_catigorical = [] for alpha in alphas: unifom_params = torch.ones_like(alpha)/alpha.shape[1] kld = kl_divergence(Categorical(alpha.detach()),Categorical(unifom_params)) KLDs_catigorical.append(kld.clone().detach().mean()) dic ={} for disc_dim in range(len(alphas)): dic['disc_{}'.format(disc_dim)] = float(KLDs_catigorical[disc_dim]) writer.add_scalars("KLDs_disc/ita",dic, ita) writer.add_scalar("Loss/ita",loss,ita) writer.add_scalar("recons/ita",recon,ita) writer.add_scalar("mutual_info/ita",mi,ita) writer.add_scalar("totall_coorelation/ita",tc,ita) writer.add_scalar("reg/ita",reg,ita) writer.add_scalar("mutual_info_disc/ita",mi_disc,ita) writer.add_scalar("totall_coorelation_disc/ita",tc_disc,ita) writer.add_scalar("reg_disc/ita",reg_disc,ita) dic = {} AVG_KLD /= dataset_size for j in range(KLDs.shape[0]): dic['u_{}'.format(j)] = KLDs[j] writer.add_scalars("KLD_units/ita",dic, ita) dic.clear() dic = {} for j in range(batch_var_means.shape[0]): dic['var_u_{}'.format(j)] = batch_var_means[j] writer.add_scalars("std_means/ita",dic, ita) dic.clear() if ita >=args.max_iter: break if i == N -1: losses /= N recons /= N mutual_info /= N totall_coorelation /=N regularzation /= N disc_mutual_info /=N disc_totall_coorelation /=N disc_regularzation /=N print("After an epoch in itaration_%0.2f AVG: loss:%0.2f, recons:%0.2f,mutual_info:%0.2f, totall_coorelation:%0.2f, regularzation%0.2f\ Discrete = mutual_info:%0.2f, totall_coorelation:%0.2f, regularzation%0.2f"% (ita, losses, recons, mutual_info, totall_coorelation, regularzation, disc_mutual_info, disc_totall_coorelation, disc_regularzation)) if ita ==args.max_iter or epoch == args.epochs -1: cat_means = torch.cat(means) VAR_means = torch.std(cat_means,dim=0).pow(2) torch.save({'AVG_KLDS': AVG_KLD/N,'VAR_means':VAR_means},"AVG_KLDs_VAR_means.pth") return ita