def __loss(self): print("****** Compute Loss ******") self.loss_nll = mix_logistic_loss(self.x, self.mix_logistic_params, masks=self.masks, output_mean=False) self.lam = 0.0 if self.reg_type is None: self.loss_reg = 0 elif self.reg_type == 'kld': self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq) self.loss_reg = self.beta * tf.maximum(self.lam, self.kld) elif self.reg_type == 'mmd': # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.mmd = estimate_mmd( tf.random_normal(tf.stack([256, self.z_dim])), self.z) self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd) elif self.reg_type == 'tc': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=10000) self.loss_reg = self.mi + self.beta * self.tc + self.dwkld self.bits_per_dim = tf.reduce_mean( bits_per_dim_tf(nll=self.loss_nll, dim=tf.reduce_sum(1 - self.masks, axis=[1, 2]) * 3)) self.loss_nll = tf.reduce_mean(self.loss_nll) self.loss = self.loss_nll + self.loss_reg
def __loss(self, reg): print("****** Compute Loss ******") self.mmd, self.kld, self.mi, self.tc, self.dwkld = [ None for i in range(5) ] self.gamma, self.dwmmd = 1e3, None ## hard coded, experimental self.mmdtc = None self.loss_ae = mix_logistic_loss(self.x, self.mix_logistic_params, masks=self.masks) if reg is None: self.loss_reg = 0 elif reg == 'kld': self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq) self.loss_reg = self.beta * tf.maximum(self.lam, self.kld) elif reg == 'mmd': # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.mmd = estimate_mmd( tf.random_normal(tf.stack([256, self.z_dim])), self.z) self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd) elif reg == 'tc': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.loss_reg = self.mi + self.beta * self.tc + self.dwkld elif reg == 'info-tc': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.loss_reg = self.beta * self.tc + self.dwkld elif reg == 'tc-dwmmd': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.dwmmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z, is_dimention_wise=True) self.loss_reg = self.beta * self.tc + self.dwmmd * self.gamma elif reg == 'mmd-tc': self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.mmdtc = estimate_mmdtc(self.z, self.random_indices) self.loss_reg = (self.mmd + self.beta * self.mmdtc) * 1e5 self.mi = estimate_mi(self.z, self.z_mu, self.z_log_sigma_sq, N=200000) print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta, self.lam)) self.loss = self.loss_ae + self.loss_reg
def __loss(self, reg): print("****** Compute Loss ******") self.mmd, self.kld = [None, None] self.loss_ae = mix_logistic_loss(self.x, self.mix_logistic_params, masks=self.masks) if reg is None: self.loss_reg = 0 elif reg == 'kld': self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq) self.loss_reg = self.beta * tf.maximum(self.lam, self.kld) elif reg == 'mmd': # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.mmd = estimate_mmd( tf.random_normal(tf.stack([256, self.z_dim])), self.z) self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd) print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta, self.lam)) self.loss = self.loss_ae + self.loss_reg
def __loss(self, reg): print("****** Compute Loss ******") self.mmd, self.kld, self.mi, self.tc, self.dwkld = [ None for i in range(5) ] self.loss_ae = gaussian_recons_loss(self.x, self.x_hat) if reg is None: self.loss_reg = 0 elif reg == 'kld': self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq) self.loss_reg = self.beta * tf.maximum(self.lam, self.kld) elif reg == 'mmd': self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z) self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd) elif reg == 'tc': self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld( self.z, self.z_mu, self.z_log_sigma_sq, N=self.N) self.loss_reg = self.mi + self.beta * self.tc + self.dwkld print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta, self.lam)) self.loss = self.loss_ae + self.loss_reg