Example #1
0
 def _free_energy(self, v):
     K = float(self.n_hidden)
     M = float(self.n_samples)
     with tf.name_scope('free_energy'):
         T1 = -tf.einsum('ij,j->i', v, self._vb)
         T2 = -tf.matmul(v, self._W)
         h_hat = Multinomial(total_count=M, logits=tf.ones([K])).sample()
         T3 = tf.einsum('ij,j->i', T2, h_hat)
         fe = tf.reduce_mean(T1 + T3, axis=0)
         fe += -tf.lgamma(M + K) + tf.lgamma(M + 1) + tf.lgamma(K)
     return fe
def multinomial(policy, game_state):

    ## identify the free positions:
    free_positions = tf.to_float(tf.equal(game_state, tf.zeros((1, 9))))

    fm_mapping = lambda x: tf.diag(tf.reshape(x, (9, )))

    free_matrices = tf.map_fn(fm_mapping, free_positions)

    ## calculate probability vector:
    pvec_mapping = lambda x: tf.transpose(tf.matmul(x, tf.transpose(policy)))

    prob_vec = tf.map_fn(pvec_mapping, free_matrices)
    prob = prob_vec / (tf.reduce_sum(prob_vec) + tf.constant(1e-5))

    return Multinomial(total_count=1., probs=prob)
Example #3
0
    def __call__(self, session, trainX, trainY, testX, testY):
        """ Initialize the actual graph

        Parameters
        ----------
        session : tf.Session
            Tensorflow session
        trainX : sparse array in coo format
            Test input OTU table, where rows are samples and columns are
            observations
        trainY : np.array
            Test output metabolite table
        testX : sparse array in coo format
            Test input OTU table, where rows are samples and columns are
            observations.  This is mainly for cross validation.
        testY : np.array
            Test output metabolite table.  This is mainly for cross validation.
        """
        self.session = session
        self.nnz = len(trainX.data)
        self.d1 = trainX.shape[1]
        self.d2 = trainY.shape[1]
        self.cv_size = len(testX.data)

        # keep the multinomial sampling on the cpu
        # https://github.com/tensorflow/tensorflow/issues/18058
        with tf.device('/cpu:0'):
            X_ph = tf.SparseTensor(indices=np.array([trainX.row,
                                                     trainX.col]).T,
                                   values=trainX.data,
                                   dense_shape=trainX.shape)
            Y_ph = tf.constant(trainY, dtype=tf.float32)

            X_holdout = tf.SparseTensor(indices=np.array(
                [testX.row, testX.col]).T,
                                        values=testX.data,
                                        dense_shape=testX.shape)
            Y_holdout = tf.constant(testY, dtype=tf.float32)

            total_count = tf.reduce_sum(Y_ph, axis=1)
            batch_ids = tf.multinomial(
                tf.log(tf.reshape(X_ph.values, [1, -1])), self.batch_size)
            batch_ids = tf.squeeze(batch_ids)
            X_samples = tf.gather(X_ph.indices, 0, axis=1)
            X_obs = tf.gather(X_ph.indices, 1, axis=1)
            sample_ids = tf.gather(X_samples, batch_ids)

            Y_batch = tf.gather(Y_ph, sample_ids)
            X_batch = tf.gather(X_obs, batch_ids)

        with tf.device(self.device_name):
            self.qUmain = tf.Variable(tf.random_normal([self.d1, self.p]),
                                      name='qU')
            self.qUbias = tf.Variable(tf.random_normal([self.d1, 1]),
                                      name='qUbias')
            self.qVmain = tf.Variable(tf.random_normal([self.p, self.d2 - 1]),
                                      name='qV')
            self.qVbias = tf.Variable(tf.random_normal([1, self.d2 - 1]),
                                      name='qVbias')

            qU = tf.concat([tf.ones([self.d1, 1]), self.qUbias, self.qUmain],
                           axis=1)
            qV = tf.concat(
                [self.qVbias,
                 tf.ones([1, self.d2 - 1]), self.qVmain], axis=0)

            # regression coefficents distribution
            Umain = Normal(loc=tf.zeros([self.d1, self.p]) + self.u_mean,
                           scale=tf.ones([self.d1, self.p]) * self.u_scale,
                           name='U')
            Ubias = Normal(loc=tf.zeros([self.d1, 1]) + self.u_mean,
                           scale=tf.ones([self.d1, 1]) * self.u_scale,
                           name='biasU')

            Vmain = Normal(loc=tf.zeros([self.p, self.d2 - 1]) + self.v_mean,
                           scale=tf.ones([self.p, self.d2 - 1]) * self.v_scale,
                           name='V')
            Vbias = Normal(loc=tf.zeros([1, self.d2 - 1]) + self.v_mean,
                           scale=tf.ones([1, self.d2 - 1]) * self.v_scale,
                           name='biasV')

            du = tf.gather(qU, X_batch, axis=0, name='du')
            dv = tf.concat([tf.zeros([self.batch_size, 1]), du @ qV],
                           axis=1,
                           name='dv')

            tc = tf.gather(total_count, sample_ids)
            Y = Multinomial(total_count=tc, logits=dv, name='Y')
            num_samples = trainX.shape[0]
            norm = num_samples / self.batch_size
            logprob_vmain = tf.reduce_sum(Vmain.log_prob(self.qVmain),
                                          name='logprob_vmain')
            logprob_vbias = tf.reduce_sum(Vbias.log_prob(self.qVbias),
                                          name='logprob_vbias')
            logprob_umain = tf.reduce_sum(Umain.log_prob(self.qUmain),
                                          name='logprob_umain')
            logprob_ubias = tf.reduce_sum(Ubias.log_prob(self.qUbias),
                                          name='logprob_ubias')
            logprob_y = tf.reduce_sum(Y.log_prob(Y_batch), name='logprob_y')
            self.log_loss = -(logprob_y * norm + logprob_umain +
                              logprob_ubias + logprob_vmain + logprob_vbias)

        # keep the multinomial sampling on the cpu
        # https://github.com/tensorflow/tensorflow/issues/18058
        with tf.device('/cpu:0'):
            # cross validation
            with tf.name_scope('accuracy'):
                cv_batch_ids = tf.multinomial(
                    tf.log(tf.reshape(X_holdout.values, [1, -1])),
                    self.cv_size)
                cv_batch_ids = tf.squeeze(cv_batch_ids)
                X_cv_samples = tf.gather(X_holdout.indices, 0, axis=1)
                X_cv = tf.gather(X_holdout.indices, 1, axis=1)
                cv_sample_ids = tf.gather(X_cv_samples, cv_batch_ids)

                Y_cvbatch = tf.gather(Y_holdout, cv_sample_ids)
                X_cvbatch = tf.gather(X_cv, cv_batch_ids)
                holdout_count = tf.reduce_sum(Y_cvbatch, axis=1)
                cv_du = tf.gather(qU, X_cvbatch, axis=0, name='cv_du')
                pred = tf.reshape(holdout_count, [-1, 1]) * tf.nn.softmax(
                    tf.concat([tf.zeros([self.cv_size, 1]), cv_du @ qV],
                              axis=1,
                              name='pred'))

                self.cv = tf.reduce_mean(tf.squeeze(tf.abs(pred - Y_cvbatch)))

        # keep all summaries on the cpu
        with tf.device('/cpu:0'):
            tf.summary.scalar('logloss', self.log_loss)
            tf.summary.scalar('cv_rmse', self.cv)
            tf.summary.histogram('qUmain', self.qUmain)
            tf.summary.histogram('qVmain', self.qVmain)
            tf.summary.histogram('qUbias', self.qUbias)
            tf.summary.histogram('qVbias', self.qVbias)
            self.merged = tf.summary.merge_all()

            self.writer = tf.summary.FileWriter(self.save_path,
                                                self.session.graph)

        with tf.device(self.device_name):
            with tf.name_scope('optimize'):
                optimizer = tf.train.AdamOptimizer(self.learning_rate,
                                                   beta1=self.beta_1,
                                                   beta2=self.beta_2)

                gradients, self.variables = zip(
                    *optimizer.compute_gradients(self.log_loss))
                self.gradients, _ = tf.clip_by_global_norm(
                    gradients, self.clipnorm)
                self.train = optimizer.apply_gradients(
                    zip(self.gradients, self.variables))

        tf.global_variables_initializer().run()
Example #4
0
 def _sample(self, means):
     probs = tf.to_float(means / tf.reduce_sum(means))
     return Multinomial(total_count=self.n_samples, probs=probs)
Example #5
0
def main(_):

    opts = Options(save_path=FLAGS.save_path,
                   train_biom=FLAGS.train_biom,
                   test_biom=FLAGS.test_biom,
                   train_metadata=FLAGS.train_metadata,
                   test_metadata=FLAGS.test_metadata,
                   formula=FLAGS.formula,
                   learning_rate=FLAGS.learning_rate,
                   clipping_size=FLAGS.clipping_size,
                   beta_mean=FLAGS.beta_mean,
                   beta_scale=FLAGS.beta_scale,
                   gamma_mean=FLAGS.gamma_mean,
                   gamma_scale=FLAGS.gamma_scale,
                   epochs_to_train=FLAGS.epochs_to_train,
                   num_neg_samples=FLAGS.num_neg_samples,
                   batch_size=FLAGS.batch_size,
                   min_sample_count=FLAGS.min_sample_count,
                   min_feature_count=FLAGS.min_feature_count,
                   statistics_interval=FLAGS.statistics_interval,
                   summary_interval=FLAGS.summary_interval,
                   checkpoint_interval=FLAGS.checkpoint_interval)
    # preprocessing
    train_table, train_metadata = opts.train_table, opts.train_metadata
    train_metadata = train_metadata.loc[train_table.ids(axis='sample')]

    sample_filter = lambda val, id_, md: (
        (id_ in train_metadata.index) and np.sum(val) > opts.min_sample_count)
    read_filter = lambda val, id_, md: np.sum(val) > opts.min_feature_count
    metadata_filter = lambda val, id_, md: id_ in train_metadata.index

    train_table = train_table.filter(metadata_filter, axis='sample')
    train_table = train_table.filter(sample_filter, axis='sample')
    train_table = train_table.filter(read_filter, axis='observation')
    train_metadata = train_metadata.loc[train_table.ids(axis='sample')]
    sort_f = lambda xs: [xs[train_metadata.index.get_loc(x)] for x in xs]
    train_table = train_table.sort(sort_f=sort_f, axis='sample')
    train_metadata = dmatrix(opts.formula,
                             train_metadata,
                             return_type='dataframe')

    # hold out data preprocessing
    test_table, test_metadata = opts.test_table, opts.test_metadata
    metadata_filter = lambda val, id_, md: id_ in test_metadata.index
    obs_lookup = set(train_table.ids(axis='observation'))
    feat_filter = lambda val, id_, md: id_ in obs_lookup

    test_table = test_table.filter(metadata_filter, axis='sample')
    test_table = test_table.filter(feat_filter, axis='observation')
    test_metadata = test_metadata.loc[test_table.ids(axis='sample')]
    sort_f = lambda xs: [xs[test_metadata.index.get_loc(x)] for x in xs]
    test_table = test_table.sort(sort_f=sort_f, axis='sample')
    test_metadata = dmatrix(opts.formula,
                            test_metadata,
                            return_type='dataframe')

    p = train_metadata.shape[1]  # number of covariates
    G_data = train_metadata.values
    y_data = np.array(train_table.matrix_data.todense()).T
    y_test = np.array(test_table.matrix_data.todense()).T
    N, D = y_data.shape
    save_path = opts.save_path
    learning_rate = opts.learning_rate
    batch_size = opts.batch_size
    gamma_mean, gamma_scale = opts.gamma_mean, opts.gamma_scale
    beta_mean, beta_scale = opts.beta_mean, opts.beta_scale
    num_iter = (N // batch_size) * opts.epochs_to_train
    holdout_size = test_metadata.shape[0]
    checkpoint_interval = opts.checkpoint_interval

    # Model code
    with tf.Graph().as_default(), tf.Session() as session:
        with tf.device("/cpu:0"):
            # Place holder variables to accept input data
            G_ph = tf.placeholder(tf.float32, [batch_size, p], name='G_ph')
            Y_ph = tf.placeholder(tf.float32, [batch_size, D], name='Y_ph')
            G_holdout = tf.placeholder(tf.float32, [holdout_size, p],
                                       name='G_holdout')
            Y_holdout = tf.placeholder(tf.float32, [holdout_size, D],
                                       name='Y_holdout')
            total_count = tf.placeholder(tf.float32, [batch_size],
                                         name='total_count')

            # Define PointMass Variables first
            qgamma = tf.Variable(tf.random_normal([1, D]), name='qgamma')
            qbeta = tf.Variable(tf.random_normal([p, D]), name='qB')

            # Distributions
            # species bias
            gamma = Normal(loc=tf.zeros([1, D]) + gamma_mean,
                           scale=tf.ones([1, D]) * gamma_scale,
                           name='gamma')
            # regression coefficents distribution
            beta = Normal(loc=tf.zeros([p, D]) + beta_mean,
                          scale=tf.ones([p, D]) * beta_scale,
                          name='B')

            Bprime = tf.concat([qgamma, qbeta], axis=0)

            # add bias terms for samples
            Gprime = tf.concat([tf.ones([batch_size, 1]), G_ph], axis=1)

            eta = tf.matmul(Gprime, Bprime)
            phi = tf.nn.log_softmax(eta)
            Y = Multinomial(total_count=total_count, logits=phi, name='Y')

            loss = -(tf.reduce_mean(gamma.log_prob(qgamma)) + \
                     tf.reduce_mean(beta.log_prob(qbeta)) + \
                     tf.reduce_mean(Y.log_prob(Y_ph)) * (N / batch_size))
            loss = tf.Print(loss, [loss])
            optimizer = tf.train.AdamOptimizer(learning_rate)

            gradients, variables = zip(*optimizer.compute_gradients(loss))
            gradients, _ = tf.clip_by_global_norm(gradients,
                                                  opts.clipping_size)
            train = optimizer.apply_gradients(zip(gradients, variables))

            with tf.name_scope('accuracy'):
                holdout_count = tf.reduce_sum(Y_holdout, axis=1)
                pred = tf.reshape(holdout_count, [-1, 1]) * tf.nn.softmax(
                    tf.matmul(G_holdout, qbeta) + qgamma)
                mse = tf.reduce_mean(tf.squeeze(tf.abs(pred - Y_holdout)))
                tf.summary.scalar('mean_absolute_error', mse)

            tf.summary.scalar('loss', loss)
            tf.summary.histogram('qbeta', qbeta)
            tf.summary.histogram('qgamma', qgamma)
            merged = tf.summary.merge_all()

            tf.global_variables_initializer().run()

            writer = tf.summary.FileWriter(save_path, session.graph)

            losses = np.array([0.] * num_iter)
            idx = np.arange(train_metadata.shape[0])
            log_handle = open(os.path.join(save_path, 'run.log'), 'w')

            last_checkpoint_time = 0
            start_time = time.time()
            saver = tf.train.Saver()
            for i in range(num_iter):
                batch_idx = np.random.choice(idx, size=batch_size)
                feed_dict = {
                    Y_ph: y_data[batch_idx].astype(np.float32),
                    G_ph: train_metadata.values[batch_idx].astype(np.float32),
                    Y_holdout: y_test.astype(np.float32),
                    G_holdout: test_metadata.values.astype(np.float32),
                    total_count:
                    y_data[batch_idx].sum(axis=1).astype(np.float32)
                }

                if i % 1000 == 0:
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    _, summary, train_loss, grads = session.run(
                        [train, merged, loss, gradients],
                        feed_dict=feed_dict,
                        options=run_options,
                        run_metadata=run_metadata)
                    writer.add_run_metadata(run_metadata, 'step%d' % i)
                    writer.add_summary(summary, i)
                elif i % 5000 == 0:
                    _, summary, err, train_loss, grads = session.run(
                        [train, mse, merged, loss, gradients],
                        feed_dict=feed_dict)
                    writer.add_summary(summary, i)
                else:
                    _, summary, train_loss, grads = session.run(
                        [train, merged, loss, gradients], feed_dict=feed_dict)
                    writer.add_summary(summary, i)

                now = time.time()
                if now - last_checkpoint_time > checkpoint_interval:
                    saver.save(session,
                               os.path.join(opts.save_path, "model.ckpt"),
                               global_step=i)
                    last_checkpoint_time = now

                losses[i] = train_loss
            elapsed_time = time.time() - start_time
            print('Elapsed Time: %f seconds' % elapsed_time)

            # Cross validation
            pred_beta = qbeta.eval()
            pred_gamma = qgamma.eval()
            mse, mrc = cross_validation(test_metadata.values, pred_beta,
                                        pred_gamma, y_test)
            print("MSE: %f, MRC: %f" % (mse, mrc))
Example #6
0
    def __call__(self, session, trainX, trainY, testX, testY):
        """ Initialize the actual graph

        Parameters
        ----------
        session : tf.Session
            Tensorflow session
        trainX : np.array
            Input training design matrix.
        trainY : np.array
            Output training OTU table, where rows are samples and columns are
            observations.
        testX : np.array
            Input testing design matrix.
        testY : np.array
            Output testing OTU table, where rows are samples and columns are
            observations.
        """
        self.session = session
        self.N, self.p = trainX.shape
        self.D = trainY.shape[1]
        holdout_size = testX.shape[0]

        # Place holder variables to accept input data
        self.X_ph = tf.constant(trainX, dtype=tf.float32, name='G_ph')
        self.Y_ph = tf.constant(trainY, dtype=tf.float32, name='Y_ph')
        self.X_holdout = tf.constant(testX, dtype=tf.float32, name='G_holdout')
        self.Y_holdout = tf.constant(testY, dtype=tf.float32, name='Y_holdout')

        batch_ids = tf.multinomial(tf.ones([1, self.N]), self.batch_size)
        sample_ids = tf.squeeze(batch_ids)

        Y_batch = tf.gather(self.Y_ph, sample_ids, axis=0)
        X_batch = tf.gather(self.X_ph, sample_ids, axis=0)

        total_count = tf.reduce_sum(Y_batch, axis=1)
        holdout_count = tf.reduce_sum(self.Y_holdout, axis=1)

        # Define PointMass Variables first
        self.qbeta = tf.Variable(tf.random_normal([self.p, self.D - 1]),
                                 name='qB')

        # regression coefficents distribution
        beta = Normal(loc=tf.zeros([self.p, self.D - 1]) + self.beta_mean,
                      scale=tf.ones([self.p, self.D - 1]) * self.beta_scale,
                      name='B')

        eta = tf.matmul(X_batch, self.qbeta, name='eta')

        phi = tf.nn.log_softmax(tf.concat(
            [tf.zeros([self.batch_size, 1]), eta], axis=1),
                                name='phi')

        Y = Multinomial(total_count=total_count, logits=phi, name='Y')

        # cross validation
        with tf.name_scope('accuracy'):
            pred = tf.reshape(holdout_count, [-1, 1]) * tf.nn.softmax(
                tf.concat([
                    tf.zeros([holdout_size, 1]),
                    tf.matmul(self.X_holdout, self.qbeta)
                ],
                          axis=1),
                name='phi')

            self.cv = tf.reduce_mean(tf.squeeze(tf.abs(pred - self.Y_holdout)))
            tf.summary.scalar('mean_absolute_error', self.cv)

        self.loss = -(tf.reduce_sum(beta.log_prob(self.qbeta)) +
                      tf.reduce_sum(Y.log_prob(Y_batch)) *
                      (self.N / self.batch_size))

        optimizer = tf.train.AdamOptimizer(self.learning_rate,
                                           beta1=self.beta_1,
                                           beta2=self.beta_2)

        gradients, variables = zip(*optimizer.compute_gradients(self.loss))
        self.gradients, _ = tf.clip_by_global_norm(gradients, self.clipnorm)
        self.train = optimizer.apply_gradients(zip(gradients, variables))

        tf.summary.scalar('loss', self.loss)
        tf.summary.histogram('qbeta', self.qbeta)
        self.merged = tf.summary.merge_all()
        if self.save_path is not None:
            self.writer = tf.summary.FileWriter(self.save_path,
                                                self.session.graph)
        else:
            self.writer = None
        tf.global_variables_initializer().run()
Example #7
0
    def _initialize_parameters(self, hparams, ppm):

        K = np.float32(self.K)

        su, tu, a, b, self.size_u = (hparams['su'], hparams['tu'], hparams['a'], hparams['b'], hparams['size_u'])
        si, ti, c, d, self.size_i = (hparams['si'], hparams['ti'], hparams['c'], hparams['d'], hparams['size_i'])

        with tf.name_scope("hparams"), tf.device(self.device):
            ## Hyperparameters
            self.lsu = tf.Variable(softplus_inverse(-hparams['su'] + 1.), dtype=tf.float32, name="lsu")
            self.su = -tf.nn.softplus(self.lsu) + 1.

            self.tu = tf.Variable(hparams['tu'], dtype=tf.float32, name="tu")

            self.a = tf.Variable(hparams['a'], dtype=tf.float32, name="a")
            self.b = tf.Variable(hparams['b'], dtype=tf.float32, name="b")

            self.lsi = tf.Variable(softplus_inverse(-hparams['si'] + 1.), dtype=tf.float32, name="lsi")
            self.si = -tf.nn.softplus(self.lsi) + 1.

            self.ti = tf.Variable(hparams['ti'], dtype=tf.float32, name="ti")

            self.c = tf.Variable(hparams['c'], dtype=tf.float32, name="c")
            self.d = tf.Variable(hparams['d'], dtype=tf.float32, name="d")

        e = np.sum(self.edge_vals_d, dtype=np.float32)

        # initial values for total user and total item masses of type K
        # set st \sum_k tim_k * tum_k = e (which is in fact a bit higher than it oughta be)
        # and using item_mass / user_mass ~ item_size / user_size (which is only kind of true)
        tum_init = np.sqrt(self.size_u / self.size_i * e / K)
        tim_init = np.sqrt(self.size_i / self.size_u * e / K)

        with tf.name_scope("user_params"), tf.device(self.device):
            # shape params are read off immediately from update equations
            # rate params set to be consistent w \gam_i ~ 1, \sum_j beta_jk beta_k ~ \sqrt(e/k) (which is self consistent)
            if ppm :
                # If creating the principled predictive (ppm), don't have the user_degree. Just create some random initialization for now, we'll update it with a default value
                self.gam_shp = tf.Variable(tf.random_gamma([self.U, 1], 5., 5., seed=self.seed), dtype=tf.float32, name="gam_rte") 
                self.gam_rte = tf.Variable(tf.random_gamma([self.U, 1], 5., 5., seed=self.seed), dtype=tf.float32, name="gam_rte") 
                self.theta_shp = tf.Variable(tf.random_gamma([self.U, self.K], 10., 10., seed=self.seed), name="theta_shp")
                self.theta_rte =tf.Variable(tf.random_gamma([self.U, self.K], 5., 5., seed=self.seed), name="theta_rte") 
                self.g = tf.Variable(tf.random_gamma([self.K, 1], 0.001, 1, seed=self.seed) + TINY, name="g") 
            else:
                user_degs = np.expand_dims(self.user_degree, axis=1)
                self.gam_shp = tf.Variable((user_degs - su), name="gam_shp")  # s^U
                self.gam_rte = tf.Variable(np.sqrt(e) * (0.9 + 0.1*tf.random_gamma([self.U, 1], 5., 5., seed=self.seed)), dtype=tf.float32, name="gam_rte")  # r^U
                init_gam_mean = self.gam_shp.initial_value / self.gam_rte.initial_value
                self.theta_shp = tf.Variable((a + user_degs/K) * tf.random_gamma([self.U, self.K], 10., 10., seed=self.seed), name="theta_shp")  # kap^U
                self.theta_rte = tf.Variable((b + init_gam_mean * tim_init)*(0.9 + 0.1*tf.random_gamma([self.U, self.K], 5., 5., seed=self.seed)), name="theta_rte")  # lam^U
                self.g = tf.Variable(tf.random_gamma([self.K, 1], 0.001, 1, seed=self.seed) + TINY, name="g")  # g


        with tf.name_scope("item_params"), tf.device(self.device):
            ## Items
            if ppm:
                self.omega_shp = tf.Variable(tf.random_gamma([self.I, 1], 5., 5., seed=self.seed), name="omega_shp")  # s^I
                self.omega_rte = tf.Variable(tf.random_gamma([self.I, 1], 5., 5., seed=self.seed), dtype=tf.float32, name="omega_rte")  # r^I
                self.beta_shp = tf.Variable(tf.random_gamma([self.I, self.K], 10., 10., seed=self.seed), name="beta_shp")  # kap^I
                self.beta_rte = tf.Variable(tf.random_gamma([self.I, self.K], 5., 5., seed=self.seed), name="beta_rte")  # lam^I
                self.w = tf.Variable(tf.random_gamma([self.K, 1], 0.001, 1, seed=self.seed) + TINY, name="w")  # w
            else:
                item_degs = np.expand_dims(self.item_degree, axis=1)
                self.omega_shp = tf.Variable((item_degs - si), name="omega_shp")  # s^I
                self.omega_rte = tf.Variable(np.sqrt(e) * (0.9 + 0.1*tf.random_gamma([self.I, 1], 5., 5., seed=self.seed)), dtype=tf.float32, name="omega_rte")  # r^I
                init_omega_mean = self.omega_shp.initial_value / self.omega_rte.initial_value
                self.beta_shp = tf.Variable((c + item_degs/K) * tf.random_gamma([self.I, self.K], 10., 10., seed=self.seed), name="beta_shp")  # kap^I
                self.beta_rte = tf.Variable((d + init_omega_mean*tum_init) * (0.9 + 0.1*tf.random_gamma([self.I, self.K], 5., 5., seed=self.seed)), name="beta_rte")  # lam^I
                self.w = tf.Variable(tf.random_gamma([self.K, 1], 0.001, 1, seed=self.seed) + TINY, name="w")  # w

        with tf.device('/cpu:0'):
            with tf.variable_scope("edge_params", reuse=None):
                ## Edges
                if self.simple_graph:
                    # set init value so there's approximately 1 expected edge between each pair... WARNING: this may be profoundly stupid
                    self.sg_edge_param = tf.get_variable(name="sg_edge_param", shape=[self.occupied_pairs, self.K], dtype=tf.float32,
                                    initializer=tf.random_normal_initializer(mean=-np.log(K), stddev=1. / K, seed=self.seed),
                                    partitioner=tf.fixed_size_partitioner(self.edge_param_splits, 0))
                else:
                    self.lphi = tf.get_variable(name="lphi", shape=[self.occupied_pairs, self.K], dtype=tf.float32,
                                    initializer=tf.random_normal_initializer(mean=0, stddev=1. / K, seed=self.seed),
                                    partitioner=tf.fixed_size_partitioner(self.edge_param_splits, 0))

        with tf.name_scope("variational_post"), tf.device(self.device):

            # Variational posterior distributions
            self.q_gam = Gamma(concentration=self.gam_shp, rate=self.gam_rte, name="q_gam")
            self.q_theta = Gamma(concentration=self.theta_shp, rate=self.theta_rte, name="q_theta")
            self.q_g = PointMass(self.g, name="q_g")

            self.q_omega = Gamma(concentration=self.omega_shp, rate=self.omega_rte, name="q_omega")
            self.q_beta = Gamma(concentration=self.beta_shp, rate=self.beta_rte, name="q_beta")
            self.q_w = PointMass(self.w, name="q_w")

            if self.simple_graph:
                self.q_e_aux_vals = tPoissonMulti(log_lams=self.sg_edge_param, name="q_e_aux_vals") # q_edges_aux_flat
            else:
                self.q_e_aux_vals = Multinomial(total_count=self.edge_vals, logits=self.lphi, name="q_e_aux_vals") # q_edges_aux_flat
                self.q_e_aux_vals_mean = self.q_e_aux_vals.mean()

        with tf.name_scope("degree_vars"):
            # create some structures to make it easy to work with the expected value (wrt q) of the edges

            # qm_du[u,k] is the expected weighted degree of user u counting only edges of type k
            # qm_du[u,k] = E_q[e^k_i.] in the language of the paper
            # initialized arbitrarily, will override at end of init to set to
            # we use a tf.Variable here to cache the q_e_aux_vals.mean() value
            self.qm_du = tf.Variable(tf.ones([self.U, self.K], dtype=tf.float32), name="qm_du")
            self.qm_di = tf.Variable(tf.ones([self.I, self.K], dtype=tf.float32), name="qm_di")

        # Total Item Mass:
        self.i_tot_mass_m = self.q_w.mean() + tf.matmul(self.q_beta.mean(), self.q_omega.mean(), transpose_a=True)
        # Total User Mass:
        self.u_tot_mass_m = self.q_g.mean() + tf.matmul(self.q_theta.mean(), self.q_gam.mean(), transpose_a=True)