Beispiel #1
0
def train_analytical(data_loader, ita, epoch, args):

    losses = 0
    recons = 0
    KLD_batches = 0
    #KLDs_batches = np.zeros(args.z_dim)

    N = len(data_loader)
    for i, (x, y) in tqdm(enumerate(data_loader, 0),
                          total=len(data_loader),
                          smoothing=0.9):

        ita += 1
        train_optimizer.zero_grad()
        x = x.cuda()
        x_hat, mu, logvar, z, param = model(x)
        model.module.capacity = torch.tensor(
            linear_annealing(0, args.capacity, ita,
                             args.reg_anneal)).float().cuda().requires_grad_()

        loss, recon, KLD_total_mean, KLDs = model.module.losses(
            x, x_hat, mu, logvar, z, objective, equality=args.kl_equality)
        loss.backward()
        train_optimizer.step()

        losses += float(loss.detach())
        recons += float(recon.detach())
        KLD_batches += float(KLD_total_mean.detach())
        batch_var_means = torch.std(mu.detach(), dim=0).pow(2)
        KLDs = KLDs.detach().mean(0)
        writer.add_scalar("Loss/ita", loss.detach(), ita)
        writer.add_scalar("recons/ita", recon.detach(), ita)
        writer.add_scalar("KLD/ita", KLD_total_mean, ita)
        dic = {}
        for j in range(KLDs.shape[0]):

            dic['u_{}'.format(j)] = KLDs[j]

        writer.add_scalars("KLD_units/ita", dic, ita)
        dic.clear()
        dic = {}
        for j in range(batch_var_means.shape[0]):
            dic['var_u_{}'.format(j)] = batch_var_means[j]
        writer.add_scalars("std_means/ita", dic, ita)
        dic.clear()

        if ita >= args.max_iter:
            break
    if i == N - 1:
        losses /= N
        recons /= N
        KLD_batches /= N

        print("After an epoch in itaration_{} AVG: loss:{}, recons:{}, KLD:{}".
              format(ita, losses, recons, KLD_batches))

    return ita
Beispiel #2
0
def train_TC_montecarlo(data_loader, ita, epoch, args):

    losses = 0
    recons = 0
    KLD_batches = 0
    mutual_info = 0
    totall_coorelation = 0
    regularzation = 0

    N = len(data_loader)
    dataset_size = N * args.batch_size
    AVG_KLD = torch.zeros(model.module.z_dim)
    means = []
    for i, (x, y) in tqdm(enumerate(data_loader, 0),
                          total=len(data_loader),
                          smoothing=0.9):

        ita += 1
        train_optimizer.zero_grad()
        x = x.cuda(async=True)

        x_hat, mu, logvar, z, params = model(x)
        model.module.gamma = linear_annealing(0, 1, ita, args.reg_anneal)
        loss, recon, mi, tc, reg = model.module.beta_tc_loss(
            x, x_hat, params, z, dataset_size)

        loss.backward()
        train_optimizer.step()

        losses += float(loss)
        recons += float(recon)
        mutual_info += float(mi)
        totall_coorelation += float(tc)
        regularzation += float(reg)
        KLDs = model.module.kld_unit_guassians_per_sample(
            mu.clone().detach(),
            logvar.clone().detach())
        KLDs = KLDs.detach().mean(0)
        batch_var_means = torch.std(mu.detach(), dim=0).pow(2)
        #var_means = torch.std(mu.detach(),dim=0).pow(2)
        if ita == args.max_iter or epoch == args.epochs - 1:
            AVG_KLD += KLDs
            means.append(mu.clone().detach())

        writer.add_scalar("Loss/ita", loss, ita)
        writer.add_scalar("recons/ita", recon, ita)
        writer.add_scalar("mutual_info/ita", mi, ita)
        writer.add_scalar("totall_coorelation/ita", tc, ita)
        writer.add_scalar("reg/ita", reg, ita)
        dic = {}
        AVG_KLD /= dataset_size
        for j in range(KLDs.shape[0]):
            dic['u_{}'.format(j)] = KLDs[j]

        writer.add_scalars("KLD_units/ita", dic, ita)
        dic.clear()
        dic = {}
        for j in range(batch_var_means.shape[0]):
            dic['var_u_{}'.format(j)] = batch_var_means[j]
        writer.add_scalars("std_means/ita", dic, ita)
        dic.clear()
        if ita >= args.max_iter:
            break

    if i == N - 1:
        losses /= N
        recons /= N
        mutual_info /= N
        totall_coorelation /= N
        regularzation /= N

        print(
            "After an epoch in itaration_{} AVG: loss:{}, recons:{},mutual_info:{}, totall_coorelation:{}, regularzation{}"
            .format(ita, losses, recons, mutual_info, totall_coorelation,
                    regularzation))
    if ita == args.max_iter or epoch == args.epochs - 1:
        cat_means = torch.cat(means)
        VAR_means = torch.std(cat_means, dim=0).pow(2)
        torch.save({
            'AVG_KLDS': AVG_KLD / N,
            'VAR_means': VAR_means
        }, "AVG_KLDs_VAR_means.pth")

    return ita
Beispiel #3
0
def train_analytical_factorvae_tc(data_loader, ita, epoch, args):

    losses = 0
    recons = 0
    KLD_batches = 0
    #KLDs_batches = np.zeros(args.z_dim)

    N = len(data_loader)
    for i, (x, x2, y) in tqdm(enumerate(data_loader, 0),
                              total=len(data_loader),
                              smoothing=0.9):

        ita += 1
        train_optimizer.zero_grad()
        x = x.cuda()
        x_hat, mu, logvar, z, param = model(x)
        model.module.capacity = torch.tensor(
            linear_annealing(0, args.capacity, ita,
                             args.reg_anneal)).float().cuda()
        vae_tc_loss = 0
        if args.factorvae_tc:
            dz = discriminator(z)

            vae_tc_loss = (dz[:, :1] - dz[:, 1:]).mean()

        loss, recon, KLD_total_mean, KLDs = model.module.losses(
            x, x_hat, mu, logvar, z, objective, vae_tc_loss)
        loss.backward(retain_graph=True)
        train_optimizer.step()

        x2 = x2.cuda()
        with torch.no_grad():
            params = model.module.encoder(x2).view(x2.size(0), args.z_dim, 2)
            mu, logstd_var = params.select(-1, 0), params.select(-1, 1)
            z_prime = model.module.reparam(mu, logstd_var)

            z_pperm = permute_dims(z_prime).clone()

            params = model.module.encoder(x).view(x.size(0), args.z_dim, 2)
            mu, logstd_var = params.select(-1, 0), params.select(-1, 1)
            dz_pr = model.module.reparam(mu, logstd_var).clone()

        D_z_pperm = discriminator(z_pperm)
        D_tc_loss = 0.5 * (F.cross_entropy(dz_pr, ZEROS) +
                           F.cross_entropy(D_z_pperm, ONES))

        optim_D.zero_grad()
        D_tc_loss.backward()
        optim_D.step()

        losses += float(loss.clone())
        recons += float(recon)
        KLD_batches += float(KLD_total_mean)

        KLDs = KLDs.detach().mean(0)

        writer.add_scalar("Loss/ita", loss, ita)
        writer.add_scalar("recons/ita", recon, ita)
        writer.add_scalar("KLD/ita", KLD_total_mean, ita)
        dic = {}
        for j in range(KLDs.shape[0]):

            dic['u_{}'.format(j)] = KLDs[j]

        writer.add_scalars("KLD_units/ita", dic, ita)
        dic.clear()
        if ita >= args.max_iter:
            break
    if i == N - 1:
        losses /= N
        recons /= N
        KLD_batches /= N

        print("After an epoch in itaration_{} AVG: loss:{}, recons:{}, KLD:{}".
              format(ita, losses, recons, KLD_batches))

    return ita
Beispiel #4
0
def train_TC_montecarlo(data_loader,ita, epoch, args):

    losses = 0
    recons = 0
    KLD_batches = 0
    mutual_info = 0
    totall_coorelation = 0
    regularzation = 0
    disc_mutual_info = 0
    disc_totall_coorelation = 0
    disc_regularzation = 0
    N = len(data_loader)
    dataset_size = N * args.batch_size
    AVG_KLD = torch.zeros(model.module.num_latent_dims)
    means = []
    for i ,(x, y) in tqdm(enumerate(data_loader, 0),total=len(data_loader),smoothing=0.9):

        ita += 1
        train_optimizer.zero_grad()
        x = x.cuda(async=True)

        x_hat, mu, logvar, z, alphas, rep_as = model(x)
        model.module.gamma = linear_annealing(0, 1, ita, args.reg_anneal)
        loss, recon, mi, tc, reg, mi_disc, tc_disc, reg_disc = \
            model.module.beta_tc_loss(x, x_hat, mu, logvar, z, alphas, rep_as, dataset_size)
     
        
        loss.backward()
        train_optimizer.step()

        
        losses += float(loss)
        recons += float(recon)
        mutual_info += float(mi)
        totall_coorelation += float(tc)
        regularzation += float(reg)
        disc_mutual_info += float(tc_disc)
        disc_totall_coorelation += float(tc_disc)
        disc_regularzation += float(reg_disc) 
        KLDs = model.module.kld_unit_guassians_per_sample(mu.clone().detach(), logvar.clone().detach())
        KLDs = KLDs.detach().mean(0)
        batch_var_means = torch.std(mu.detach(),dim=0).pow(2)

        if ita == args.max_iter or epoch == args.epochs -1:
            AVG_KLD += KLDs
            means.append(mu.clone().detach())
        ###discrete:

        if model.module.num_latent_dims_disc >0:
            KLDs_catigorical = []
            for alpha in alphas:
                unifom_params = torch.ones_like(alpha)/alpha.shape[1]
                kld = kl_divergence(Categorical(alpha.detach()),Categorical(unifom_params))
                KLDs_catigorical.append(kld.clone().detach().mean())
            
            dic ={}
            for disc_dim in range(len(alphas)):
                dic['disc_{}'.format(disc_dim)] = float(KLDs_catigorical[disc_dim])
            writer.add_scalars("KLDs_disc/ita",dic, ita)
        writer.add_scalar("Loss/ita",loss,ita)
        writer.add_scalar("recons/ita",recon,ita)
        writer.add_scalar("mutual_info/ita",mi,ita)
        writer.add_scalar("totall_coorelation/ita",tc,ita)
        writer.add_scalar("reg/ita",reg,ita)
        writer.add_scalar("mutual_info_disc/ita",mi_disc,ita)
        writer.add_scalar("totall_coorelation_disc/ita",tc_disc,ita)
        writer.add_scalar("reg_disc/ita",reg_disc,ita)
        dic = {}
        AVG_KLD /= dataset_size
        for j in range(KLDs.shape[0]):  
            dic['u_{}'.format(j)] = KLDs[j]

        writer.add_scalars("KLD_units/ita",dic, ita)
        dic.clear()
        dic = {}
        for j in range(batch_var_means.shape[0]):  
            dic['var_u_{}'.format(j)] = batch_var_means[j]
        writer.add_scalars("std_means/ita",dic, ita)
        dic.clear()
        if ita >=args.max_iter:
            break
    
    if i == N -1:
        losses /= N
        recons /= N
        mutual_info /= N
        totall_coorelation /=N
        regularzation /= N
        disc_mutual_info /=N
        disc_totall_coorelation /=N
        disc_regularzation /=N
        print("After an epoch in itaration_%0.2f AVG: loss:%0.2f, recons:%0.2f,mutual_info:%0.2f, totall_coorelation:%0.2f, regularzation%0.2f\
            Discrete = mutual_info:%0.2f, totall_coorelation:%0.2f, regularzation%0.2f"%
            (ita, losses, recons, mutual_info, totall_coorelation, regularzation, disc_mutual_info, disc_totall_coorelation, disc_regularzation))
    if ita ==args.max_iter or epoch == args.epochs -1:
        cat_means = torch.cat(means)
        VAR_means = torch.std(cat_means,dim=0).pow(2)
        torch.save({'AVG_KLDS': AVG_KLD/N,'VAR_means':VAR_means},"AVG_KLDs_VAR_means.pth")
       
    return ita