コード例 #1
0
class FactorVAE(BaseVAE):
    """Class that implements the Factor Variational Auto-Encoder"""

    def __init__(self, n_input, n_hidden, dim_z, n_output, gamma, binary=True, **kwargs):
        """initialize neural networks
        :param gamma: weight for total correlation term in loss function
        """
        super(FactorVAE, self).__init__()
        self.dim_z = dim_z
        self.binary = binary
        self.gamma = gamma
        self.input_size = (n_input,)

        # VAE networks
        self.vae = VAE(n_input, n_hidden, dim_z, n_output, binary, **kwargs)

        # discriminator layers
        D_hidden_num = 3
        D_hidden_dim = 1000
        D_hidden_dims = [D_hidden_dim] * D_hidden_num
        D_act = nn.LeakyReLU
        D_act_args = {"negative_slope": 0.2, "inplace": False}
        D_output_dim = 2
        self.discriminator = nns.create_mlp(self.dim_z, D_hidden_dims,
                                            act_layer=D_act, act_args=D_act_args, norm=True)
        self.discriminator = nn.Sequential(
            self.discriminator,
            nn.Linear(D_hidden_dim, D_output_dim))

    def encode(self, x):
        """vae encode"""
        return self.vae.encode(x)

    def reparameterize(self, mu, logvar):
        """reparameterization trick"""
        return self.vae.reparameterize(mu, logvar)

    def decode(self, code):
        """vae deocde"""
        return self.vae.decode(code)

    def sample_latent(self, num, device, **kwargs):
        """vae sample latent"""
        return self.vae.sample_latent(num, device, **kwargs)

    def sample(self, num, device, **kwargs):
        """vae sample"""
        return self.vae.sample(num, device, **kwargs)

    def forward(self, input, no_dec=False):
        """autoencoder forward computation"""
        encoded = self.encode(input)
        mu, logvar = encoded
        z = self.reparameterize(mu, logvar) # latent variable z

        if no_dec:
            # no decoding
            return z

        return self.decode(z), encoded, z

    def decoded_to_output(self, decoded, **kwargs):
        """vae transform decoded result to output"""
        return self.vae.decoded_to_output(decoded, **kwargs)

    def reconstruct(self, input, **kwargs):
        """vae reconstruct"""
        return self.vae.reconstruct(input, **kwargs)

    def permute_dims(self, z):
        """permute separately each dimension of the z randomly in a batch
        :param z: [B x D] tensor
        :return: [B x D] tensor with each dim of D dims permuted randomly
        """
        B, D = z.size()
        # generate randomly permuted batch on each dimension
        permuted = []
        for i in range(D):
            ind = torch.randperm(B)
            permuted.append(z[:, i][ind].view(-1, 1))

        return torch.cat(permuted, dim=1)

    def loss_function(self, *inputs, **kwargs):
        """loss function described in the paper (eq. (2))"""
        optim_part = kwargs['optim_part'] # the part to optimize

        if optim_part == 'vae':
            # update VAE
            decoded = inputs[0]
            encoded = inputs[1]
            z = inputs[2]
            x = inputs[3]
            mu, logvar = encoded

            # KL divergence term
            KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()
            if self.binary:
                # likelihood term under Bernolli MLP decoder
                MLD = F.binary_cross_entropy(decoded, x, reduction='sum').div(x.size(0))
            else:
                # likelihood term under Gaussian MLP decoder
                mu_o, logvar_o = decoded
                recon_x_distribution = Normal(
                    loc=mu_o, scale=torch.exp(0.5*logvar_o))
                MLD = -recon_x_distribution.log_prob(x).sum(1).mean()

            Dz = self.discriminator(z)
            tc_loss = (Dz[:, :1] - Dz[:, 1:]).mean()

            return {
                "loss": KLD + MLD + self.gamma * tc_loss,
                "KLD": KLD,
                "MLD": MLD,
                "tc_loss": tc_loss}
        elif optim_part == 'discriminator':
            # update discriminator
            Dz = inputs[0]
            Dz_pperm = inputs[1]
            device = z.device

            ones = torch.ones(
                Dz.size(0), dtype=torch.long, requires_grad=False).to(device)
            zeros = torch.zeros(
                Dz.size(0), dtype=torch.long, requires_grad=False).to(device)

            D_tc_loss = 0.5 * (F.cross_entropy(Dz, zeros) +
                            F.cross_entropy(Dz_pperm, ones))

            return {"loss": D_tc_loss, "D_tc_loss": D_tc_loss}

        else:
            raise Exception("no such network to optimize: {}".format(optim_part))
コード例 #2
0
    'fuc', 'fut'
])
X = single.train_data
Y = single.train_label
C = single.train_captions
print single.color_scheme

latent_dim = int(sys.argv[1])
vae = VAE(784, [500], latent_dim)
vae.train(X,
          batch_size=20,
          num_epochs=100,
          rerun=False,
          model_filename='composite_l_%d' % latent_dim)

latent_z = vae.encode(X)
with open('latent_space_%d_dim_abbrev.txt' % latent_dim, 'w') as lf:
    for i, lz in enumerate(latent_z):
        if Y[i] in ['bat', 'bbt', 'bac', 'bbc']:
            for z in lz:
                lf.write('%f ' % z)
            lf.write('\n')

with open('captions_%d_dim_abbrev.txt' % latent_dim, 'w') as lf:
    for i, _ in enumerate(latent_z):
        if Y[i] in ['bat', 'bbt', 'bac', 'bbc']:
            lf.write('%s\n' % C[i])

if latent_dim == 2:
    latent_z = vae.encode(X)
    latent_y = [single.color_scheme[y] for y in Y]
コード例 #3
0
#torch.autograd.set_detect_anomaly(True)

# train
for e in range(1, args.epoch + 1):
    TOTAL_LOSS, CE_LOSS, VAT_LOSS, VAT_LOSS_ORTH, EN_LOSS = 0, 0, 0, 0, 0
    net.train()
    for [x_raw, y], [x_ul_raw] in zip(X_loader, X_ul_loader):
        if args.cuda:
            x_raw, y = x_raw.cuda(), y.cuda()
            x_ul_raw = x_ul_raw.cuda()

        x_raw, y = Variable(x_raw), Variable(y)
        x_ul_raw = Variable(x_ul_raw)

        out, out_ul = net(x_raw), net(x_ul_raw)
        mu, logvar = vae.encode(x_ul_raw)
        z = vae.reparameterize(mu, logvar)

        # 计算切向扰动
        r_x = r_vat(z, x_ul_raw, out_ul)
        out_adv = net(x_ul_raw + r_x)

        # 计算法向扰动
        r_adv_orth = r_vat_orth(x_ul_raw, r_x, out_ul)
        # torch.cuda.empty_cache()
        out_adv_orth = net(x_ul_raw + r_adv_orth * args.epsilon2)

        # loss 共分为 4 部分
        # vat_loss 只需要算对 out_adv 的梯度, vat_loss_orth 同理
        optimizer.zero_grad()
        vat_loss = kldivergence(out_ul.detach(), out_adv)
コード例 #4
0
single = input_quickdraw.Single(['diamond'], 5000)
X = single.train_data
Y = single.train_label

vae = VAE(784, [500], 2)
vae.train(X,
          batch_size=200,
          num_epochs=100,
          rerun=False,
          model_filename='diamond_only')

print('Encoding...')

latent_z = tuple()
for i in range(3):
    latent_z = latent_z + (vae.encode(X), )
latent_z = np.concatenate(latent_z)

print('Binning...')

nbins = 100
x, y = latent_z.T
k = kde.gaussian_kde(latent_z.T)
xi, yi = np.mgrid[x.min():x.max():nbins * 1j, y.min():y.max():nbins * 1j]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))

plt.clf()
plt.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap='Reds')
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.axes().set_aspect('equal')
コード例 #5
0
    mse_loss = list()
    kld_loss = list()
    for epoch in range(1, args.epochs + 1):
        mse, kld = train(epoch)
        mse_loss.append(mse.cpu().detach().clone().numpy())
        kld_loss.append(kld.cpu().detach().clone().numpy())
        scheduler.step()
        #test(epoch)
        if epoch % args.interval == 0:
            with torch.no_grad():
                sample = model.sample(64, device=device).cpu()
                sample = 0.5 * (sample + 1)
                sample = sample.clamp(0, 1)
                save_image(sample,
                           'result/recon/sample_' + str(epoch) + '.png')
    loss = np.array([mse_loss, kld_loss])
    np.save('result/loss/' + args.out, loss)

    features = list()
    labels = list()
    model.eval()
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            features.extend(mu.cpu().clone().tolist())
            labels.extend(label.tolist())
    features = np.array([x for x in zip(np.array(features), labels)])
    np.save('result/features/' + args.out, features)