Exemplo n.º 1
0
def _test(distribution, bijector, n):
    x = TransformedDistribution(distribution=distribution,
                                bijector=bijector,
                                validate_args=True)
    val_est = get_dims(x.sample(n))
    val_true = n + get_dims(distribution.mean())
    assert val_est == val_true
    def test_auto_transform_true(self):
        with self.test_session() as sess:
            # Match normal || softplus-inverse-normal distribution with
            # automated transformation on latter (assuming it is softplus).
            x = TransformedDistribution(
                distribution=Normal(0.0, 0.5),
                bijector=tf.contrib.distributions.bijectors.Softplus())
            x.support = 'nonnegative'
            qx = Normal(loc=tf.Variable(tf.random_normal([])),
                        scale=tf.nn.softplus(tf.Variable(tf.random_normal(
                            []))))

            inference = ed.KLqp({x: qx})
            inference.initialize(auto_transform=True, n_samples=5, n_iter=1000)
            tf.global_variables_initializer().run()
            for _ in range(inference.n_iter):
                info_dict = inference.update()

            # Check approximation on constrained space has same moments as
            # target distribution.
            n_samples = 10000
            x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
            x_unconstrained = inference.transformations[x]
            qx_constrained = transform(
                qx, bijectors.Invert(x_unconstrained.bijector))
            qx_mean, qx_var = tf.nn.moments(qx_constrained.sample(n_samples),
                                            0)
            stats = sess.run([x_mean, qx_mean, x_var, qx_var])
            self.assertAllClose(info_dict['loss'], 0.0, rtol=0.2, atol=0.2)
            self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1)
            self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
Exemplo n.º 3
0
    def __init__(self, M, C, theta_prior, delta_prior, a_prior):

        self.M = M
        self.C = C
        self.theta_prior = theta_prior  # prior of ability
        self.delta_prior = delta_prior  # prior of difficulty
        self.a_prior = a_prior  # prior of discrimination

        if isinstance(a_prior, ed.RandomVariable):
            # variational posterior of discrimination
            self.qa = Normal(loc=tf.Variable(tf.ones([M])),
                             scale=tf.nn.softplus(
                                 tf.Variable(tf.ones([M]) * .5)),
                             name='qa')
        else:
            self.qa = a_prior

        with tf.variable_scope('local'):
            # variational posterior of ability
            if isinstance(self.theta_prior, RandomVariable):
                self.qtheta = TransformedDistribution(distribution=Normal(loc=tf.Variable(tf.random_normal([C])), scale=tf.nn.softplus(tf.Variable(tf.random_normal([C])))),\
                                                           bijector=ds.bijectors.Sigmoid(), sample_shape=[M],name='qtheta')
            else:
                self.qtheta = self.theta_prior
            # variational posterior of difficulty
            self.qdelta = TransformedDistribution(distribution=Normal(loc=tf.Variable(tf.random_normal([M])), scale=tf.nn.softplus(tf.Variable(tf.random_normal([M])))), \
                                                            bijector=ds.bijectors.Sigmoid(), sample_shape=[C],name='qdelta')

        alpha = (tf.transpose(self.qtheta) / self.qdelta)**self.qa

        beta = ((1. - tf.transpose(self.qtheta)) / (1. - self.qdelta))**self.qa

        # observed variable
        self.x = Beta(tf.transpose(alpha), tf.transpose(beta))
    def test_hmc_default(self):
        with self.test_session() as sess:
            x = TransformedDistribution(
                distribution=Normal(1.0, 1.0),
                bijector=tf.contrib.distributions.bijectors.Softplus())
            x.support = 'nonnegative'

            inference = ed.HMC([x])
            inference.initialize(auto_transform=True, step_size=0.8)
            tf.global_variables_initializer().run()
            for _ in range(inference.n_iter):
                info_dict = inference.update()
                inference.print_progress(info_dict)

            # Check approximation on constrained space has same moments as
            # target distribution.
            n_samples = 10000
            x_unconstrained = inference.transformations[x]
            qx = inference.latent_vars[x_unconstrained]
            qx_constrained = Empirical(
                x_unconstrained.bijector.inverse(qx.params))
            x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
            qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0)
            stats = sess.run([x_mean, qx_mean, x_var, qx_var])
            self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1)
            self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
  def test_auto_transform_true(self):
    with self.test_session() as sess:
      # Match normal || softplus-inverse-normal distribution with
      # automated transformation on latter (assuming it is softplus).
      x = TransformedDistribution(
          distribution=Normal(0.0, 0.5),
          bijector=tf.contrib.distributions.bijectors.Softplus())
      x.support = 'nonnegative'
      qx = Normal(loc=tf.Variable(tf.random_normal([])),
                  scale=tf.nn.softplus(tf.Variable(tf.random_normal([]))))

      inference = ed.KLqp({x: qx})
      inference.initialize(auto_transform=True, n_samples=5, n_iter=1000)
      tf.global_variables_initializer().run()
      for _ in range(inference.n_iter):
        info_dict = inference.update()

      # Check approximation on constrained space has same moments as
      # target distribution.
      n_samples = 10000
      x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
      x_unconstrained = inference.transformations[x]
      qx_constrained = transform(qx, bijectors.Invert(x_unconstrained.bijector))
      qx_mean, qx_var = tf.nn.moments(qx_constrained.sample(n_samples), 0)
      stats = sess.run([x_mean, qx_mean, x_var, qx_var])
      self.assertAllClose(info_dict['loss'], 0.0, rtol=0.2, atol=0.2)
      self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1)
      self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
Exemplo n.º 6
0
    def fit(self, X, batch_idx=None, max_iter=100, max_time=60):
        tf.reset_default_graph()

        # Data size
        N = X.shape[0]
        P = X.shape[1]

        if not self.batch_correction:
            batch_idx = None

        # Number of experimental batches
        if batch_idx is not None:
            self.n_batches = np.unique(batch_idx[:, 0]).size
        else:
            self.n_batches = 0

        # Prior for cell scalings
        log_library_size = np.log(np.sum(X, axis=1))
        self.mean_llib, self.std_llib = np.mean(log_library_size), np.std(
            log_library_size)

        if self.minibatch_size is not None:
            # Create ZINBayes computation graph
            self.define_stochastic_model(P, self.n_components)
            inference_local, inference_global = self.define_stochastic_inference(
                N, P, self.n_components)

            self.run_stochastic_inference(X,
                                          inference_local,
                                          inference_global,
                                          n_iterations=max_iter)
        else:
            # Create ZINBayes computation graph
            self.define_model(N, P, self.n_components, batch_idx=batch_idx)
            self.inference = self.define_inference(X)

            # If we want to assess convergence during inference on held-out data
            inference_val = None
            if self.validation and self.X_test is not None:
                self.define_val_model(self.X_test.shape[0], P,
                                      self.n_components)
                inference_val = self.define_val_inference(self.X_test)

            # Run inference
            self.loss = self.run_inference(self.inference,
                                           inference_val=inference_val,
                                           n_iterations=max_iter)

        # Get estimated variational distributions of global latent variables
        self.est_qW0 = TransformedDistribution(
            distribution=Normal(self.qW0.distribution.loc.eval(),
                                self.qW0.distribution.scale.eval()),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.est_qr = TransformedDistribution(
            distribution=Normal(self.qr.distribution.loc.eval(),
                                self.qr.distribution.scale.eval()),
            bijector=tf.contrib.distributions.bijectors.Exp())
        if self.zero_inflation:
            self.est_qW1 = Normal(self.qW1.loc.eval(), self.qW1.scale.eval())
Exemplo n.º 7
0
    def define_inference(self, X):
        # Local latent variables
        # self.qz = lognormal_q(self.z.shape)
        self.qz = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.z.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.qlam = lognormal_q(self.lam.shape)
        self.qlam = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.lam.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.lam.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        # Global latent variables
        # self.qr = lognormal_q(self.r.shape)
        self.qr = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.r.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.qW0 = lognormal_q(self.W0.shape)
        self.qW0 = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.W0.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        latent_vars_dict = {
            self.z: self.qz,
            self.lam: self.qlam,
            self.r: self.qr,
            self.W0: self.qW0
        }

        if self.zero_inflation:
            self.qW1 = Normal(
                tf.Variable(tf.zeros(self.W1.shape)),
                tf.nn.softplus(tf.Variable(0.1 * tf.ones(self.W1.shape))))
            latent_vars_dict[self.W1] = self.qW1

        if self.scalings:
            # self.ql = lognormal_q(self.l.shape)
            self.ql = TransformedDistribution(
                distribution=Normal(
                    tf.Variable(self.mean_llib * tf.ones(self.l.shape)),
                    tf.nn.softplus(
                        tf.Variable(self.std_llib * tf.ones(self.l.shape)))),
                bijector=tf.contrib.distributions.bijectors.Exp())
            latent_vars_dict[self.l] = self.ql

        inference = ed.ReparameterizationKLqp(
            latent_vars_dict, data={self.likelihood: tf.cast(X, tf.float32)})

        return inference
def _test(base_dist_cls, transform, inverse, log_det_jacobian, n,
          **base_dist_args):
    x = TransformedDistribution(base_dist_cls=base_dist_cls,
                                transform=transform,
                                inverse=inverse,
                                log_det_jacobian=log_det_jacobian,
                                **base_dist_args)
    val_est = get_dims(x.sample(n))
    val_true = n + get_dims(base_dist_args['mu'])
    assert val_est == val_true
def _test(base_dist_cls, transform, inverse, log_det_jacobian, n, **base_dist_args):
    x = TransformedDistribution(
        base_dist_cls=base_dist_cls,
        transform=transform,
        inverse=inverse,
        log_det_jacobian=log_det_jacobian,
        **base_dist_args
    )
    val_est = get_dims(x.sample(n))
    val_true = n + get_dims(base_dist_args["mu"])
    assert val_est == val_true
Exemplo n.º 10
0
    def define_val_model(self, N, P, K):
        # Define new graph
        self.z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K]))
        self.l_test = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([N, 1]),
                                np.sqrt(self.std_llib) * tf.ones([N, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        rho_test = tf.matmul(self.z_test, self.W0)
        rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1),
                                         (-1, 1))  # NxP

        self.lam_test = Gamma(self.r, self.r / (rho_test * self.l_test))

        if self.zero_inflation:
            logit_pi_test = tf.matmul(self.z_test, self.W1)

            pi_test = tf.minimum(
                tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7)
            cat_test = Categorical(
                probs=tf.stack([pi_test, 1. - pi_test], axis=2))

            components_test = [
                Poisson(rate=1e-30 * tf.ones([N, P])),
                Poisson(rate=self.lam_test)
            ]
            self.likelihood_test = Mixture(cat=cat_test,
                                           components=components_test)
        else:
            self.likelihood_test = Poisson(rate=self.lam_test)
Exemplo n.º 11
0
  def test_auto_transform_false(self):
    with self.test_session():
      # Match normal || softplus-inverse-normal distribution without
      # automated transformation; it should fail.
      x = TransformedDistribution(
          distribution=Normal(0.0, 0.5),
          bijector=tf.contrib.distributions.bijectors.Softplus())
      x.support = 'nonnegative'
      qx = Normal(loc=tf.Variable(tf.random_normal([])),
                  scale=tf.nn.softplus(tf.Variable(tf.random_normal([]))))

      inference = ed.KLqp({x: qx})
      inference.initialize(auto_transform=False, n_samples=5, n_iter=150)
      tf.global_variables_initializer().run()
      for _ in range(inference.n_iter):
        info_dict = inference.update()

      self.assertAllEqual(info_dict['loss'], np.nan)
def lognormal_q(shape):
    min_scale = 1e-5
    loc_init = tf.random_normal(shape)
    scale_init = 0.1 * tf.random_normal(shape)
    rv = TransformedDistribution(
        distribution=Normal(
            tf.Variable(loc_init),
            tf.maximum(tf.nn.softplus(tf.Variable(scale_init)), min_scale)),
        bijector=tf.contrib.distributions.bijectors.Exp())
    return rv
    def test_auto_transform_false(self):
        with self.test_session():
            # Match normal || softplus-inverse-normal distribution without
            # automated transformation; it should fail.
            x = TransformedDistribution(
                distribution=Normal(0.0, 0.5),
                bijector=tf.contrib.distributions.bijectors.Softplus())
            x.support = 'nonnegative'
            qx = Normal(loc=tf.Variable(tf.random_normal([])),
                        scale=tf.nn.softplus(tf.Variable(tf.random_normal(
                            []))))

            inference = ed.KLqp({x: qx})
            inference.initialize(auto_transform=False, n_samples=5, n_iter=150)
            tf.global_variables_initializer().run()
            for _ in range(inference.n_iter):
                info_dict = inference.update()

            self.assertAllEqual(info_dict['loss'], np.nan)
Exemplo n.º 14
0
def lognormal_q(shape, name=None):
  with tf.variable_scope(name, default_name="lognormal_q"):
    min_scale = 1e-5
    loc = tf.get_variable("loc", shape)
    scale = tf.get_variable(
        "scale", shape, initializer=tf.random_normal_initializer(stddev=0.1))
    rv = TransformedDistribution(
        distribution=Normal(loc, tf.maximum(tf.nn.softplus(scale), min_scale)),
        bijector=tf.contrib.distributions.bijectors.Exp())
    return rv
Exemplo n.º 15
0
  def test_hmc_default(self):
    with self.test_session() as sess:
      x = TransformedDistribution(
          distribution=Normal(1.0, 1.0),
          bijector=tf.contrib.distributions.bijectors.Softplus())
      x.support = 'nonnegative'

      inference = ed.HMC([x])
      inference.initialize(auto_transform=True, step_size=0.8)
      tf.global_variables_initializer().run()
      for _ in range(inference.n_iter):
        info_dict = inference.update()
        inference.print_progress(info_dict)

      # Check approximation on constrained space has same moments as
      # target distribution.
      n_samples = 1000
      qx_constrained = inference.latent_vars[x]
      x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
      qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0)
      stats = sess.run([x_mean, qx_mean, x_var, qx_var])
      self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1)
      self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
Exemplo n.º 16
0
    def _initialize_output_model(self):
        with self.sess.as_default():
            with self.sess.graph.as_default():
                if self.output_scale is None:
                    output_scale = self.decoder_scale
                else:
                    output_scale = self.output_scale

                if self.mv:
                    self.out = MultivariateNormalTriL(
                        loc=tf.layers.Flatten()(self.decoder),
                        scale_tril= output_scale
                    )
                else:
                    self.out = Normal(
                        loc=self.decoder,
                        scale = output_scale
                    )
                if self.normalize_data and self.constrain_output:
                    self.out = TransformedDistribution(
                        self.out,
                        bijector=tf.contrib.distributions.bijectors.Sigmoid()
                    )
Exemplo n.º 17
0
  def test_hmc_custom(self):
    with self.test_session() as sess:
      x = TransformedDistribution(
          distribution=Normal(1.0, 1.0),
          bijector=tf.contrib.distributions.bijectors.Softplus())
      x.support = 'nonnegative'
      qx = Empirical(tf.Variable(tf.random_normal([1000])))

      inference = ed.HMC({x: qx})
      inference.initialize(auto_transform=True, step_size=0.8)
      tf.global_variables_initializer().run()
      for _ in range(inference.n_iter):
        info_dict = inference.update()

      # Check approximation on constrained space has same moments as
      # target distribution.
      n_samples = 10000
      x_unconstrained = inference.transformations[x]
      qx_constrained_params = x_unconstrained.bijector.inverse(qx.params)
      x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
      qx_mean, qx_var = tf.nn.moments(qx_constrained_params[500:], 0)
      stats = sess.run([x_mean, qx_mean, x_var, qx_var])
      self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1)
      self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
        def construct_model():
            nku = len(Ku)
            nkv = len(Kv)

            obs = tf.placeholder(tf.float32, R_.shape)

            Ug = TransformedDistribution(distribution=Normal(tf.zeros([nku]), tf.ones([nku])),
                                         bijector=tf.contrib.distributions.bijectors.Exp())
            Vg = TransformedDistribution(distribution=Normal(tf.zeros([nkv]), tf.ones([nkv])),
                                         bijector=tf.contrib.distributions.bijectors.Exp())

            Ua = TransformedDistribution(distribution=Normal(tf.zeros([1]), tf.ones([1])),
                                         bijector=tf.contrib.distributions.bijectors.Exp())
            Va = TransformedDistribution(distribution=Normal(tf.zeros([1]), tf.ones([1])),
                                         bijector=tf.contrib.distributions.bijectors.Exp())

            cKu = tf.cholesky(Ku + tf.eye(I) / Ua)  # TODO: rank 1 chol update
            cKv = tf.cholesky(Kv + tf.eye(J) / Va)

            Uw1 = MultivariateNormalTriL(tf.zeros([L, I]),
                                         tf.reduce_sum(cKu / tf.reshape(tf.sqrt(Ug), [nku, 1, 1]), axis=0))
            Vw1 = MultivariateNormalTriL(tf.zeros([L, J]),
                                         tf.reduce_sum(cKv / tf.reshape(tf.sqrt(Vg), [nkv, 1, 1]), axis=0))

            logits = nn(Uw1, Vw1)
            R = AugmentedBernoulli(logits=logits, c=c, obs=obs, value=tf.cast(logits > 0, tf.int32))

            qUg = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([nku])),
                                                                               tf.Variable(tf.ones([nku]))),
                                          bijector=tf.contrib.distributions.bijectors.Exp())
            qVg = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([nkv])),
                                                                               tf.Variable(tf.ones([nkv]))),
                                          bijector=tf.contrib.distributions.bijectors.Exp())
            qUa = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([1])),
                                                                               tf.Variable(tf.ones([1]))),
                                          bijector=tf.contrib.distributions.bijectors.Exp())
            qVa = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([1])),
                                                                               tf.Variable(tf.ones([1]))),
                                          bijector=tf.contrib.distributions.bijectors.Exp())
            qUw1 = MultivariateNormalTriL(tf.Variable(tf.zeros([L, I])), tf.Variable(tf.eye(I)))
            qVw1 = MultivariateNormalTriL(tf.Variable(tf.zeros([L, J])), tf.Variable(tf.eye(J)))

            return obs, Ug, Vg, Ua, Va, cKu, cKv, Uw1, Vw1, R, qUg, qVg, qUa, qVa, qUw1, qVw1
Exemplo n.º 19
0
    def define_val_inference(self, X):
        self.qz_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.z_test.shape)),
                tf.nn.softplus(tf.Variable(0.01 *
                                           tf.ones(self.z_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.qlam = lognormal_q(self.lam.shape)
        self.qlam_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.lam_test.shape)),
                tf.nn.softplus(tf.Variable(0.01 *
                                           tf.ones(self.lam_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.ql = lognormal_q(self.l.shape)
        self.ql_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(self.mean_llib * tf.ones(self.l_test.shape)),
                tf.nn.softplus(
                    tf.Variable(
                        np.sqrt(self.std_llib) * tf.ones(self.l_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if self.zero_inflation:
            inference = ed.ReparameterizationKLqp(
                {
                    self.z_test: self.qz_test,
                    self.lam_test: self.qlam_test,
                    self.l_test: self.ql_test
                },
                data={
                    self.likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.qW0,
                    self.W1: self.qW1,
                    self.r: self.qr
                })
        else:
            inference = ed.ReparameterizationKLqp(
                {
                    self.z_test: self.qz_test,
                    self.lam_test: self.qlam_test,
                    self.l_test: self.ql_test
                },
                data={
                    self.likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.qW0,
                    self.r: self.qr
                })

        return inference
Exemplo n.º 20
0
    def define_stochastic_model(self, P, K):
        M = self.minibatch_size

        self.W0 = Gamma(0.1 * tf.ones([K, P]), 0.3 * tf.ones([K, P]))
        if self.zero_inflation:
            self.W1 = Normal(tf.zeros([K, P]), tf.ones([K, P]))

        self.z = Gamma(2. * tf.ones([M, K]), 1. * tf.ones([M, K]))

        self.r = Gamma(2. * tf.ones([
            P,
        ]), 1. * tf.ones([
            P,
        ]))

        self.l = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([M, 1]),
                                self.std_llib * tf.ones([M, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        self.rho = tf.matmul(self.z, self.W0)
        self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1),
                                         (-1, 1))  # NxP

        self.lam = Gamma(self.r, self.r / (self.rho * self.l))

        if self.zero_inflation:
            self.logit_pi = tf.matmul(self.z, self.W1)
            self.pi = tf.minimum(
                tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7)

            self.cat = Categorical(
                probs=tf.stack([self.pi, 1. - self.pi], axis=2))

            self.components = [
                Poisson(rate=1e-30 * tf.ones([M, P])),
                Poisson(rate=self.lam)
            ]

            self.likelihood = Mixture(cat=self.cat, components=self.components)
        else:
            self.likelihood = Poisson(rate=self.lam)
Exemplo n.º 21
0
    def __init__(self, hdims, zdim, xdim, gen_scale=1.):
        x_ph = tf.placeholder(tf.float32, [None, xdim])
        batch_size = tf.shape(x_ph)[0]
        sample_size = tf.placeholder(tf.int32, [])

        # Define the generative network (p(x | z))
        with tf.variable_scope('generative', reuse=tf.AUTO_REUSE):
            z = Normal(loc=tf.zeros([batch_size, zdim]),
                       scale=tf.ones([batch_size, zdim]))

            hidden = tf.layers.dense(z, hdims[0], activation=tf.nn.relu, name="dense1")
            loc = tf.layers.dense(hidden, xdim, name="dense2")

            x_gen = TransformedDistribution(
                distribution=tfd.Normal(loc=loc, scale=gen_scale),
                bijector=tfd.bijectors.Exp(),
                name="LogNormalTransformedDistribution"
            )
            #x_gen = Bernoulli(logits=loc)

        # Define the inference network (q(z | x))
        with tf.variable_scope('inference', reuse=tf.AUTO_REUSE):
            hidden = tf.layers.dense(x_ph, hdims[0], activation=tf.nn.relu)
            qloc = tf.layers.dense(hidden, zdim)
            qscale = tf.layers.dense(hidden, zdim, activation=tf.nn.softplus)
            qz = Normal(loc=qloc, scale=qscale)
            qz_sample = qz.sample(sample_size)

        # Define the generative network using posterior samples from q(z | x)
        with tf.variable_scope('generative'):
            qz_sample = tf.reshape(qz_sample, [-1, zdim])
            hidden = tf.layers.dense(qz_sample, hdims[0], activation=tf.nn.relu, reuse=True, name="dense1")
            loc = tf.layers.dense(hidden, xdim, reuse=True, name="dense2")

            x_gen_post = tf.exp(loc)

        self.x_ph = x_ph
        self.x_data = self.x_ph
        self.batch_size = batch_size
        self.sample_size = sample_size

        self.ops = {
            'generative': x_gen,
            'inference': qz_sample,
            'generative_post': x_gen_post
        }

        self.kl_coef = tf.placeholder(tf.float32, ())
        with tf.variable_scope('inference', reuse=tf.AUTO_REUSE):
            self.inference = ed.KLqp({z: qz}, data={x_gen: self.x_data})
            self.lr = tf.placeholder(tf.float32, shape=())

            optimizer = tf.train.RMSPropOptimizer(self.lr, epsilon=0.9)

            self.inference.initialize(
                optimizer=optimizer,
                n_samples=10,
                kl_scaling={z: self.kl_coef}
            )

            # Build elbo loss to evaluate on validation data
            self.eval_loss, _ = self.inference.build_loss_and_gradients([])
Exemplo n.º 22
0
class DNNSegBayes(DNNSeg):
    _INITIALIZATION_KWARGS = UNSUPERVISED_WORD_CLASSIFIER_BAYES_INITIALIZATION_KWARGS

    _doc_header = """
           Bayesian implementation of unsupervised word classifier.

       """
    _doc_args = DNNSeg._doc_args
    _doc_kwargs = DNNSeg._doc_kwargs
    _doc_kwargs += '\n' + '\n'.join([' ' * 8 + ':param %s' % x.key + ': ' + '; '.join(
        [x.dtypes_str(), x.descr]) + ' **Default**: ``%s``.' % (x.default_value if not isinstance(x.default_value,
                                                                                                  str) else "'%s'" % x.default_value)
                                     for x in _INITIALIZATION_KWARGS])
    __doc__ = _doc_header + _doc_args + _doc_kwargs

    def __init__(self, k, train_data, **kwargs):
        super(DNNSegBayes, self).__init__(
            k,
            train_data,
            **kwargs
        )

        for kwarg in DNNSegBayes._INITIALIZATION_KWARGS:
            setattr(self, kwarg.key, kwargs.pop(kwarg.key, kwarg.default_value))

        kwarg_keys = [x.key for x in DNNSeg._INITIALIZATION_KWARGS]
        for kwarg_key in kwargs:
            if kwarg_key not in kwarg_keys:
                raise TypeError('__init__() got an unexpected keyword argument %s' % kwarg_key)

        assert self.declare_priors or self.relaxed, 'Priors must be explicitly declared unless relaxed==True'

        self._initialize_metadata()

    def _initialize_metadata(self):
        super(DNNSegBayes, self)._initialize_metadata()

        self.inference_map = {}

    def _pack_metadata(self):
        md = super(DNNSegBayes, self)._pack_metadata()

        for kwarg in DNNSegBayes._INITIALIZATION_KWARGS:
            md[kwarg.key] = getattr(self, kwarg.key)

        return md

    def _unpack_metadata(self, md):
        super(DNNSegBayes, self)._unpack_metadata(md)

        for kwarg in DNNSegBayes._INITIALIZATION_KWARGS:
            setattr(self, kwarg.key, md.pop(kwarg.key, kwarg.default_value))

        if len(md) > 0:
            sys.stderr.write(
                'Saved model contained unrecognized attributes %s which are being ignored\n' % sorted(list(md.keys())))

    def _initialize_classifier(self):
        with self.sess.as_default():
            with self.sess.graph.as_default():
                if self.trainable_temp:
                    temp = tf.Variable(self.temp, name='temp')
                else:
                    temp = self.temp
                if self.binary_classifier:
                    if self.relaxed:
                        self.encoding_q = RelaxedBernoulli(temp, logits = self.encoder[:,:self.k])
                    else:
                        self.encoding_q = Bernoulli(logits = self.encoder[:,:self.k], dtype=self.FLOAT_TF)
                    if self.declare_priors:
                        if self.relaxed:
                            self.encoding = RelaxedBernoulli(temp, probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) * 0.5)
                        else:
                            self.encoding = Bernoulli(probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) * 0.5, dtype=self.FLOAT_TF)
                        if self.k:
                            self.inference_map[self.encoding] = self.encoding_q
                    else:
                        self.encoding = self.encoding_q
                else:
                    if self.relaxed:
                        self.encoding_q = RelaxedOneHotCategorical(temp, logits = self.encoder[:,:self.k])
                    else:
                        self.encoding_q = OneHotCategorical(logits = self.encoder[:,:self.k], dtype=self.FLOAT_TF)
                    if self.declare_priors:
                        if self.relaxed:
                            self.encoding = RelaxedOneHotCategorical(temp, probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) / self.k)
                        else:
                            self.encoding = OneHotCategorical(probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) / self.k, dtype=self.FLOAT_TF)
                        if self.k:
                            self.inference_map[self.encoding] = self.encoding_q
                    else:
                        self.encoding = self.encoding_q

    def _initialize_decoder_scale(self):
        with self.sess.as_default():
            with self.sess.graph.as_default():
                dim = self.n_timesteps_output_bwd * self.frame_dim

                if self.decoder_type == 'rnn':

                    decoder_scale = tf.nn.softplus(
                        tf.keras.layers.LSTM(
                            self.frame_dim,
                            recurrent_activation='sigmoid',
                            return_sequences=True,
                            unroll=self.unroll
                        )(self.decoder_in)
                    )

                elif self.decoder_type == 'cnn':
                    assert self.n_timesteps_output_bwd is not None, 'n_timesteps_output must be defined when decoder_type == "cnn"'

                    decoder_scale = tf.layers.dense(self.decoder_in, self.n_timesteps * self.frame_dim)[..., None]
                    decoder_scale = tf.reshape(decoder_scale, (self.batch_len, self.n_timesteps_output_bwd, self.frame_dim, 1))
                    decoder_scale = tf.keras.layers.Conv2D(self.conv_n_filters, self.conv_kernel_size, padding='same', activation='elu')(decoder_scale)
                    decoder_scale = tf.keras.layers.Conv2D(1, self.conv_kernel_size, padding='same', activation='linear')(decoder_scale)
                    decoder_scale = tf.squeeze(decoder_scale, axis=-1)

                elif self.decoder_type in ['dense', 'dense_resnet']:
                    n_classes = int(2 ** self.k) if self.binary_classifier else int(self.k)

                    # First layer
                    decoder_scale = DenseLayer(
                        self.decoder_in,
                        self.n_timesteps_output_bwd * self.frame_dim,
                        self.training,
                        activation=tf.nn.elu,
                        batch_normalize=self.batch_normalize,
                        session=self.sess
                    )

                    # Intermediate layers
                    if self.decoder_type == 'dense':
                        for i in range(1, self.n_layers_decoder - 1):
                            decoder_scale = DenseLayer(
                                decoder_scale,
                                self.n_timesteps_output_bwd * self.frame_dim,
                                self.training,
                                activation=tf.nn.elu,
                                batch_normalize=self.batch_normalize,
                                session=self.sess
                            )
                    else: # self.decoder_type = 'dense_resnet'
                        for i in range(1, self.n_layers_decoder - 1):
                            decoder_scale = DenseResidualLayer(
                                decoder_scale,
                                self.training,
                                units=self.n_timesteps_output_bwd * self.frame_dim,
                                layers_inner=self.resnet_n_layers_inner,
                                activation_inner=tf.nn.elu,
                                activation=None,
                                batch_normalize=self.batch_normalize,
                                session=self.sess
                            )

                    # Last layer
                    if self.n_layers_decoder > 1:
                        decoder_scale = DenseLayer(
                            decoder_scale,
                            self.n_timesteps_output_bwd * self.frame_dim,
                            self.training,
                            activation=None,
                            batch_normalize=False,
                            session=self.sess
                        )

                    # Reshape
                    decoder_scale = tf.reshape(decoder_scale, (self.batch_len, self.n_timesteps_output_bwd, self.frame_dim))

                else:
                    raise ValueError('Decoder type "%s" not supported at this time' %self.decoder_type)

                self.decoder_scale = tf.nn.softplus(decoder_scale) + self.epsilon

    # Override this method to include scale params for output distribution
    def _initialize_decoder(self):
        super(DNNSegBayes, self)._initialize_decoder()
        self._initialize_decoder_scale()

    def _initialize_output_model(self):
        with self.sess.as_default():
            with self.sess.graph.as_default():
                if self.output_scale is None:
                    output_scale = self.decoder_scale
                else:
                    output_scale = self.output_scale

                if self.mv:
                    self.out = MultivariateNormalTriL(
                        loc=tf.layers.Flatten()(self.decoder),
                        scale_tril= output_scale
                    )
                else:
                    self.out = Normal(
                        loc=self.decoder,
                        scale = output_scale
                    )
                if self.normalize_data and self.constrain_output:
                    self.out = TransformedDistribution(
                        self.out,
                        bijector=tf.contrib.distributions.bijectors.Sigmoid()
                    )


    def _initialize_objective(self, n_train):
        n_train_minibatch = self.n_minibatch(n_train)
        minibatch_scale = self.minibatch_scale(n_train)

        with self.sess.as_default():
            with self.sess.graph.as_default():
                if self.mv:
                    y = tf.layers.Flatten()(self.y_bwd)
                    y_mask = tf.layers.Flatten()(self.y_bwd_mask[..., None] * tf.ones_like(self.y_bwd))
                else:
                    y = self.y_bwd
                    y_mask = self.y_bwd_mask

                # Define access points to important layers
                if len(self.inference_map) > 0:
                    self.out_post = ed.copy(self.out, self.inference_map)
                else:
                    self.out_post = self.out
                if self.mv:
                    self.reconst = tf.reshape(self.out_post * y_mask, [-1, self.n_timesteps_output_bwd, self.frame_dim])
                    # self.reconst_mean = tf.reshape(self.out_post.mean() * y_mask, [-1, self.n_timesteps_output, self.frame_dim])
                else:
                    self.reconst = self.out_post * y_mask[..., None]
                    # self.reconst_mean = self.out_post.mean() * y_mask[..., None]

                if len(self.inference_map) > 0:
                    self.encoding_post = ed.copy(self.encoding, self.inference_map)
                    self.labels_post = ed.copy(self.labels, self.inference_map)
                    self.label_probs_post = ed.copy(self.label_probs, self.inference_map)
                else:
                    self.encoding_post = self.encoding
                    self.labels_post = self.labels
                    self.label_probs_post = self.label_probs

                self.llprior = self.out.log_prob(y)
                self.ll_post = self.out_post.log_prob(y)

                self.optim = self._initialize_optimizer(self.optim_name)
                if self.variational():
                    self.inference = getattr(ed, self.inference_name)(self.inference_map, data={self.out: y})
                    if self.mask_padding:
                        self.inference.initialize(
                            n_samples=self.n_samples,
                            n_iter=n_train_minibatch * self.n_iter,
                            # n_print=n_train_minibatch * self.log_freq,
                            n_print=0,
                            logdir=self.outdir + '/tensorboard/edward',
                            log_timestamp=False,
                            scale={self.out: y_mask[...,None] * minibatch_scale},
                            optimizer=self.optim
                        )
                    else:
                        self.inference.initialize(
                            n_samples=self.n_samples,
                            n_iter=n_train_minibatch * self.n_iter,
                            # n_print=n_train_minibatch * self.log_freq,
                            n_print=0,
                            logdir=self.outdir + '/tensorboard/edward',
                            log_timestamp=False,
                            scale={self.out: minibatch_scale},
                            optimizer=self.optim
                        )
                else:
                    raise ValueError('Only variational inferences are supported at this time')

    def set_output_scale(self, output_scale, trainable=False):
        with self.sess.as_default():
            with self.sess.graph.as_default():
                if trainable:
                    self.output_scale = tf.Variable(output_scale, dtype=self.FLOAT_TF)
                    self.sess.run(tf.variables_initializer(self.output_scale))
                else:
                    self.output_scale = tf.constant(output_scale, dtype=self.FLOAT_TF)


    def run_train_step(
            self,
            feed_dict,
            return_loss=True,
            return_reconstructions=False,
            return_labels=False,
            return_label_probs=False,
            return_encoding_entropy=False,
            return_segmentation_probs=False
    ):
        info_dict = self.inference.update(feed_dict)

        out_dict = {}
        if return_loss:
            out_dict['loss'] = info_dict['loss']

        if return_reconstructions or return_labels or return_label_probs:
            to_run = []
            to_run_names = []
            if return_reconstructions:
                to_run.append(self.out_post)
                to_run_names.append('reconst')
            if return_labels:
                to_run.append(self.labels_post)
                to_run_names.append('labels')
            if return_label_probs:
                to_run.append(self.label_probs_post)
                to_run_names.append('label_probs')
            if return_encoding_entropy:
                to_run.append(self.encoding_entropy_mean)
                to_run_names.append('encoding_entropy')
            if self.encoder_type.lower() == 'softhmlstm' and return_segmentation_probs:
                to_run.append(self.segmentation_probs)
                to_run_names.append('segmentation_probs')

            output = self.sess.run(to_run, feed_dict=feed_dict)
            for i, x in enumerate(output):
                out_dict[to_run_names[i]] = x

        return out_dict

    def variational(self):
        """
        Check whether the DTSR model uses variational Bayes.

        :return: ``True`` if the model is variational, ``False`` otherwise.
        """
        return self.inference_name in [
            'KLpq',
            'KLqp',
            'ImplicitKLqp',
            'ReparameterizationEntropyKLqp',
            'ReparameterizationKLKLqp',
            'ReparameterizationKLqp',
            'ScoreEntropyKLqp',
            'ScoreKLKLqp',
            'ScoreKLqp',
            'ScoreRBKLqp',
            'WakeSleep'
        ]

    def report_settings(self, indent=0):
        out = super(DNNSegBayes, self).report_settings(indent=indent)
        for kwarg in UNSUPERVISED_WORD_CLASSIFIER_BAYES_INITIALIZATION_KWARGS:
            val = getattr(self, kwarg.key)
            out += ' ' * indent + '  %s: %s\n' %(kwarg.key, "\"%s\"" %val if isinstance(val, str) else val)

        out += '\n'

        return out
Exemplo n.º 23
0
    def evaluate_loglikelihood(self, X, batch_idx=None):
        """
		This is the ELBO, which is a lower bound on the marginal log-likelihood.
		We perform some local optimization on the new data points to obtain the ELBO of the new data.
		"""
        N = X.shape[0]
        P = X.shape[1]
        K = self.n_components

        # Define new graph conditioned on the posterior global factors
        z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K]))
        l_test = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([N, 1]),
                                np.sqrt(self.std_llib) * tf.ones([N, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if batch_idx is not None and self.n_batches > 0:
            rho_test = tf.matmul(
                tf.concat([
                    z_test,
                    tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                            tf.float32)
                ],
                          axis=1), self.W0)
        else:
            rho_test = tf.matmul(z_test, self.W0)
        rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1),
                                         (-1, 1))  # NxP

        lam_test = Gamma(self.r, self.r / (rho_test * l_test))

        if self.zero_inflation:
            if batch_idx is not None and self.n_batches > 0:
                logit_pi_test = tf.matmul(
                    tf.concat([
                        z_test,
                        tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                                tf.float32)
                    ],
                              axis=1), self.W1)
            else:
                logit_pi_test = tf.matmul(z_test, self.W1)

            pi_test = tf.minimum(
                tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7)
            cat_test = Categorical(
                probs=tf.stack([pi_test, 1. - pi_test], axis=2))

            components_test = [
                Poisson(rate=1e-30 * tf.ones([N, P])),
                Poisson(rate=lam_test)
            ]
            likelihood_test = Mixture(cat=cat_test, components=components_test)
        else:
            likelihood_test = Poisson(rate=lam_test)

        qz_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(z_test.shape)),
                tf.nn.softplus(tf.Variable(1. * tf.ones(z_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        qlam_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(lam_test.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(lam_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        ql_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(self.mean_llib * tf.ones(l_test.shape)),
                tf.nn.softplus(
                    tf.Variable(
                        np.sqrt(self.std_llib) * tf.ones(l_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if self.zero_inflation:
            inference_local = ed.ReparameterizationKLqp(
                {
                    z_test: qz_test,
                    lam_test: qlam_test,
                    l_test: ql_test
                },
                data={
                    likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.est_qW0,
                    self.W1: self.est_qW1,
                    self.r: self.est_qr
                })
        else:
            inference_local = ed.ReparameterizationKLqp(
                {
                    z_test: qz_test,
                    lam_test: qlam_test,
                    l_test: ql_test
                },
                data={
                    likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.est_qW0,
                    self.r: self.est_qr
                })

        inference_local.run(n_iter=self.test_iterations,
                            n_samples=self.n_mc_samples)

        return -self.sess.run(inference_local.loss,
                              feed_dict={likelihood_test: X.astype('float32')
                                         }) / N
Exemplo n.º 24
0
# TODO: beta is actually the expm of C: correct this!

import edward as ed
import tensorflow as tf
from edward.models import Normal, InverseGamma, PointMass, Uniform, TransformedDistribution
# simulate data
d = 50
T = 300
X, C, S = MOU_sim(N=d, Sigma=None, mu=0, T=T, connectivity_strength=8.)

# the model
mu = tf.constant(0.)  # Normal(loc=tf.zeros([d]), scale=1.*tf.ones([d]))
beta = Normal(loc=tf.ones([d, d]), scale=2. * tf.ones([d, d]))
ds = tf.contrib.distributions
C = TransformedDistribution(distribution=beta,
                            bijector=ds.bijectors.Exp(),
                            name="LogNormalTransformedDistribution")

noise_proc = InverseGamma(concentration=tf.ones([d]),
                          rate=tf.ones([d]))  # tf.constant(0.1)
noise_obs = tf.constant(0.1)  # InverseGamma(alpha=1.0, beta=1.0)

x = [0] * T
x[0] = Normal(loc=mu, scale=10. * tf.ones([d]))
for n in range(1, T):
    x[n] = Normal(loc=mu + tf.tensordot(C, x[n - 1], axes=[[1], [0]]),
                  scale=noise_proc * tf.ones([d]))
## map inference
#print("setting up distributions")
#qmu = PointMass(params=tf.Variable(tf.zeros([d])))
#qbeta = PointMass(params=tf.Variable(tf.zeros([d,d])))
Exemplo n.º 25
0
    def define_stochastic_inference(self, N, P, K):
        M = self.minibatch_size

        qz_vars = [
            tf.Variable(tf.ones([N, K]), name='qz_loc'),
            tf.Variable(0.01 * tf.ones([N, K]), name='qz_scale')
        ]
        qlam_vars = [
            tf.Variable(tf.ones([N, P]), name='qlam_loc'),
            tf.Variable(0.01 * tf.ones([N, P]), name='qlam_scale')
        ]
        ql_vars = [
            tf.Variable(self.mean_llib * tf.ones([N, 1]), name='ql_loc'),
            tf.Variable(self.std_llib * tf.ones([N, 1]), name='ql_scale')
        ]
        qlocal_vars = [qz_vars, qlam_vars, ql_vars]
        qlocal_vars = [item for sublist in qlocal_vars for item in sublist]

        self.idx_ph = tf.placeholder(tf.int32, M)
        self.qz = TransformedDistribution(
            distribution=Normal(
                tf.gather(qlocal_vars[0], self.idx_ph),
                tf.nn.softplus(tf.gather(qlocal_vars[1], self.idx_ph))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.qlam = TransformedDistribution(
            distribution=Normal(
                tf.gather(qlocal_vars[2], self.idx_ph),
                tf.nn.softplus(tf.gather(qlocal_vars[3], self.idx_ph))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.ql = TransformedDistribution(
            distribution=Normal(
                tf.gather(qlocal_vars[4], self.idx_ph),
                tf.nn.softplus(tf.gather(qlocal_vars[5], self.idx_ph))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        self.qW0 = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.W0.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.qW1 = Normal(tf.Variable(tf.zeros(self.W1.shape)),
                          tf.nn.softplus(tf.Variable(tf.ones(self.W1.shape))))
        self.qr = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.r.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        self.x_ph = tf.placeholder(tf.float32, [M, P])
        inference_global = ed.ReparameterizationKLqp(
            {
                self.r: self.qr,
                self.W0: self.qW0,
                self.W1: self.qW1
            },
            data={
                self.likelihood: self.x_ph,
                self.z: self.qz,
                self.lam: self.qlam,
                self.l: self.ql
            })
        inference_local = ed.ReparameterizationKLqp(
            {
                self.z: self.qz,
                self.lam: self.qlam,
                self.l: self.ql
            },
            data={
                self.likelihood: self.x_ph,
                self.r: self.qr,
                self.W0: self.qW0,
                self.W1: self.qW1
            })

        return inference_local, inference_global
Exemplo n.º 26
0
class ZINBayes(BaseEstimator, TransformerMixin):
    def __init__(self,
                 n_components=10,
                 n_mc_samples=1,
                 gene_dispersion=True,
                 zero_inflation=True,
                 scalings=True,
                 batch_correction=False,
                 test_iterations=100,
                 optimizer=None,
                 minibatch_size=None,
                 validation=False,
                 X_test=None):
        self.n_components = n_components
        self.est_X = None
        self.est_L = None
        self.est_Z = None

        self.zero_inflation = zero_inflation
        if zero_inflation:
            print('Considering zero-inflation.')

        self.batch_correction = batch_correction
        if batch_correction:
            print('Performing batch correction.')

        self.scalings = scalings
        if scalings:
            print('Considering cell-specific scalings.')

        self.gene_dispersion = gene_dispersion
        if scalings:
            print('Considering gene-specific dispersion.')

        self.n_mc_samples = n_mc_samples
        self.test_iterations = test_iterations

        self.optimizer = optimizer
        self.minibatch_size = minibatch_size

        # if validation, use X_test to assess convergence
        self.validation = validation and X_test is not None
        self.X_test = X_test
        self.loss_dict = {'t_loss': [], 'v_loss': []}

        sess = ed.get_session()
        sess.close()
        tf.reset_default_graph()

    def close_session(self):
        return self.sess.close()

    def fit(self, X, batch_idx=None, max_iter=100, max_time=60):
        tf.reset_default_graph()

        # Data size
        N = X.shape[0]
        P = X.shape[1]

        if not self.batch_correction:
            batch_idx = None

        # Number of experimental batches
        if batch_idx is not None:
            self.n_batches = np.unique(batch_idx[:, 0]).size
        else:
            self.n_batches = 0

        # Prior for cell scalings
        log_library_size = np.log(np.sum(X, axis=1))
        self.mean_llib, self.std_llib = np.mean(log_library_size), np.std(
            log_library_size)

        if self.minibatch_size is not None:
            # Create ZINBayes computation graph
            self.define_stochastic_model(P, self.n_components)
            inference_local, inference_global = self.define_stochastic_inference(
                N, P, self.n_components)

            self.run_stochastic_inference(X,
                                          inference_local,
                                          inference_global,
                                          n_iterations=max_iter)
        else:
            # Create ZINBayes computation graph
            self.define_model(N, P, self.n_components, batch_idx=batch_idx)
            self.inference = self.define_inference(X)

            # If we want to assess convergence during inference on held-out data
            inference_val = None
            if self.validation and self.X_test is not None:
                self.define_val_model(self.X_test.shape[0], P,
                                      self.n_components)
                inference_val = self.define_val_inference(self.X_test)

            # Run inference
            self.loss = self.run_inference(self.inference,
                                           inference_val=inference_val,
                                           n_iterations=max_iter)

        # Get estimated variational distributions of global latent variables
        self.est_qW0 = TransformedDistribution(
            distribution=Normal(self.qW0.distribution.loc.eval(),
                                self.qW0.distribution.scale.eval()),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.est_qr = TransformedDistribution(
            distribution=Normal(self.qr.distribution.loc.eval(),
                                self.qr.distribution.scale.eval()),
            bijector=tf.contrib.distributions.bijectors.Exp())
        if self.zero_inflation:
            self.est_qW1 = Normal(self.qW1.loc.eval(), self.qW1.scale.eval())

    def transform(self):
        if self.minibatch_size is None:
            self.est_Z = self.sess.run(tf.exp(self.qz.distribution.loc))
            if self.scalings:
                self.est_L = self.sess.run(tf.exp(self.ql.distribution.loc))
            self.est_X = self.posterior_nb_mean()

            return self.est_Z

    def fit_transform(self, X, batch_idx=None, max_iter=100, max_time=60):
        self.fit(X, batch_idx=batch_idx, max_iter=max_iter, max_time=60)
        return self.transform()

    def get_est_X(self):
        return self.est_X

    def get_est_l(self):
        return self.est_L

    def score(self, X, batch_idx=None):
        return self.evaluate_loglikelihood(X, batch_idx=batch_idx)

    # def define_model(self, N, P, K, batch_idx=None):
    # 	self.W0 = Gamma(.1 * tf.ones([K + self.n_batches, P]), .3 * tf.ones([K + self.n_batches, P]))

    # 	self.z = Gamma(16. * tf.ones([N, K]), 4. * tf.ones([N, K]))

    # 	self.a = Gamma(2. * tf.ones([1,P]), 1. * tf.ones([1,P]))
    # 	self.r = Gamma(self.a, self.a)

    # 	self.l = Gamma(self.mean_llib**2 / self.std_llib**2 * tf.ones([N, 1]), self.mean_llib / self.std_llib**2 * tf.ones([N, 1]))

    # 	rho = tf.matmul(self.z, self.W0)

    # 	self.likelihood = Poisson(rate=self.r*rho)

    def define_model(self, N, P, K, batch_idx=None):
        self.W0 = Gamma(.1 * tf.ones([K + self.n_batches, P]),
                        .3 * tf.ones([K + self.n_batches, P]))
        if self.zero_inflation:
            self.W1 = Normal(tf.zeros([K + self.n_batches, P]),
                             tf.ones([K + self.n_batches, P]))

        self.z = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K]))

        disp_size = 1
        if self.gene_dispersion:
            disp_size = P
        self.r = Gamma(2. * tf.ones([
            disp_size,
        ]), 1. * tf.ones([
            disp_size,
        ]))

        self.l = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([N, 1]),
                                np.sqrt(self.std_llib) * tf.ones([N, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if batch_idx is not None and self.n_batches > 0:
            self.rho = tf.matmul(
                tf.concat([
                    self.z,
                    tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                            tf.float32)
                ],
                          axis=1), self.W0)
        else:
            self.rho = tf.matmul(self.z, self.W0)

        if self.scalings:
            self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1),
                                             (-1, 1))  # NxP
            self.lam = Gamma(self.r, self.r / (self.rho * self.l))
        else:
            self.lam = Gamma(self.r, self.r / self.rho)

        if self.zero_inflation:
            if batch_idx is not None and self.n_batches > 0:
                self.logit_pi = tf.matmul(
                    tf.concat([
                        self.z,
                        tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                                tf.float32)
                    ],
                              axis=1), self.W1)
            else:
                self.logit_pi = tf.matmul(self.z, self.W1)
            self.pi = tf.minimum(
                tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7)

            self.cat = Categorical(
                probs=tf.stack([self.pi, 1. - self.pi], axis=2))

            self.components = [
                Poisson(rate=1e-30 * tf.ones([N, P])),
                Poisson(rate=self.lam)
            ]

            self.likelihood = Mixture(cat=self.cat, components=self.components)
        else:
            self.likelihood = Poisson(rate=self.lam)

    # def define_inference(self, X):
    # 	# Local latent variables
    # 	# self.qz = lognormal_q(self.z.shape)
    # 	self.qz = TransformedDistribution(
    # 	distribution=Normal(tf.Variable(tf.ones(self.z.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z.shape)))),
    # 	bijector=tf.contrib.distributions.bijectors.Exp())

    # 	self.ql = TransformedDistribution(
    # 	distribution=Normal(tf.Variable(self.mean_llib * tf.ones(self.l.shape)), tf.nn.softplus(tf.Variable(self.std_llib * tf.ones(self.l.shape)))),
    # 	bijector=tf.contrib.distributions.bijectors.Exp())

    # 	self.qr = TransformedDistribution(
    # 	distribution=Normal(tf.Variable(tf.ones(self.r.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))),
    # 	bijector=tf.contrib.distributions.bijectors.Exp())

    # 	self.qa = TransformedDistribution(
    # 	distribution=Normal(tf.Variable(tf.ones(self.a.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.a.shape)))),
    # 	bijector=tf.contrib.distributions.bijectors.Exp())

    # 	self.qW0 = TransformedDistribution(
    # 	distribution=Normal(tf.Variable(tf.ones(self.W0.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))),
    # 	bijector=tf.contrib.distributions.bijectors.Exp())

    # 	latent_vars_dict = {self.z: self.qz, self.a: self.qa,  self.r: self.qr,  self.W0: self.qW0}

    # 	inference = ed.ReparameterizationKLqp(latent_vars_dict, data={self.likelihood: tf.cast(X, tf.float32)})

    # 	return inference

    def define_inference(self, X):
        # Local latent variables
        # self.qz = lognormal_q(self.z.shape)
        self.qz = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.z.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.qlam = lognormal_q(self.lam.shape)
        self.qlam = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.lam.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.lam.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        # Global latent variables
        # self.qr = lognormal_q(self.r.shape)
        self.qr = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.r.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.qW0 = lognormal_q(self.W0.shape)
        self.qW0 = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.W0.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        latent_vars_dict = {
            self.z: self.qz,
            self.lam: self.qlam,
            self.r: self.qr,
            self.W0: self.qW0
        }

        if self.zero_inflation:
            self.qW1 = Normal(
                tf.Variable(tf.zeros(self.W1.shape)),
                tf.nn.softplus(tf.Variable(0.1 * tf.ones(self.W1.shape))))
            latent_vars_dict[self.W1] = self.qW1

        if self.scalings:
            # self.ql = lognormal_q(self.l.shape)
            self.ql = TransformedDistribution(
                distribution=Normal(
                    tf.Variable(self.mean_llib * tf.ones(self.l.shape)),
                    tf.nn.softplus(
                        tf.Variable(self.std_llib * tf.ones(self.l.shape)))),
                bijector=tf.contrib.distributions.bijectors.Exp())
            latent_vars_dict[self.l] = self.ql

        inference = ed.ReparameterizationKLqp(
            latent_vars_dict, data={self.likelihood: tf.cast(X, tf.float32)})

        return inference

    def define_val_model(self, N, P, K):
        # Define new graph
        self.z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K]))
        self.l_test = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([N, 1]),
                                np.sqrt(self.std_llib) * tf.ones([N, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        rho_test = tf.matmul(self.z_test, self.W0)
        rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1),
                                         (-1, 1))  # NxP

        self.lam_test = Gamma(self.r, self.r / (rho_test * self.l_test))

        if self.zero_inflation:
            logit_pi_test = tf.matmul(self.z_test, self.W1)

            pi_test = tf.minimum(
                tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7)
            cat_test = Categorical(
                probs=tf.stack([pi_test, 1. - pi_test], axis=2))

            components_test = [
                Poisson(rate=1e-30 * tf.ones([N, P])),
                Poisson(rate=self.lam_test)
            ]
            self.likelihood_test = Mixture(cat=cat_test,
                                           components=components_test)
        else:
            self.likelihood_test = Poisson(rate=self.lam_test)

    def define_val_inference(self, X):
        self.qz_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.z_test.shape)),
                tf.nn.softplus(tf.Variable(0.01 *
                                           tf.ones(self.z_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.qlam = lognormal_q(self.lam.shape)
        self.qlam_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.lam_test.shape)),
                tf.nn.softplus(tf.Variable(0.01 *
                                           tf.ones(self.lam_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        # self.ql = lognormal_q(self.l.shape)
        self.ql_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(self.mean_llib * tf.ones(self.l_test.shape)),
                tf.nn.softplus(
                    tf.Variable(
                        np.sqrt(self.std_llib) * tf.ones(self.l_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if self.zero_inflation:
            inference = ed.ReparameterizationKLqp(
                {
                    self.z_test: self.qz_test,
                    self.lam_test: self.qlam_test,
                    self.l_test: self.ql_test
                },
                data={
                    self.likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.qW0,
                    self.W1: self.qW1,
                    self.r: self.qr
                })
        else:
            inference = ed.ReparameterizationKLqp(
                {
                    self.z_test: self.qz_test,
                    self.lam_test: self.qlam_test,
                    self.l_test: self.ql_test
                },
                data={
                    self.likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.qW0,
                    self.r: self.qr
                })

        return inference

    # def run_inference(self, inference, inference_val=None, n_iterations=1000):
    # 	N = self.l.shape.as_list()[0]
    # 	self.inference_e.initialize(n_iter=n_iterations, n_samples=self.n_mc_samples, optimizer=self.optimizer)
    # 	self.inference_m.initialize(optimizer=self.optimizer)

    # 	self.sess = ed.get_session()
    # 	tf.global_variables_initializer().run()

    # 	for i in range(self.inference_e.n_iter):
    # 		info_dict = self.inference_e.update()
    # 		self.inference_m.update()
    # 		info_dict['loss'] = info_dict['loss'] / N
    # 		self.inference_e.print_progress(info_dict)
    # 		self.loss_dict['t_loss'].append(info_dict["loss"])

    # 	self.inference_e.finalize()
    # 	self.inference_m.finalize()

    def run_inference(self, inference, inference_val=None, n_iterations=1000):
        N = self.l.shape.as_list()[0]
        inference.initialize(n_iter=n_iterations,
                             n_samples=self.n_mc_samples,
                             optimizer=self.optimizer)

        if inference_val is not None:
            N_val = self.l_test.shape.as_list()[0]
            inference_val.initialize(n_samples=self.n_mc_samples,
                                     optimizer=self.optimizer)

        self.sess = ed.get_session()
        tf.global_variables_initializer().run()

        for i in range(inference.n_iter):
            info_dict = inference.update()
            info_dict['loss'] = info_dict['loss'] / N
            inference.print_progress(info_dict)
            self.loss_dict['t_loss'].append(info_dict["loss"])

            if inference_val is not None:
                self.sess.run(inference_val.reset)
                self.sess.run(
                    tf.variables_initializer(self.ql_test.get_variables()))
                self.sess.run(
                    tf.variables_initializer(self.qz_test.get_variables()))
                self.sess.run(
                    tf.variables_initializer(self.qlam_test.get_variables()))
                for _ in range(5):
                    val_info_dict = inference_val.update()
                self.loss_dict['v_loss'].append(val_info_dict['loss'] / N_val)

        inference.finalize()

        if inference_val is not None:
            inference_val.finalize()

    def evaluate_loglikelihood(self, X, batch_idx=None):
        """
		This is the ELBO, which is a lower bound on the marginal log-likelihood.
		We perform some local optimization on the new data points to obtain the ELBO of the new data.
		"""
        N = X.shape[0]
        P = X.shape[1]
        K = self.n_components

        # Define new graph conditioned on the posterior global factors
        z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K]))
        l_test = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([N, 1]),
                                np.sqrt(self.std_llib) * tf.ones([N, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if batch_idx is not None and self.n_batches > 0:
            rho_test = tf.matmul(
                tf.concat([
                    z_test,
                    tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                            tf.float32)
                ],
                          axis=1), self.W0)
        else:
            rho_test = tf.matmul(z_test, self.W0)
        rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1),
                                         (-1, 1))  # NxP

        lam_test = Gamma(self.r, self.r / (rho_test * l_test))

        if self.zero_inflation:
            if batch_idx is not None and self.n_batches > 0:
                logit_pi_test = tf.matmul(
                    tf.concat([
                        z_test,
                        tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                                tf.float32)
                    ],
                              axis=1), self.W1)
            else:
                logit_pi_test = tf.matmul(z_test, self.W1)

            pi_test = tf.minimum(
                tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7)
            cat_test = Categorical(
                probs=tf.stack([pi_test, 1. - pi_test], axis=2))

            components_test = [
                Poisson(rate=1e-30 * tf.ones([N, P])),
                Poisson(rate=lam_test)
            ]
            likelihood_test = Mixture(cat=cat_test, components=components_test)
        else:
            likelihood_test = Poisson(rate=lam_test)

        qz_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(z_test.shape)),
                tf.nn.softplus(tf.Variable(1. * tf.ones(z_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        qlam_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(lam_test.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(lam_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        ql_test = TransformedDistribution(
            distribution=Normal(
                tf.Variable(self.mean_llib * tf.ones(l_test.shape)),
                tf.nn.softplus(
                    tf.Variable(
                        np.sqrt(self.std_llib) * tf.ones(l_test.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if self.zero_inflation:
            inference_local = ed.ReparameterizationKLqp(
                {
                    z_test: qz_test,
                    lam_test: qlam_test,
                    l_test: ql_test
                },
                data={
                    likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.est_qW0,
                    self.W1: self.est_qW1,
                    self.r: self.est_qr
                })
        else:
            inference_local = ed.ReparameterizationKLqp(
                {
                    z_test: qz_test,
                    lam_test: qlam_test,
                    l_test: ql_test
                },
                data={
                    likelihood_test: tf.cast(X, tf.float32),
                    self.W0: self.est_qW0,
                    self.r: self.est_qr
                })

        inference_local.run(n_iter=self.test_iterations,
                            n_samples=self.n_mc_samples)

        return -self.sess.run(inference_local.loss,
                              feed_dict={likelihood_test: X.astype('float32')
                                         }) / N

    def posterior_nb_mean(self):
        est_rho = self.sess.run(self.rho,
                                feed_dict={
                                    self.z:
                                    np.exp(self.qz.distribution.loc.eval()),
                                    self.W0:
                                    np.exp(self.qW0.distribution.loc.eval())
                                })
        est_l = 1
        if self.scalings:
            est_l = np.exp(self.ql.distribution.loc.eval())
        est_mean = est_rho * est_l

        return est_mean

    def define_stochastic_model(self, P, K):
        M = self.minibatch_size

        self.W0 = Gamma(0.1 * tf.ones([K, P]), 0.3 * tf.ones([K, P]))
        if self.zero_inflation:
            self.W1 = Normal(tf.zeros([K, P]), tf.ones([K, P]))

        self.z = Gamma(2. * tf.ones([M, K]), 1. * tf.ones([M, K]))

        self.r = Gamma(2. * tf.ones([
            P,
        ]), 1. * tf.ones([
            P,
        ]))

        self.l = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([M, 1]),
                                self.std_llib * tf.ones([M, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        self.rho = tf.matmul(self.z, self.W0)
        self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1),
                                         (-1, 1))  # NxP

        self.lam = Gamma(self.r, self.r / (self.rho * self.l))

        if self.zero_inflation:
            self.logit_pi = tf.matmul(self.z, self.W1)
            self.pi = tf.minimum(
                tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7)

            self.cat = Categorical(
                probs=tf.stack([self.pi, 1. - self.pi], axis=2))

            self.components = [
                Poisson(rate=1e-30 * tf.ones([M, P])),
                Poisson(rate=self.lam)
            ]

            self.likelihood = Mixture(cat=self.cat, components=self.components)
        else:
            self.likelihood = Poisson(rate=self.lam)

    def define_stochastic_inference(self, N, P, K):
        M = self.minibatch_size

        qz_vars = [
            tf.Variable(tf.ones([N, K]), name='qz_loc'),
            tf.Variable(0.01 * tf.ones([N, K]), name='qz_scale')
        ]
        qlam_vars = [
            tf.Variable(tf.ones([N, P]), name='qlam_loc'),
            tf.Variable(0.01 * tf.ones([N, P]), name='qlam_scale')
        ]
        ql_vars = [
            tf.Variable(self.mean_llib * tf.ones([N, 1]), name='ql_loc'),
            tf.Variable(self.std_llib * tf.ones([N, 1]), name='ql_scale')
        ]
        qlocal_vars = [qz_vars, qlam_vars, ql_vars]
        qlocal_vars = [item for sublist in qlocal_vars for item in sublist]

        self.idx_ph = tf.placeholder(tf.int32, M)
        self.qz = TransformedDistribution(
            distribution=Normal(
                tf.gather(qlocal_vars[0], self.idx_ph),
                tf.nn.softplus(tf.gather(qlocal_vars[1], self.idx_ph))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.qlam = TransformedDistribution(
            distribution=Normal(
                tf.gather(qlocal_vars[2], self.idx_ph),
                tf.nn.softplus(tf.gather(qlocal_vars[3], self.idx_ph))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.ql = TransformedDistribution(
            distribution=Normal(
                tf.gather(qlocal_vars[4], self.idx_ph),
                tf.nn.softplus(tf.gather(qlocal_vars[5], self.idx_ph))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        self.qW0 = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.W0.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())
        self.qW1 = Normal(tf.Variable(tf.zeros(self.W1.shape)),
                          tf.nn.softplus(tf.Variable(tf.ones(self.W1.shape))))
        self.qr = TransformedDistribution(
            distribution=Normal(
                tf.Variable(tf.ones(self.r.shape)),
                tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))),
            bijector=tf.contrib.distributions.bijectors.Exp())

        self.x_ph = tf.placeholder(tf.float32, [M, P])
        inference_global = ed.ReparameterizationKLqp(
            {
                self.r: self.qr,
                self.W0: self.qW0,
                self.W1: self.qW1
            },
            data={
                self.likelihood: self.x_ph,
                self.z: self.qz,
                self.lam: self.qlam,
                self.l: self.ql
            })
        inference_local = ed.ReparameterizationKLqp(
            {
                self.z: self.qz,
                self.lam: self.qlam,
                self.l: self.ql
            },
            data={
                self.likelihood: self.x_ph,
                self.r: self.qr,
                self.W0: self.qW0,
                self.W1: self.qW1
            })

        return inference_local, inference_global

    def run_stochastic_inference(self,
                                 X,
                                 inference_local,
                                 inference_global,
                                 n_iterations=100):
        M = self.minibatch_size
        N = X.shape[0]

        # Run inference
        inference_global.initialize(
            scale={
                self.likelihood: float(N) / M,
                self.z: float(N) / M,
                self.lam: float(N) / M,
                self.l: float(N) / M
            })
        inference_local.initialize(
            scale={
                self.likelihood: float(N) / M,
                self.z: float(N) / M,
                self.lam: float(N) / M,
                self.l: float(N) / M
            })

        self.sess = ed.get_session()
        tf.global_variables_initializer().run()

        self.loss = []

        n_iter_per_epoch = N // M
        pbar = ed.Progbar(n_iterations)
        for epoch in range(n_iterations):
            #     print("Epoch: {0}".format(epoch))
            avg_loss = 0.0

            for t in range(1, n_iter_per_epoch + 1):
                x_batch, idx_batch = next_batch(X, M)

                #         inference_local.update(feed_dict={x_ph: x_batch})
                for _ in range(5):  # make local inferences
                    info_dict = inference_local.update(feed_dict={
                        self.x_ph: x_batch,
                        self.idx_ph: idx_batch
                    })

                info_dict = inference_global.update(feed_dict={
                    self.x_ph: x_batch,
                    self.idx_ph: idx_batch
                })
                avg_loss += info_dict['loss']

            # Print a lower bound to the average marginal likelihood for a cell
            avg_loss /= n_iter_per_epoch
            avg_loss /= M

            #     print("-log p(x) <= {:0.3f}\n".format(avg_loss), end='\r')
            self.loss.append(avg_loss)
            pbar.update(epoch, values={'Loss': avg_loss})

        inference_global.finalize()
        inference_local.finalize()
Exemplo n.º 27
0
x_ph = tf.placeholder(tf.float32, [M, D])

U = Gamma(0.1, 0.5, sample_shape=[M, K])
V = Gamma(0.1, 0.3, sample_shape=[D, K])
x = Poisson(tf.matmul(U, V, transpose_b=True))

min_scale = 1e-5

qV_variables = [
    tf.Variable(tf.random_uniform([D, K])),
    tf.Variable(tf.random_uniform([D, K]))
]

qV = TransformedDistribution(
            distribution=Normal(qV_variables[0],\
                                tf.maximum(tf.nn.softplus(qV_variables[1]), \
                                           min_scale)),
            bijector=tf.contrib.distributions.bijectors.Exp())

qU_variables = [
    tf.Variable(tf.random_uniform([N, K])),
    tf.Variable(tf.random_uniform([N, K]))
]


qU = TransformedDistribution(
            distribution=Normal(tf.gather(qU_variables[0], idx_ph),\
                                tf.maximum(tf.nn.softplus(tf.gather(qU_variables[1], idx_ph)), \
                                           min_scale)),
            bijector=tf.contrib.distributions.bijectors.Exp())
Exemplo n.º 28
0
    def define_model(self, N, P, K, batch_idx=None):
        self.W0 = Gamma(.1 * tf.ones([K + self.n_batches, P]),
                        .3 * tf.ones([K + self.n_batches, P]))
        if self.zero_inflation:
            self.W1 = Normal(tf.zeros([K + self.n_batches, P]),
                             tf.ones([K + self.n_batches, P]))

        self.z = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K]))

        disp_size = 1
        if self.gene_dispersion:
            disp_size = P
        self.r = Gamma(2. * tf.ones([
            disp_size,
        ]), 1. * tf.ones([
            disp_size,
        ]))

        self.l = TransformedDistribution(
            distribution=Normal(self.mean_llib * tf.ones([N, 1]),
                                np.sqrt(self.std_llib) * tf.ones([N, 1])),
            bijector=tf.contrib.distributions.bijectors.Exp())

        if batch_idx is not None and self.n_batches > 0:
            self.rho = tf.matmul(
                tf.concat([
                    self.z,
                    tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                            tf.float32)
                ],
                          axis=1), self.W0)
        else:
            self.rho = tf.matmul(self.z, self.W0)

        if self.scalings:
            self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1),
                                             (-1, 1))  # NxP
            self.lam = Gamma(self.r, self.r / (self.rho * self.l))
        else:
            self.lam = Gamma(self.r, self.r / self.rho)

        if self.zero_inflation:
            if batch_idx is not None and self.n_batches > 0:
                self.logit_pi = tf.matmul(
                    tf.concat([
                        self.z,
                        tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches),
                                tf.float32)
                    ],
                              axis=1), self.W1)
            else:
                self.logit_pi = tf.matmul(self.z, self.W1)
            self.pi = tf.minimum(
                tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7)

            self.cat = Categorical(
                probs=tf.stack([self.pi, 1. - self.pi], axis=2))

            self.components = [
                Poisson(rate=1e-30 * tf.ones([N, P])),
                Poisson(rate=self.lam)
            ]

            self.likelihood = Mixture(cat=self.cat, components=self.components)
        else:
            self.likelihood = Poisson(rate=self.lam)