def setup(self, x_shape): batch_size = x_shape[0] self.x_src = expr.Source(x_shape) loss = 0 # Encode enc = self.encoder(self.x_src) z, self.encoder_loss = self.latent_encoder.encode(enc, batch_size) loss += self.encoder_loss # Decode x_tilde = self.decoder(z) if self.recon_depth > 0: # Reconstruction error in discriminator x = expr.Concatenate(axis=0)(x_tilde, self.x_src) d = self.discriminator_recon(x) d = expr.Reshape((batch_size*2, -1))(d) d_x_tilde, d_x = expr.Slices([batch_size])(d) loss += self.recon_error(d_x_tilde, d_x) else: loss += self.recon_error(x_tilde, self.x_src) # Kill gradient from GAN loss to AE encoder z = ScaleGradient(0.0)(z) # Decode for GAN loss gen_size = batch_size if self.sample_z: gen_size += batch_size z_samples = self.latent_encoder.samples(batch_size) z = expr.Concatenate(axis=0)(z, z_samples) x = self.decoder_neggrad(z) x = expr.Concatenate(axis=0)(self.x_src, x) # Scale gradients to balance real vs. generated contributions to GAN # discriminator dis_batch_size = batch_size + gen_size real_weight = self.real_vs_gen_weight gen_weight = (1-self.real_vs_gen_weight) * float(batch_size)/gen_size weights = np.zeros((dis_batch_size, 1)) weights[:batch_size] = real_weight weights[batch_size:] = gen_weight dis_weights = ca.array(weights) shape = np.array(x_shape)**0 shape[0] = dis_batch_size dis_weights_inv = ca.array(1.0 / np.reshape(weights, shape)) x = ScaleGradient(dis_weights_inv)(x) # Discriminate d = self.discriminator(x) d = ScaleGradient(dis_weights)(d) sign = np.ones((gen_size + batch_size, 1), dtype=ca.float_) sign[batch_size:] = -1.0 offset = np.zeros_like(sign) offset[batch_size:] = 1.0 self.gan_loss = expr.log(d*sign + offset + self.eps) self._graph = expr.ExprGraph(expr.sum(loss) + expr.sum(-self.gan_loss)) self._graph.out_grad = ca.array(1.0) self._graph.setup()
def setup(self, x_shape): batch_size = x_shape[0] self.x_src = expr.Source(x_shape) z = expr.random.normal(size=(batch_size, self.n_hidden)) x_tilde = self.generator(z) x = expr.Concatenate(axis=0)(self.x_src, x_tilde) if self.real_vs_gen_weight != 0.5: # Scale gradients to balance real vs. generated contributions to # GAN discriminator dis_batch_size = batch_size*2 weights = np.zeros((dis_batch_size, 1)) weights[:batch_size] = self.real_vs_gen_weight weights[batch_size:] = (1-self.real_vs_gen_weight) dis_weights = ca.array(weights) shape = np.array(x_shape)**0 shape[0] = dis_batch_size dis_weights_inv = ca.array(1.0 / np.reshape(weights, shape)) x = ScaleGradient(dis_weights_inv)(x) # Discriminate d = self.discriminator(x) if self.real_vs_gen_weight != 0.5: d = ScaleGradient(dis_weights)(d) sign = np.ones((batch_size*2, 1), dtype=ca.float_) sign[batch_size:] = -1.0 offset = np.zeros_like(sign) offset[batch_size:] = 1.0 self.gan_loss = expr.log(d*sign + offset + self.eps) self._graph = expr.ExprGraph(-expr.sum(self.gan_loss)) self._graph.out_grad = ca.array(1.0) self._graph.setup()
def setup(self, x_shape): batch_size = x_shape[0] self.x_src = ex.Source(x_shape) z = ex.random.normal(size=(batch_size, self.n_hidden)) x_tilde = self.generator(z) x = ex.Concatenate(axis=0)(self.x_src, x_tilde) if self.real_vs_gen_weight != 0.5: # Scale gradients to balance real vs. generated contributions to # GAN discriminator dis_batch_size = batch_size * 2 weights = np.zeros((dis_batch_size, 1)) weights[:batch_size] = self.real_vs_gen_weight weights[batch_size:] = (1 - self.real_vs_gen_weight) dis_weights = ca.array(weights) shape = np.array(x_shape)**0 shape[0] = dis_batch_size dis_weights_inv = ca.array(1.0 / np.reshape(weights, shape)) x = ScaleGradient(dis_weights_inv)(x) # Discriminate d = self.discriminator(x) if self.real_vs_gen_weight != 0.5: d = ScaleGradient(dis_weights)(d) sign = np.ones((batch_size * 2, 1), dtype=ca.float_) sign[batch_size:] = -1.0 offset = np.zeros_like(sign) offset[batch_size:] = 1.0 self.gan_loss = ex.log(d * sign + offset + self.eps) self.loss = ex.sum(self.gan_loss) self._graph = ex.graph.ExprGraph(self.loss) self._graph.setup() self.loss.grad_array = ca.array(-1.0)
def setup(self, x_shape, y_shape): batch_size = x_shape[0] self.sampler.batch_size = x_shape[0] self.x_src = expr.Source(x_shape) self.y_src = expr.Source(y_shape) if self.mode in ['vae', 'vaegan']: h_enc = self.encoder(self.x_src, self.y_src) z, z_mu, z_log_sigma, z_eps = self.sampler(h_enc) self.kld = KLDivergence()(z_mu, z_log_sigma) x_tilde = self.generator(z, self.y_src) self.logpxz = self.reconstruct_error(x_tilde, self.x_src) loss = 0.5*self.kld + expr.sum(self.logpxz) if self.mode in ['gan', 'vaegan']: y = self.y_src if self.mode == 'gan': z = self.sampler.samples() x_tilde = self.generator(z, y) gen_size = batch_size elif self.mode == 'vaegan': z = ScaleGradient(0.0)(z) z = expr.Concatenate(axis=0)(z, z_eps) y = expr.Concatenate(axis=0)(y, self.y_src) x_tilde = self.generator_neg(z, y) gen_size = batch_size*2 x = expr.Concatenate(axis=0)(self.x_src, x_tilde) y = expr.Concatenate(axis=0)(y, self.y_src) d = self.discriminator(x, y) d = expr.clip(d, self.eps, 1.0-self.eps) real_size = batch_size sign = np.ones((real_size + gen_size, 1), dtype=ca.float_) sign[real_size:] = -1.0 offset = np.zeros_like(sign) offset[real_size:] = 1.0 self.gan_loss = expr.log(d*sign + offset) if self.mode == 'gan': loss = expr.sum(-self.gan_loss) elif self.mode == 'vaegan': loss = loss + expr.sum(-self.gan_loss) self._graph = expr.ExprGraph(loss) self._graph.out_grad = ca.array(1.0) self._graph.setup()
def setup(self, x_shape): self.sampler.batch_size = x_shape[0] self.x_src = expr.Source(x_shape) h_enc = self.encoder(self.x_src) z, z_mu, z_log_sigma = self.sampler(h_enc) kld = KLDivergence()(z_mu, z_log_sigma) x_tilde = self.decoder(z) logpxz = self.reconstruct_error(x_tilde, self.x_src) lowerbound = kld + expr.sum(logpxz) self._lowerbound_graph = expr.ExprGraph(lowerbound) self._lowerbound_graph.out_grad = ca.array(1.0) self._lowerbound_graph.setup()
def setup(self, x_shape): self.sampler.batch_size = x_shape[0] self.x_src = expr.Source(x_shape) h_enc = self.encoder(self.x_src) z, z_mu, z_log_sigma = self.sampler(h_enc) kld = KLDivergence()(z_mu, z_log_sigma) x_tilde = self.decoder(z) logpxz = self.reconstruct_error(x_tilde, self.x_src) self.lowerbound = kld + expr.sum(logpxz) self._graph = expr.graph.ExprGraph(self.lowerbound) self._graph.setup() self.lowerbound.grad_array = ca.array(1.0)
def encode(self, h_enc, batch_size): z = self.z_enc(h_enc) z_ = ScaleGradient(-1.0)(z) z_samples = self.samples(batch_size) z_ = ex.Concatenate(axis=0)(z_samples, z_) d_z = self.discriminator(z_) sign = np.ones((batch_size * 2, 1), dtype=ca.float_) sign[batch_size:] = -1.0 offset = np.zeros_like(sign) offset[batch_size:] = 1.0 loss = ex.sum(-ex.log(d_z * sign + offset + self.eps)) z = ScaleGradient(self.recon_weight)(z) return z, loss
def encode(self, h_enc, batch_size): z = self.z_enc(h_enc) z_ = ScaleGradient(-1.0)(z) z_samples = self.samples(batch_size) z_ = ex.Concatenate(axis=0)(z_samples, z_) d_z = self.discriminator(z_) sign = np.ones((batch_size*2, 1), dtype=ca.float_) sign[batch_size:] = -1.0 offset = np.zeros_like(sign) offset[batch_size:] = 1.0 loss = ex.sum(-ex.log(d_z*sign + offset + self.eps)) z = ScaleGradient(self.recon_weight)(z) return z, loss
def setup(self, x_shape): batch_size = x_shape[0] self.x_src = ex.Source(x_shape) loss = 0 # Encode enc = self.encoder(self.x_src) z, self.encoder_loss = self.latent_encoder.encode(enc, batch_size) loss += self.encoder_loss # Decode x_tilde = self.decoder(z) if self.recon_depth > 0: # Reconstruction error in discriminator x = ex.Concatenate(axis=0)(x_tilde, self.x_src) d = self.discriminator_recon(x) d_x_tilde, d_x = ex.Slices([batch_size])(d) loss += self.recon_error(d_x_tilde, d_x) else: loss += self.recon_error(x_tilde, self.x_src) # Kill gradient from GAN loss to AE encoder z = ScaleGradient(0.0)(z) # Decode for GAN loss gen_size = 0 if self.discriminate_ae_recon: gen_size += batch_size # Kill gradient from GAN loss to AE encoder z = ScaleGradient(0.0)(z) if self.discriminate_sample_z: gen_size += batch_size z_samples = self.latent_encoder.samples(batch_size) if self.discriminate_ae_recon: z = ex.Concatenate(axis=0)(z, z_samples) else: z = z_samples if gen_size == 0: raise ValueError('GAN does not receive any generated samples.') x = self.decoder_neggrad(z) x = ex.Concatenate(axis=0)(self.x_src, x) # Scale gradients to balance real vs. generated contributions to GAN # discriminator dis_batch_size = batch_size + gen_size real_weight = self.real_vs_gen_weight gen_weight = (1 - self.real_vs_gen_weight) * float(batch_size) / gen_size weights = np.zeros((dis_batch_size, )) weights[:batch_size] = real_weight weights[batch_size:] = gen_weight dis_weights = ca.array(weights) shape = np.array(x_shape)**0 shape[0] = dis_batch_size dis_weights_inv = ca.array(1.0 / np.reshape(weights, shape)) x = ScaleGradient(dis_weights_inv)(x) # Discriminate d = self.discriminator(x) d = ex.Reshape((-1, ))(d) d = ScaleGradient(dis_weights)(d) sign = np.ones((gen_size + batch_size, ), dtype=ca.float_) sign[batch_size:] = -1.0 offset = np.zeros_like(sign) offset[batch_size:] = 1.0 self.gan_loss = ex.log(d * sign + offset + self.eps) self.loss = ex.sum(loss) - ex.sum(self.gan_loss) self._graph = ex.graph.ExprGraph(self.loss) self._graph.setup() self.loss.grad_array = ca.array(1.0)