def stochastic_layer(layer_in, n, samples, nonlin=None):
     mu = DenseLayer(layer_in, n, init.Normal(init_w),
                     init.Normal(init_w), nonlin)
     logvar = DenseLayer(layer_in, n, init.Normal(init_w),
                         init.Normal(init_w), nonlin)
     return SampleLayer(mu, logvar, eq_samples=samples,
                        iw_samples=1), mu, logvar
 def stochastic_layer(layer_in, n, samples, nonlin=None):
     mu = DenseLayer(layer_in,
                     n,
                     W=init.Normal(init_w, mean=.0),
                     b=init.Normal(init_w),
                     nonlinearity=nonlin)
     logvar = DenseLayer(layer_in,
                         n,
                         W=init.Normal(init_w, mean=.0),
                         b=init.Normal(init_w),
                         nonlinearity=nonlin)
     # logvar = ConstrainLayer(logvar, scale=1, max=T.log(-0.999 * self.sym_warmup + 1.0999))
     return SampleLayer(mu, logvar, eq_samples=samples,
                        iw_samples=1), mu, logvar
Esempio n. 3
0
    def __init__(self,
                 n_x,
                 n_a,
                 n_z,
                 n_y,
                 a_hidden,
                 z_hidden,
                 xhat_hidden,
                 y_hidden,
                 trans_func=rectify,
                 x_dist='bernoulli'):
        """
        Initialize an auxiliary deep generative model consisting of
        discriminative classifier q(y|a,x),
        generative model P p(xhat|z,y),
        inference model Q q(a|x) and q(z|x,y).
        All weights are initialized using the Bengio and Glorot (2010) initialization scheme.
        :param n_x: Number of inputs.
        :param n_a: Number of auxiliary.
        :param n_z: Number of latent.
        :param n_y: Number of classes.
        :param a_hidden: List of number of deterministic hidden q(a|x).
        :param z_hidden: List of number of deterministic hidden q(z|x,y).
        :param xhat_hidden: List of number of deterministic hidden p(xhat|z,y).
        :param y_hidden: List of number of deterministic hidden q(y|a,x).
        :param trans_func: The transfer function used in the deterministic layers.
        :param x_dist: The x distribution, 'bernoulli' or 'gaussian'.
        """
        super(ADGMSSL, self).__init__(n_x, a_hidden + z_hidden + xhat_hidden,
                                      n_a + n_z, trans_func)
        self.y_hidden = y_hidden
        self.x_dist = x_dist
        self.n_y = n_y
        self.n_x = n_x
        self.n_a = n_a
        self.n_z = n_z

        self._srng = RandomStreams()

        self.sym_beta = T.scalar(
            'beta')  # symbolic upscaling of the discriminative term.
        self.sym_x_l = T.matrix('x')  # symbolic labeled inputs
        self.sym_t_l = T.matrix('t')  # symbolic labeled targets
        self.sym_x_u = T.matrix('x')  # symbolic unlabeled inputs
        self.sym_bs_l = T.iscalar(
            'bs_l'
        )  # symbolic number of labeled data_preparation points in batch
        self.sym_samples = T.iscalar(
            'samples')  # symbolic number of Monte Carlo samples
        self.sym_y = T.matrix('y')
        self.sym_z = T.matrix('z')

        ### Input layers ###
        l_x_in = InputLayer((None, n_x))
        l_y_in = InputLayer((None, n_y))

        ### Auxiliary q(a|x) ###
        l_a_x = l_x_in
        for hid in a_hidden:
            l_a_x = DenseLayer(l_a_x, hid, init.GlorotNormal('relu'),
                               init.Normal(1e-3), self.transf)
        l_a_x_mu = DenseLayer(l_a_x, n_a, init.GlorotNormal(),
                              init.Normal(1e-3), None)
        l_a_x_logvar = DenseLayer(l_a_x, n_a, init.GlorotNormal(),
                                  init.Normal(1e-3), None)
        l_a_x = SampleLayer(l_a_x_mu,
                            l_a_x_logvar,
                            eq_samples=self.sym_samples)
        # Reshape all layers to align them for multiple samples in the lower bound calculation.
        l_a_x_reshaped = ReshapeLayer(l_a_x, (-1, self.sym_samples, 1, n_a))
        l_a_x_mu_reshaped = DimshuffleLayer(l_a_x_mu, (0, 'x', 'x', 1))
        l_a_x_logvar_reshaped = DimshuffleLayer(l_a_x_logvar, (0, 'x', 'x', 1))

        ### Classifier q(y|a,x) ###
        # Concatenate the input x and the output of the auxiliary MLP.
        l_a_to_y = DenseLayer(l_a_x, y_hidden[0], init.GlorotNormal('relu'),
                              init.Normal(1e-3), None)
        l_a_to_y = ReshapeLayer(l_a_to_y,
                                (-1, self.sym_samples, 1, y_hidden[0]))
        l_x_to_y = DenseLayer(l_x_in, y_hidden[0], init.GlorotNormal('relu'),
                              init.Normal(1e-3), None)
        l_x_to_y = DimshuffleLayer(l_x_to_y, (0, 'x', 'x', 1))
        l_y_xa = ReshapeLayer(ElemwiseSumLayer([l_a_to_y, l_x_to_y]),
                              (-1, y_hidden[0]))
        l_y_xa = NonlinearityLayer(l_y_xa, self.transf)

        if len(y_hidden) > 1:
            for hid in y_hidden[1:]:
                l_y_xa = DenseLayer(l_y_xa, hid, init.GlorotUniform('relu'),
                                    init.Normal(1e-3), self.transf)
        l_y_xa = DenseLayer(l_y_xa, n_y, init.GlorotUniform(),
                            init.Normal(1e-3), softmax)
        l_y_xa_reshaped = ReshapeLayer(l_y_xa, (-1, self.sym_samples, 1, n_y))

        ### Recognition q(z|x,y) ###
        # Concatenate the input x and y.
        l_x_to_z = DenseLayer(l_x_in, z_hidden[0], init.GlorotNormal('relu'),
                              init.Normal(1e-3), None)
        l_x_to_z = DimshuffleLayer(l_x_to_z, (0, 'x', 'x', 1))
        l_y_to_z = DenseLayer(l_y_in, z_hidden[0], init.GlorotNormal('relu'),
                              init.Normal(1e-3), None)
        l_y_to_z = DimshuffleLayer(l_y_to_z, (0, 'x', 'x', 1))
        l_z_xy = ReshapeLayer(ElemwiseSumLayer([l_x_to_z, l_y_to_z]),
                              [-1, z_hidden[0]])
        l_z_xy = NonlinearityLayer(l_z_xy, self.transf)

        if len(z_hidden) > 1:
            for hid in z_hidden[1:]:
                l_z_xy = DenseLayer(l_z_xy, hid, init.GlorotNormal('relu'),
                                    init.Normal(1e-3), self.transf)
        l_z_axy_mu = DenseLayer(l_z_xy, n_z, init.GlorotNormal(),
                                init.Normal(1e-3), None)
        l_z_axy_logvar = DenseLayer(l_z_xy, n_z, init.GlorotNormal(),
                                    init.Normal(1e-3), None)
        l_z_xy = SampleLayer(l_z_axy_mu,
                             l_z_axy_logvar,
                             eq_samples=self.sym_samples)
        # Reshape all layers to align them for multiple samples in the lower bound calculation.
        l_z_axy_mu_reshaped = DimshuffleLayer(l_z_axy_mu, (0, 'x', 'x', 1))
        l_z_axy_logvar_reshaped = DimshuffleLayer(l_z_axy_logvar,
                                                  (0, 'x', 'x', 1))
        l_z_axy_reshaped = ReshapeLayer(l_z_xy, (-1, self.sym_samples, 1, n_z))

        ### Generative p(xhat|z,y) ###
        # Concatenate the input x and y.
        l_y_to_xhat = DenseLayer(l_y_in, xhat_hidden[0],
                                 init.GlorotNormal('relu'), init.Normal(1e-3),
                                 None)
        l_y_to_xhat = DimshuffleLayer(l_y_to_xhat, (0, 'x', 'x', 1))
        l_z_to_xhat = DenseLayer(l_z_xy, xhat_hidden[0],
                                 init.GlorotNormal('relu'), init.Normal(1e-3),
                                 None)
        l_z_to_xhat = ReshapeLayer(l_z_to_xhat,
                                   (-1, self.sym_samples, 1, xhat_hidden[0]))
        l_xhat_zy = ReshapeLayer(ElemwiseSumLayer([l_z_to_xhat, l_y_to_xhat]),
                                 [-1, xhat_hidden[0]])
        l_xhat_zy = NonlinearityLayer(l_xhat_zy, self.transf)
        if len(xhat_hidden) > 1:
            for hid in xhat_hidden[1:]:
                l_xhat_zy = DenseLayer(l_xhat_zy, hid,
                                       init.GlorotNormal('relu'),
                                       init.Normal(1e-3), self.transf)
        if x_dist == 'bernoulli':
            l_xhat_zy_mu_reshaped = None
            l_xhat_zy_logvar_reshaped = None
            l_xhat_zy = DenseLayer(l_xhat_zy, n_x, init.GlorotNormal(),
                                   init.Normal(1e-3), sigmoid)
        elif x_dist == 'multinomial':
            l_xhat_zy_mu_reshaped = None
            l_xhat_zy_logvar_reshaped = None
            l_xhat_zy = DenseLayer(l_xhat_zy, n_x, init.GlorotNormal(),
                                   init.Normal(1e-3), softmax)
        elif x_dist == 'gaussian':
            l_xhat_zy_mu = DenseLayer(l_xhat_zy, n_x, init.GlorotNormal(),
                                      init.Normal(1e-3), None)
            l_xhat_zy_logvar = DenseLayer(l_xhat_zy, n_x, init.GlorotNormal(),
                                          init.Normal(1e-3), None)
            l_xhat_zy = SampleLayer(l_xhat_zy_mu,
                                    l_xhat_zy_logvar,
                                    eq_samples=1)
            l_xhat_zy_mu_reshaped = ReshapeLayer(
                l_xhat_zy_mu, (-1, self.sym_samples, 1, n_x))
            l_xhat_zy_logvar_reshaped = ReshapeLayer(
                l_xhat_zy_logvar, (-1, self.sym_samples, 1, n_x))
        l_xhat_zy_reshaped = ReshapeLayer(l_xhat_zy,
                                          (-1, self.sym_samples, 1, n_x))

        ### Various class variables ###
        self.l_x_in = l_x_in
        self.l_y_in = l_y_in
        self.l_a_mu = l_a_x_mu_reshaped
        self.l_a_logvar = l_a_x_logvar_reshaped
        self.l_a = l_a_x_reshaped
        self.l_z_mu = l_z_axy_mu_reshaped
        self.l_z_logvar = l_z_axy_logvar_reshaped
        self.l_z = l_z_axy_reshaped
        self.l_y = l_y_xa_reshaped
        self.l_xhat_mu = l_xhat_zy_mu_reshaped
        self.l_xhat_logvar = l_xhat_zy_logvar_reshaped
        self.l_xhat = l_xhat_zy_reshaped

        self.model_params = get_all_params([self.l_xhat, self.l_y])

        ### Calculate networks shapes for documentation ###
        self.qa_shapes = self.get_model_shape(get_all_params(l_a_x))
        self.qy_shapes = self.get_model_shape(
            get_all_params(l_y_xa))[len(self.qa_shapes) - 1:]
        self.qz_shapes = self.get_model_shape(get_all_params(l_z_xy))
        self.px_shapes = self.get_model_shape(
            get_all_params(l_xhat_zy))[(len(self.qz_shapes) - 1):]

        ### Predefined functions for generating xhat and y ###
        inputs = {l_z_xy: self.sym_z, self.l_y_in: self.sym_y}
        outputs = get_output(self.l_xhat, inputs,
                             deterministic=True).mean(axis=(1, 2))
        inputs = [self.sym_z, self.sym_y, self.sym_samples]
        self.f_xhat = theano.function(inputs, outputs)

        inputs = [self.sym_x_l, self.sym_samples]
        outputs = get_output(self.l_y, self.sym_x_l,
                             deterministic=True).mean(axis=(1, 2))
        self.f_y = theano.function(inputs, outputs)

        self.y_params = get_all_params(
            self.l_y, trainable=True)[(len(a_hidden) + 2) * 2::]
        self.xhat_params = get_all_params(self.l_xhat, trainable=True)
    def __init__(self,
                 n_c,
                 n_z,
                 qz_hid,
                 px_hid,
                 enc_rnn=256,
                 dec_rnn=256,
                 n_l=28,
                 nonlinearity=rectify,
                 px_nonlinearity=None,
                 x_dist='bernoulli',
                 batchnorm=False,
                 seed=1234):
        """
        Weights are initialized using the Bengio and Glorot (2010) initialization scheme.
        :param n_c: Number of inputs.
        :param n_z: Number of latent.
        :param qz_hid: List of number of deterministic hidden q(z|a,x,y).
        :param px_hid: List of number of deterministic hidden p(a|z,y) & p(x|z,y).
        :param nonlinearity: The transfer function used in the deterministic layers.
        :param x_dist: The x distribution, 'bernoulli', 'multinomial', or 'gaussian'.
        :param batchnorm: Boolean value for batch normalization.
        :param seed: The random seed.
        """
        super(RVAE, self).__init__(n_c, qz_hid + px_hid, n_z, nonlinearity)
        self.x_dist = x_dist
        self.n_x = n_c
        self.seq_length = n_l
        self.n_z = n_z
        self.batchnorm = batchnorm
        self._srng = RandomStreams(seed)

        # Decide Glorot initializaiton of weights.
        init_w = 1e-3
        hid_w = ""
        if nonlinearity == rectify or nonlinearity == softplus:
            hid_w = "relu"

        # Define symbolic variables for theano functions.
        self.sym_x = T.tensor3('x')  # inputs
        self.sym_z = T.matrix('z')
        self.sym_samples = T.iscalar('samples')  # MC samples
        self.sym_warmup = T.fscalar('warmup')

        # Assist methods for collecting the layers
        def dense_layer(layer_in,
                        n,
                        dist_w=init.GlorotNormal,
                        dist_b=init.Normal):
            dense = DenseLayer(layer_in,
                               num_units=n,
                               W=dist_w(hid_w),
                               b=dist_b(init_w),
                               nonlinearity=None)
            if batchnorm:
                dense = BatchNormLayer(dense)
            return NonlinearityLayer(dense, self.transf)

        def stochastic_layer(layer_in, n, samples, nonlin=None):
            mu = DenseLayer(layer_in,
                            n,
                            W=init.Normal(init_w, mean=.0),
                            b=init.Normal(init_w),
                            nonlinearity=nonlin)
            logvar = DenseLayer(layer_in,
                                n,
                                W=init.Normal(init_w, mean=.0),
                                b=init.Normal(init_w),
                                nonlinearity=nonlin)
            # logvar = ConstrainLayer(logvar, scale=1, max=T.log(-0.999 * self.sym_warmup + 1.0999))
            return SampleLayer(mu, logvar, eq_samples=samples,
                               iw_samples=1), mu, logvar

        def lstm_layer(input,
                       nunits,
                       return_final,
                       backwards=False,
                       name='LSTM'):
            ingate = Gate(W_in=init.Uniform(0.01),
                          W_hid=init.Uniform(0.01),
                          b=init.Constant(0.0))
            forgetgate = Gate(W_in=init.Uniform(0.01),
                              W_hid=init.Uniform(0.01),
                              b=init.Constant(5.0))
            cell = Gate(
                W_cell=None,
                nonlinearity=T.tanh,
                W_in=init.Uniform(0.01),
                W_hid=init.Uniform(0.01),
            )
            outgate = Gate(W_in=init.Uniform(0.01),
                           W_hid=init.Uniform(0.01),
                           b=init.Constant(0.0))

            lstm = LSTMLayer(input,
                             num_units=nunits,
                             backwards=backwards,
                             peepholes=False,
                             ingate=ingate,
                             forgetgate=forgetgate,
                             cell=cell,
                             outgate=outgate,
                             name=name,
                             only_return_final=return_final)
            return lstm

        # RNN encoder implementation
        l_x_in = InputLayer((None, n_l, n_c))
        l_enc_forward = lstm_layer(l_x_in,
                                   enc_rnn,
                                   return_final=True,
                                   backwards=False,
                                   name='enc_forward')
        l_enc_backward = lstm_layer(l_x_in,
                                    enc_rnn,
                                    return_final=True,
                                    backwards=True,
                                    name='enc_backward')
        l_enc_concat = ConcatLayer([l_enc_forward, l_enc_backward], axis=-1)
        l_enc = dense_layer(l_enc_concat, enc_rnn)

        # # Overwrite encoder
        # l_enc = dense_layer(l_x_in, enc_rnn)

        # Recognition q(z|x)
        l_qz = l_enc
        for hid in qz_hid:
            l_qz = dense_layer(l_qz, hid)

        # Reparameterisation and sample
        l_qz_mu = DenseLayer(l_qz,
                             n_z,
                             W=init.Normal(init_w, mean=1.0),
                             b=init.Normal(init_w),
                             nonlinearity=None)
        l_qz_logvar = DenseLayer(l_qz,
                                 n_z,
                                 init.Normal(init_w),
                                 init.Normal(init_w),
                                 nonlinearity=None)
        l_qz = SampleLayer(l_qz_mu,
                           l_qz_logvar,
                           eq_samples=self.sym_samples,
                           iw_samples=1)

        # Generative p(x|z)
        l_qz_repeat = RepeatLayer(l_qz, n=n_l)

        # Skip connection to encoder until warmup threshold is reached
        if T.ge(self.sym_warmup, 0.4):
            l_skip_enc_repeat = RepeatLayer(l_enc, n=n_l)
            l_qz_repeat = ConcatLayer([l_qz_repeat, l_skip_enc_repeat],
                                      axis=-1)

        l_dec_forward = lstm_layer(l_qz_repeat,
                                   dec_rnn,
                                   return_final=False,
                                   backwards=False,
                                   name='dec_forward')
        l_dec_backward = lstm_layer(l_qz_repeat,
                                    dec_rnn,
                                    return_final=False,
                                    backwards=True,
                                    name='dec_backward')
        l_dec_concat = ConcatLayer([l_dec_forward, l_dec_backward], axis=-1)
        l_dec = ReshapeLayer(l_dec_concat, (-1, 2 * dec_rnn))
        l_dec = dense_layer(l_dec, dec_rnn)

        # # Overwrite decoder
        # l_dec = dense_layer(l_qz, n_l)

        # Add additional dense layers
        l_px = l_dec
        for hid in px_hid:
            l_px = dense_layer(l_px, hid)

        # Reshape the last dimension and perhaps model with a distribution
        if x_dist == 'bernoulli':
            l_px = DenseLayer(l_px, n_c, init.GlorotNormal(),
                              init.Normal(init_w), sigmoid)
        elif x_dist == 'multinomial':
            l_px = DenseLayer(l_px, n_c, init.GlorotNormal(),
                              init.Normal(init_w), softmax)
        elif x_dist == 'gaussian':
            l_px, l_px_mu, l_px_logvar = stochastic_layer(
                l_px, n_c, self.sym_samples, nonlin=px_nonlinearity)
        elif x_dist == 'linear':
            l_px = DenseLayer(l_px, n_c, nonlinearity=None)

        # Reshape all the model layers to have the same size
        self.l_x_in = l_x_in

        self.l_qz = ReshapeLayer(l_qz, (-1, self.sym_samples, 1, n_z))
        self.l_qz_mu = DimshuffleLayer(l_qz_mu, (0, 'x', 'x', 1))
        self.l_qz_logvar = DimshuffleLayer(l_qz_logvar, (0, 'x', 'x', 1))

        self.l_px = DimshuffleLayer(
            ReshapeLayer(l_px, (-1, n_l, self.sym_samples, 1, n_c)),
            (0, 2, 3, 1, 4))
        self.l_px_mu = DimshuffleLayer(ReshapeLayer(l_px_mu, (-1, n_l, self.sym_samples, 1, n_c)), (0, 2, 3, 1, 4)) \
            if x_dist == "gaussian" else None
        self.l_px_logvar = DimshuffleLayer(ReshapeLayer(l_px_logvar, (-1, n_l, self.sym_samples, 1, n_c)), (0, 2, 3, 1, 4)) \
            if x_dist == "gaussian" else None

        # Predefined functions
        inputs = {self.l_x_in: self.sym_x}
        outputs = get_output(l_qz, inputs, deterministic=True)
        self.f_qz = theano.function([self.sym_x, self.sym_samples],
                                    outputs,
                                    on_unused_input='warn')

        inputs = {l_qz: self.sym_z, self.l_x_in: self.sym_x}
        outputs = get_output(self.l_px, inputs,
                             deterministic=True).mean(axis=(1, 2))
        self.f_px = theano.function([self.sym_x, self.sym_z, self.sym_samples],
                                    outputs,
                                    on_unused_input='warn')

        if x_dist == "gaussian":
            outputs = get_output(self.l_px_mu, inputs,
                                 deterministic=True).mean(axis=(1, 2))
            self.f_mu = theano.function(
                [self.sym_x, self.sym_z, self.sym_samples],
                outputs,
                on_unused_input='ignore')

            outputs = get_output(self.l_px_logvar, inputs,
                                 deterministic=True).mean(axis=(1, 2))
            self.f_var = theano.function(
                [self.sym_x, self.sym_z, self.sym_samples],
                outputs,
                on_unused_input='ignore')

        # Define model parameters
        self.model_params = get_all_params([self.l_px])
        self.trainable_model_params = get_all_params([self.l_px],
                                                     trainable=True)