class SimpleNet(): def __init__(self): n_in = 1 n_hidden_1 = 5 n_hidden_2 = 5 self.model = FunctionSet( en1=L.Linear(n_in, n_hidden_1), en2_mu=L.Linear(n_hidden_1, n_hidden_2), en2_var=L.Linear(n_hidden_1, n_hidden_2), de1=L.Linear(n_hidden_2, n_hidden_1), de2=L.Linear(n_hidden_1, n_in) ) self.optimizer = optimizers.Adam() self.optimizer.setup(self.model.collect_parameters()) def encode(self, x_var): h1 = F.tanh(self.model.en1(x_var)) mu = self.model.en2_mu(h1) var = self.model.en2_var(h1) return mu, var def decode(self, z, sigmoid=True): h1 = F.tanh(self.model.de1(z)) h2 = self.model.de2(h1) if sigmoid: return F.sigmoid(h2) return h2 def cost(self, x_var, C=1.0, k=1): mu, ln_var = self.encode(x_var) batchsize = len(mu.data) rec_loss = 0 for l in six.moves.range(k): z = F.gaussian(mu, ln_var) rec_loss += F.bernoulli_nll(x_var, self.decode(z, sigmoid=False)) \ / (k * batchsize) self.rec_loss = rec_loss self.loss = self.rec_loss + C * gaussian_kl_divergence(mu, ln_var) / batchsize return self.loss