Exemple #1
0
    def __call__(self, x):
        # Obtain parameters for q(z|x)
        encoding_time = time.time()
        self.encode(x)
        encoding_time = float(time.time() - encoding_time)

        decoding_time_average = 0.

        xp = cuda.cupy
        self.importance_weights = 0
        self.w_holder = []
        self.kl = 0
        self.logp = 0

        for j in xrange(self.num_zsamples):
            # Sample z ~ q(z|x)
            z = F.gaussian(self.qmu, self.qln_var)

            # Compute log q(z|x)
            encoder_log = gaussian_logp(z, self.qmu, self.qln_var)

            # Obtain parameters for p(x|z)
            decoding_time = time.time()
            self.decode(z)
            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            # Compute log p(x|z)
            decoder_log = bernoulli_logp(x, self.p_ber_prob_logit)

            # Compute log p(z).
            prior_log = gaussian_logp0(z)

            # Store the latest log weight'
            current_temperature = min(self.temperature['value'], 1.0)
            self.w_holder.append(decoder_log + current_temperature *
                                 (prior_log - encoder_log))

            # Store the KL and Logp equivalents. They are not used for computation but for recording and reporting.
            self.kl += (encoder_log - prior_log)
            self.logp += (decoder_log)

        self.temperature['value'] += self.temperature['increment']

        # Compute w' for this sample (batch)
        logps = F.stack(self.w_holder)
        self.obj_batch = F.logsumexp(logps, axis=0) - np.log(self.num_zsamples)
        self.kl /= self.num_zsamples
        self.logp /= self.num_zsamples

        decoding_time_average /= self.num_zsamples

        batch_size = self.obj_batch.shape[0]

        self.obj = -F.sum(self.obj_batch) / batch_size
        self.timing_info = np.array([encoding_time, decoding_time_average])

        return self.obj
Exemple #2
0
    def __call__(self, x):
        # Compute q(z|x)
        # pdb.set_trace()
        encoding_time = time.time()
        self.encode(x)
        encoding_time = float(time.time() - encoding_time)

        decoding_time_average = 0.

        self.kl = 0
        self.logp = 0
        for j in xrange(self.num_zsamples):
            # z ~ q(z|x)
            z = F.gaussian(self.qmu, self.qln_var)
            # pdb.set_trace()
            # Compute log q(z|x)
            encoder_log = gaussian_logp(z, self.qmu, self.qln_var)

            # Compute p(x|z)
            decoding_time = time.time()
            self.decode(z)
            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            # Computer p(z)
            prior_log = gaussian_logp0(z)

            # Compute objective
            self.kl += (encoder_log-prior_log)
            self.logp += bernoulli_logp(x, self.p_ber_prob_logit)
            # pdb.set_trace()

        current_temperature = min(self.temperature['value'],1.0)
        self.temperature['value'] += self.temperature['increment']

        decoding_time_average /= self.num_zsamples
        self.logp /= self.num_zsamples
        self.kl /= self.num_zsamples
        # pdb.set_trace()
        self.obj_batch = self.logp - (current_temperature*self.kl)
        self.timing_info = np.array([encoding_time,decoding_time_average])

        batch_size = self.obj_batch.shape[0]
        
        self.obj = -F.sum(self.obj_batch)/batch_size

        # pdb.set_trace()
        
        return self.obj
Exemple #3
0
    def __call__(self, x):
        # Obtain parameters for q(z|x)
        encoding_time = time.time()
        self.encode(x)
        encoding_time = float(time.time() - encoding_time)

        self.logp_xz = 0
        self.logq = 0

        # For reporting purposes only
        self.logp = 0
        self.kl = 0

        decoding_time_average = 0.

        current_temperature = min(self.temperature['value'], 1.0)
        self.temperature['value'] += self.temperature['increment']

        for j in xrange(self.num_zsamples):
            # z ~ q(z|x)
            z = F.gaussian(self.qmu, self.qln_var)

            decoding_time = time.time()

            # Apply inverse autoregressive flow (IAF)
            self.logq += gaussian_logp(z, self.qmu,
                                       self.qln_var)  # - log q(z|x)

            for i in range(self.num_trans):
                a_layer_name = 'qiaf_a' + str(i + 1)
                b_layer_name = 'qiaf_b' + str(i + 1)
                z, delta_logq = self.iaf(z, self.qh, self[a_layer_name],
                                         self[b_layer_name])
                self.logq += delta_logq

            self.logq *= current_temperature

            # Compute p(x|z)
            self.decode(z)

            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            # Compute objective, p(x,z)
            logx_given_z = gaussian_logp(x, self.pmu, self.pln_var)  # p(x|z)
            logz = (current_temperature * gaussian_logp0(z))  # p(z)
            self.logp_xz += (logx_given_z + logz)

            # For reporting purposes only
            self.logp += logx_given_z
            self.kl += (self.logq - logz)

        decoding_time_average /= self.num_zsamples
        # self.logp_xz /= self.num_zsamples
        # self.logq /= self.num_zsamples

        # For reporting purposes only
        self.logp /= self.num_zsamples
        self.kl /= self.num_zsamples

        self.obj_batch = self.logp_xz - self.logq  # variational free energy
        self.obj_batch /= self.num_zsamples
        self.timing_info = np.array([encoding_time, decoding_time_average])

        batch_size = self.obj_batch.shape[0]
        self.obj = -F.sum(self.obj_batch) / batch_size

        return self.obj
Exemple #4
0
    def __call__(self, x):
        # Compute parameters for q(z|x, a)
        encoding_time_1 = time.time()
        qmu_a, qln_var_a = self.encode_a(x)
        encoding_time_1 = float(time.time() - encoding_time_1)

        a_enc = F.gaussian(qmu_a, qln_var_a)

        encoding_time_2 = time.time()
        qmu_z, qln_var_z = self.encode_z(x, a_enc)
        encoding_time_2 = float(time.time() - encoding_time_2)

        encoding_time = encoding_time_1 + encoding_time_2

        decoding_time_average = 0.

        self.kl = 0
        self.logp = 0

        logp_a_xz = 0
        logp_x_z = 0
        logp_z = 0
        logq_a_x = 0
        logq_z_ax = 0

        current_temperature = min(self.temperature['value'], 1.0)
        self.temperature['value'] += self.temperature['increment']

        for j in xrange(self.num_zsamples):
            # z ~ q(z|x, a)
            z = F.gaussian(self.qmu_z, self.qln_var_z)

            # Compute p(x|z)
            decoding_time = time.time()
            pmu_a, pln_var_a = self.decode_a(z, x)
            pmu_x, pln_var_x = self.decode(z)
            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            logp_a_xz += gaussian_logp(a_enc, pmu_a, pln_var_a)
            logp_x_z += gaussian_logp(x, pmu_x, pln_var_x)
            logp_z += current_temperature * gaussian_logp0(z)
            logq_a_x += gaussian_logp(a_enc, qmu_a, qln_var_a)
            logq_z_ax += current_temperature * gaussian_logp(
                z, qmu_z, qln_var_z)

        logp_a_xz /= self.num_zsamples
        logp_x_z /= self.num_zsamples
        logp_z /= self.num_zsamples
        logq_a_x /= self.num_zsamples
        logq_z_ax /= self.num_zsamples

        decoding_time_average /= self.num_zsamples
        self.logp /= self.num_zsamples

        self.obj_batch = logp_a_xz + logp_x_z + logp_z - logq_a_x - logq_z_ax
        self.kl = logq_z_ax - logp_z
        self.logp = logp_x_z

        self.timing_info = np.array([encoding_time, decoding_time_average])

        batch_size = self.obj_batch.shape[0]

        self.obj = -F.sum(self.obj_batch) / batch_size

        return self.obj
Exemple #5
0
    def __call__(self, x):
        # Compute q(z|x)
        encoding_time = time.time()
        qmu, qln_var = self.encode(x)
        encoding_time = float(time.time() - encoding_time)

        decoding_time_average = 0.

        self.kl = 0
        self.logp = 0

        current_temperature = min(self.temperature['value'],1.0)
        self.temperature['value'] += self.temperature['increment']

        for j in xrange(self.num_zsamples):
            # Sample z ~ q(z_0|x)
            z_0 = F.gaussian(self.qmu, self.qln_var)

            # Perform planar flow mappings, Equation (10)
            decoding_time = time.time()
            z_K = self.planar_flows(z_0)

            # Obtain parameters for p(x|z_K)
            p_ber_prob_logit =  self.decode(z_K)
            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            # Compute log q(z_0)
            q_prior_log = current_temperature*gaussian_logp0(z_0)
            
            # Compute log p(x|z_K)
            decoder_log = bernoulli_logp(x, p_ber_prob_logit)
            # Compute log p(z_K)
            p_prior_log = current_temperature*gaussian_logp0(z_K)

            # Compute log p(x,z_K) which is log p(x|z_K) + log p(z_K)
            joint_log = decoder_log + p_prior_log

            # Compute second term of log q(z_K)
            q_K_log = 0
            for i in range(self.num_trans):
                flow_u_name = 'flow_u_' + str(i)
                lodget_jacobian = F.sum(self[flow_u_name](self.phi[i]), axis=1)
                q_K_log += F.log(1 + lodget_jacobian)
            q_K_log *= current_temperature

            # For recording purposes only
            self.logp += decoder_log
            self.kl += -(q_prior_log - p_prior_log - q_K_log)


        decoding_time_average /= self.num_zsamples
        # pdb.set_trace()
        self.obj_batch = ((q_prior_log -joint_log) - q_K_log)
        self.obj_batch /= self.num_zsamples
        batch_size = self.obj_batch.shape[0]

        self.obj = F.sum(self.obj_batch)/batch_size

        self.kl /= self.num_zsamples
        self.logp /= self.num_zsamples

        self.timing_info = np.array([encoding_time,decoding_time])
        
        return self.obj