def __loss(self):
        print("******   Compute Loss   ******")
        self.loss_nll = mix_logistic_loss(self.x,
                                          self.mix_logistic_params,
                                          masks=self.masks,
                                          output_mean=False)

        self.lam = 0.0
        if self.reg_type is None:
            self.loss_reg = 0
        elif self.reg_type == 'kld':
            self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.kld)
        elif self.reg_type == 'mmd':
            # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z)
            self.mmd = estimate_mmd(
                tf.random_normal(tf.stack([256, self.z_dim])), self.z)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd)
        elif self.reg_type == 'tc':
            self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
                self.z, self.z_mu, self.z_log_sigma_sq, N=10000)
            self.loss_reg = self.mi + self.beta * self.tc + self.dwkld

        self.bits_per_dim = tf.reduce_mean(
            bits_per_dim_tf(nll=self.loss_nll,
                            dim=tf.reduce_sum(1 - self.masks, axis=[1, 2]) *
                            3))
        self.loss_nll = tf.reduce_mean(self.loss_nll)
        self.loss = self.loss_nll + self.loss_reg
    def __loss(self, reg):
        print("******   Compute Loss   ******")
        self.mmd, self.kld, self.mi, self.tc, self.dwkld = [
            None for i in range(5)
        ]
        self.gamma, self.dwmmd = 1e3, None  ## hard coded, experimental
        self.mmdtc = None
        self.loss_ae = mix_logistic_loss(self.x,
                                         self.mix_logistic_params,
                                         masks=self.masks)
        if reg is None:
            self.loss_reg = 0
        elif reg == 'kld':
            self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.kld)
        elif reg == 'mmd':
            # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z)
            self.mmd = estimate_mmd(
                tf.random_normal(tf.stack([256, self.z_dim])), self.z)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd)
        elif reg == 'tc':
            self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
                self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
            self.loss_reg = self.mi + self.beta * self.tc + self.dwkld
        elif reg == 'info-tc':
            self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
                self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
            self.loss_reg = self.beta * self.tc + self.dwkld
        elif reg == 'tc-dwmmd':
            self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
                self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
            self.dwmmd = estimate_mmd(tf.random_normal(int_shape(self.z)),
                                      self.z,
                                      is_dimention_wise=True)
            self.loss_reg = self.beta * self.tc + self.dwmmd * self.gamma
        elif reg == 'mmd-tc':
            self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)),
                                    self.z)
            self.mmdtc = estimate_mmdtc(self.z, self.random_indices)
            self.loss_reg = (self.mmd + self.beta * self.mmdtc) * 1e5

        self.mi = estimate_mi(self.z, self.z_mu, self.z_log_sigma_sq, N=200000)

        print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta,
                                                  self.lam))
        self.loss = self.loss_ae + self.loss_reg
示例#3
0
    def __loss(self, reg):
        print("******   Compute Loss   ******")
        self.mmd, self.kld = [None, None]
        self.loss_ae = mix_logistic_loss(self.x,
                                         self.mix_logistic_params,
                                         masks=self.masks)
        if reg is None:
            self.loss_reg = 0
        elif reg == 'kld':
            self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.kld)
        elif reg == 'mmd':
            # self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)), self.z)
            self.mmd = estimate_mmd(
                tf.random_normal(tf.stack([256, self.z_dim])), self.z)
            self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd)

        print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta,
                                                  self.lam))
        self.loss = self.loss_ae + self.loss_reg
 def __loss(self, reg):
     print("******   Compute Loss   ******")
     self.mmd, self.kld, self.mi, self.tc, self.dwkld = [
         None for i in range(5)
     ]
     self.loss_ae = gaussian_recons_loss(self.x, self.x_hat)
     if reg is None:
         self.loss_reg = 0
     elif reg == 'kld':
         self.kld = compute_gaussian_kld(self.z_mu, self.z_log_sigma_sq)
         self.loss_reg = self.beta * tf.maximum(self.lam, self.kld)
     elif reg == 'mmd':
         self.mmd = estimate_mmd(tf.random_normal(int_shape(self.z)),
                                 self.z)
         self.loss_reg = self.beta * tf.maximum(self.lam, self.mmd)
     elif reg == 'tc':
         self.mi, self.tc, self.dwkld = estimate_mi_tc_dwkld(
             self.z, self.z_mu, self.z_log_sigma_sq, N=self.N)
         self.loss_reg = self.mi + self.beta * self.tc + self.dwkld
     print("reg:{0}, beta:{1}, lam:{2}".format(self.reg, self.beta,
                                               self.lam))
     self.loss = self.loss_ae + self.loss_reg