コード例 #1
0
ファイル: iwae.py プロジェクト: dementrock/iwae
    def gradIminibatch_srng(self, x, srng, num_samples, model_type='iwae'):
        # rep_x = T.extra_ops.repeat(x, num_samples, axis=0)
        rep_x = t_repeat(x, num_samples, axis=0)  # works marginally faster than theano's T.extra_ops.repeat
        q_samples = self.q_samplesIx_srng(rep_x, srng)

        log_ws = self.log_weightsIq_samples(q_samples)

        log_ws_matrix = log_ws.reshape((x.shape[0], num_samples))
        log_ws_minus_max = log_ws_matrix - T.max(log_ws_matrix, axis=1, keepdims=True)
        ws = T.exp(log_ws_minus_max)
        ws_normalized = ws / T.sum(ws, axis=1, keepdims=True)
        ws_normalized_vector = T.reshape(ws_normalized, log_ws.shape)

        dummy_vec = T.vector(dtype=theano.config.floatX)

        if model_type in ['vae', 'VAE']:
            print "Training a VAE"
            return collections.OrderedDict([(
                                             param,
                                             T.grad(T.sum(log_ws)/T.cast(num_samples, log_ws.dtype), param)
                                            )
                                            for param in self.params])
        else:
            print "Training an IWAE"
            return collections.OrderedDict([(
                                             param,
                                             theano.clone(
                                                T.grad(T.dot(log_ws, dummy_vec), param),
                                                replace={dummy_vec: ws_normalized_vector})
                                            )
                                            for param in self.params])
コード例 #2
0
ファイル: iwae.py プロジェクト: shafiahmed/iwae
    def gradIminibatch_srng(self, x, srng, num_samples, model_type='iwae'):
        # rep_x = T.extra_ops.repeat(x, num_samples, axis=0)
        rep_x = t_repeat(
            x, num_samples,
            axis=0)  # works marginally faster than theano's T.extra_ops.repeat
        q_samples = self.q_samplesIx_srng(rep_x, srng)

        log_ws = self.log_weightsIq_samples(q_samples)

        log_ws_matrix = log_ws.reshape((x.shape[0], num_samples))
        log_ws_minus_max = log_ws_matrix - T.max(
            log_ws_matrix, axis=1, keepdims=True)
        ws = T.exp(log_ws_minus_max)
        ws_normalized = ws / T.sum(ws, axis=1, keepdims=True)
        ws_normalized_vector = T.reshape(ws_normalized, log_ws.shape)

        dummy_vec = T.vector(dtype=theano.config.floatX)

        if model_type in ['vae', 'VAE']:
            print "Training a VAE"
            return collections.OrderedDict([(
                param,
                T.grad(
                    T.sum(log_ws) / T.cast(num_samples, log_ws.dtype), param),
            ) for param in self.params])
        else:
            print "Training an IWAE"
            return collections.OrderedDict([
                (param,
                 theano.clone(T.grad(T.dot(log_ws, dummy_vec), param),
                              replace={dummy_vec: ws_normalized_vector}))
                for param in self.params
            ])
コード例 #3
0
    def gradIminibatch_srng(self,
                            x,
                            srng,
                            num_samples,
                            model_type='iwae',
                            backward_pass='******'):
        # rep_x = T.extra_ops.repeat(x, num_samples, axis=0)
        rep_x = t_repeat(
            x, num_samples,
            axis=0)  # works marginally faster than theano's T.extra_ops.repeat
        q_samples = self.q_samplesIx_srng(rep_x, srng)

        log_ws = self.log_weightsIq_samples(q_samples)

        log_ws_matrix = log_ws.reshape((x.shape[0], num_samples))

        # for alpha divergence (take 0 <= alpha <= 1)
        # see the math to show why we can directly set alpha = 1,
        # with reparameterization trick
        if backward_pass == 'full':
            log_ws_matrix *= (1.0 - self.alpha)

            log_ws_minus_max = log_ws_matrix - T.max(
                log_ws_matrix, axis=1, keepdims=True)
            ws = T.exp(log_ws_minus_max)
            ws_normalized = ws / T.sum(ws, axis=1, keepdims=True)
            ws_normalized_vector = T.reshape(ws_normalized, log_ws.shape)

            dummy_vec = T.vector(dtype=theano.config.floatX)

        else:
            # just take the particle that has the largest (unnormalised) weight
            # NOTE: might pick different particles for different datapoint!
            log_ws_max = log_ws_matrix.max(axis=1)

        if backward_pass == 'max':
            print "Training an AAE with largest particle"
            return collections.OrderedDict([
                (param,
                 T.grad(T.sum(log_ws_max) / T.cast(1, log_ws.dtype), param))
                for param in self.params
            ])
        elif model_type in ['vae', 'VAE']:
            print "Training a VAE"
            return collections.OrderedDict([
                (param,
                 T.grad(
                     T.sum(log_ws) / T.cast(num_samples, log_ws.dtype), param))
                for param in self.params
            ])
        else:
            print "Training an AAE with alpha = %.2f, k = %d" % (self.alpha,
                                                                 num_samples)
            return collections.OrderedDict([
                (param,
                 theano.clone(T.grad(T.dot(log_ws, dummy_vec), param),
                              replace={dummy_vec: ws_normalized_vector}))
                for param in self.params
            ])
コード例 #4
0
    def get_aux_mult(self):
        a_first_row_unnorm = (self.digams_1_cumsum - self.digams_1p2_cumsum + self.digams[:,1]).reshape((1,self.K))

        a_first_row_unnorm_rep = t_repeat(a_first_row_unnorm, self.K, axis=0).reshape((self.K,self.K))

        a = T.exp(a_first_row_unnorm_rep) * T.tril(T.ones((self.K, self.K)))

        return a / T.sum(a, 1).reshape((self.K,1))
コード例 #5
0
    def log_marginal_likelihood_estimate(self, x, num_samples, srng):
        num_xs = x.shape[0]
        # rep_x = T.extra_ops.repeat(x, num_samples, axis=0)
        rep_x = t_repeat(x, num_samples, axis=0)
        samples = self.q_samplesIx_srng(rep_x, srng)

        log_ws = self.log_weightsIq_samples(samples)
        log_ws_matrix = T.reshape(log_ws, (num_xs, num_samples))
        log_marginal_estimate = log_mean_exp(log_ws_matrix, axis=1)

        return log_marginal_estimate
コード例 #6
0
ファイル: iwae.py プロジェクト: dementrock/iwae
    def log_marginal_likelihood_estimate(self, x, num_samples, srng):
        num_xs = x.shape[0]
        # rep_x = T.extra_ops.repeat(x, num_samples, axis=0)
        rep_x = t_repeat(x, num_samples, axis=0)
        samples = self.q_samplesIx_srng(rep_x, srng)

        log_ws = self.log_weightsIq_samples(samples)
        log_ws_matrix = T.reshape(log_ws, (num_xs, num_samples))
        log_marginal_estimate = log_mean_exp(log_ws_matrix, axis=1)

        return log_marginal_estimate