예제 #1
0
    def up_apply_node_func(self, nodes):
        x = nodes.data['evid']  # represents P(x_u | Q_u) have size bs x h

        if 'gamma_r' in nodes.data:
            gamma_r = nodes.data['gamma_r']
            gamma_ch = nodes.data['gamma_ch']
            gamma_p_ch = nodes.data['gamma_p_ch']

            U_out = self.__gather_param__(
                self.U_output,
                types=nodes.data['t'] if self.num_types > 1 else None)
            # U_out has shape bs x rank x h
            beta = thlp.sum_over(
                thlp.mul(nodes.data['gamma_r'].unsqueeze(2), U_out), 1)
        else:
            beta = self.__gather_param__(
                self.p,
                types=nodes.data['t'] if self.num_types > 1 else None,
                pos=nodes.data['pos'] if not self.pos_stationarity else None)
            bs = beta.size(0)
            gamma_r = th.zeros((bs, self.rank))
            gamma_ch = th.zeros(
                (bs, self.max_output_degree, self.h_size, self.rank))
            gamma_p_ch = th.zeros((bs, self.max_output_degree, self.rank))

        beta = thlp.mul(x, beta)  # has shape (bs x h)
        beta, N_u = thlp.normalise(beta, 1, get_Z=True)

        return {
            'beta': beta,
            'N_u': N_u,
            'gamma_r': gamma_r,
            'gamma_ch': gamma_ch,
            'gamma_p_ch': gamma_p_ch
        }
예제 #2
0
    def up_reduce_func(self, nodes):
        bs = nodes.mailbox['beta_ch'].shape[0]
        n_ch = nodes.mailbox['beta_ch'].shape[1]
        beta_ch = thlp.zeros(bs, self.max_output_degree, self.h_size)
        beta_ch[:, :n_ch, :] = nodes.mailbox['beta_ch']

        # compute beta_r
        U = self.__gather_param__(
            self.U,
            types=nodes.data['t'] if self.num_types > 1 else None,
        )
        # U has shape bs x L x h x rank
        gamma_ch_rl = thlp.sum_over(thlp.mul(U, beta_ch.unsqueeze(3)),
                                    2)  # has shape bs x L x rank

        G = self.__gather_param__(
            self.G, types=nodes.data['t'] if self.num_types > 1 else None)

        for i in range(self.max_output_degree):
            # TODO: we are assuming last are bottom
            if i < n_ch:
                btm = thlp.zeros(bs, 1)
                x = th.cat((gamma_ch_rl[:, i, :], btm), 1)
            else:
                x = thlp.zeros(1, self.rank + 1)
                x[:, -1] = 0

            new_shape = [x.shape[0]] + [1] * i + [
                self.rank + 1
            ] + [1] * (self.max_output_degree - i)
            G = thlp.mul(G, x.view(*new_shape))

        gamma_r = thlp.sum_over(G, list(range(1, self.max_output_degree + 1)))
        return {'gamma_r': gamma_r, 'gamma_ch_all': G, 'beta_ch': beta_ch}
예제 #3
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        gamma_ch = nodes.data['gamma_ch']  # has shape (bs x L x h x h)
        gamma_p_ch = nodes.data['gamma_p_ch'].unsqueeze(
            2)  # has shape (bs x L x 1 x h)
        gamma_r = nodes.data['gamma_r'].unsqueeze(1).unsqueeze(
            2)  # has shape (bs x 1 x 1 x h)
        n_ch_mask = gamma_p_ch.exp().sum((2, 3), keepdim=True)
        a = thlp.div(gamma_ch, gamma_p_ch * n_ch_mask)
        a = thlp.mul(a, gamma_r)
        b = thlp.sum_over(a, 2, keepdim=True)
        # P(Q_l, Q_ch_l | X)
        eta_u_chl = thlp.div(thlp.mul(a,
                                      eta_u.unsqueeze(1).unsqueeze(2)),
                             b * n_ch_mask)  # has shape (bs x L x h x h)

        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        self.accumulate_posterior(
            self.U,
            eta_u_chl[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf]
            if not self.pos_stationarity else None)

        return {'eta_ch': thlp.sum_over(eta_u_chl, 3), 'eta': eta_u}
예제 #4
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        U_out = self.__gather_param__(
            self.U_output,
            types=nodes.data['t'] if self.num_types > 1 else None)

        gamma_r = nodes.data['gamma_r']  # has shape (bs x rank)
        a = thlp.mul(U_out, gamma_r.unsqueeze(2))  # has shape bs x rank x h
        b = thlp.sum_over(a, 1, keepdim=True)
        # P(Q_u, R_u | X)
        eta_ur = thlp.mul(thlp.div(a, b),
                          eta_u.unsqueeze(1))  # has shape bs x rank x h

        eta_r = thlp.sum_over(eta_ur, 2).unsqueeze(1).unsqueeze(
            2)  # has shape (bs x 1 x 1 x rank)
        gamma_ch = nodes.data['gamma_ch']  # has shape (bs x L x h x rank)
        gamma_p_ch = nodes.data['gamma_p_ch'].unsqueeze(
            2)  # has shape (bs x L x 1 x rank)
        gamma_r = gamma_r.unsqueeze(1).unsqueeze(
            2)  # has shape (bs x 1 x 1 x rank)
        n_ch_mask = gamma_p_ch.exp().sum((2, 3), keepdim=True)
        a = thlp.div(gamma_ch, gamma_p_ch * n_ch_mask)
        a = thlp.mul(a, gamma_r)
        b = thlp.sum_over(a, 2, keepdim=True)
        # P(Q_l, R_l | X)
        eta_ur_ch = thlp.div(thlp.mul(a, eta_r),
                             b * n_ch_mask)  # has shape (bs x L x h x rank)

        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        self.accumulate_posterior(
            self.U,
            eta_ur_ch[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.U_output,
            eta_ur[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf]
            if not self.pos_stationarity else None)

        return {'eta_ch': thlp.sum_over(eta_ur_ch, 3), 'eta': eta_u}
예제 #5
0
    def up_apply_node_func(self, nodes):
        x = nodes.data['evid']  # represents P(x_u | Q_u) have size bs x h

        if 'beta_np' in nodes.data:
            beta = nodes.data['beta_np']  # has shape (bs x h)
            beta_np = nodes.data['beta_np']
            beta_ch = nodes.data['beta_ch']
        else:
            beta = self.__gather_param__(
                self.p, types=nodes.data['t'] if self.num_types > 1 else None)
            bs = beta.size(0)
            beta_np = th.zeros((bs, self.h_size))
            beta_ch = th.zeros(
                [bs] +
                [self.h_size + 1
                 for i in range(self.max_output_degree)] + [self.h_size])
        beta = thlp.mul(x, beta)  # has shape (bs x h)

        # normalise
        beta, N_u = thlp.normalise(beta, 1, get_Z=True)

        return {
            'beta': beta,
            'N_u': N_u,
            'beta_np': beta_np,
            'beta_ch': beta_ch
        }
예제 #6
0
    def up_reduce_func(self, nodes):
        bs = nodes.mailbox['beta_ch'].shape[0]
        n_ch = nodes.mailbox['beta_ch'].shape[1]
        beta_ch = thlp.zeros(bs, self.max_output_degree, self.h_size)
        beta_ch[:, :n_ch, :] = nodes.mailbox['beta_ch']
        # TODO: we are assuming beta_ch is ordered accoridng pos. It allows bottom at the end

        # compute beta_r
        U = self.__gather_param__(
            self.U, types=nodes.data['t'] if self.num_types > 1 else None)
        # U has shape bs x L x h x rank+1 x rank

        gamma_ch_rl = thlp.sum_over(
            thlp.mul(U,
                     beta_ch.unsqueeze(3).unsqueeze(4)), 2)
        # has shape bs x L x rank+1 x rank

        gamma_less_l = thlp.zeros(bs, self.max_output_degree, self.rank + 1)
        gamma_less_l[:, 0, :-1] = gamma_ch_rl[:, 0,
                                              -1, :]  # has shape bs x rank
        for i in range(1, n_ch):
            gamma_i = gamma_ch_rl[:, i, :, :]  # has shape bs x (rank+1) x rank
            gamma_prev = gamma_less_l[:, i - 1, :].unsqueeze(
                2)  # has shape bs x (rank+1) x 1
            gamma_less_l[:,
                         i, :-1] = thlp.sum_over(thlp.mul(gamma_i, gamma_prev),
                                                 1)
        gamma_less_l[:, n_ch:, -1] = 0

        # R_out = self.__gather_param__(self.R_output, types=nodes.data['t'] if self.num_types > 1 else None)
        # gamma_r = thlp.sum_over(thlp.mul(R_out, gamma_less_l[:, n_ch-1, :-1].unsqueeze(2)), 1)  # has shape bs x rank
        gamma_r = gamma_less_l[:, n_ch - 1, :-1]

        return {
            'gamma_less_l': gamma_less_l,
            'gamma_r': gamma_r,
            'beta_ch': beta_ch,
            'n_ch': th.full([bs], n_ch, dtype=th.long)
        }
예제 #7
0
    def up_reduce_func(self, nodes):
        bs = nodes.mailbox['beta_ch'].shape[0]
        n_ch = nodes.mailbox['beta_ch'].shape[1]
        beta_ch = thlp.zeros(bs, self.max_output_degree, self.h_size)
        beta_ch[:, :n_ch, :] = nodes.mailbox['beta_ch']

        U = self.__gather_param__(
            self.U, types=nodes.data['t'] if self.num_types > 1 else None)
        # has shape bs x L x h x h
        Sp = self.__gather_param__(
            self.Sp, types=nodes.data['t'] if self.num_types > 1 else None)
        # has shape bs x L
        gamma_ch = thlp.mul(
            thlp.mul(beta_ch.unsqueeze(3), U),
            Sp.unsqueeze(2).unsqueeze(3))  # has shape (bs x L x h x h)
        gamma_p_ch = thlp.sum_over(gamma_ch, 2)  # has shape (bs x L x h)

        gamma_r = thlp.sum_over(gamma_p_ch, 1)  # has shape (bs x h)

        return {
            'gamma_r': gamma_r,
            'gamma_ch': gamma_ch,
            'gamma_p_ch': gamma_p_ch
        }
예제 #8
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        beta_ch = nodes.data['beta_ch']  # has shape (bs x h x ... x h x h)
        beta_np = nodes.data['beta_np']
        new_shape = [-1] + [1] * self.max_output_degree + [self.h_size]
        # P(Q_u, Q_1, ..., Q_L | X)
        eta_uch = thlp.div(thlp.mul(beta_ch, eta_u.view(*new_shape)),
                           beta_np.view(*new_shape))
        # P(Q_1, ..., Q_L | X)
        eta_joint_ch = thlp.sum_over(eta_uch, -1)

        bs = eta_u.shape[0]
        eta_ch = th.empty(bs, self.max_output_degree, self.h_size + 1)
        for i in range(self.max_output_degree):
            sum_over_var = list(
                set(range(1, self.max_output_degree + 1)) - {i + 1})
            eta_ch[:, i, :] = thlp.sum_over(eta_joint_ch, sum_over_var)

        # accumulate posterior
        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        self.accumulate_posterior(
            self.U,
            eta_uch[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None)

        return {'eta_ch': eta_ch[:, :, :-1], 'eta': eta_u}
예제 #9
0
    def up_reduce_func(self, nodes):
        beta_ch = nodes.mailbox['beta_ch']  # has shape (bs x n_ch x h)
        n_ch = beta_ch.shape[1]
        bs = beta_ch.shape[0]

        U = self.__gather_param__(
            self.U, types=nodes.data['t'] if self.num_types > 1 else None)

        for i in range(self.max_output_degree):
            # TODO: we are assuming last are bottom
            if i < n_ch:
                btm = thlp.zeros(bs, 1)
                x = th.cat((beta_ch[:, i, :], btm), 1)
            else:
                x = thlp.zeros(1, self.h_size + 1)
                x[:, -1] = 0

            new_shape = [x.shape[0]] + [1] * i + [
                self.h_size + 1
            ] + [1] * (self.max_output_degree - i)
            U = thlp.mul(U, x.view(*new_shape))

        beta = thlp.sum_over(U, list(range(1, self.max_output_degree + 1)))
        return {'beta_np': beta, 'beta_ch': U}
예제 #10
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        n_ch_list = nodes.data['n_ch'][is_internal] - 1
        gamma_r = nodes.data['gamma_r'][is_internal]  # has shape (bs x rank)
        # has shape bs x rank+1
        gamma_less_l = nodes.data['gamma_less_l'][
            is_internal]  # has shape bs x L x rank+1
        beta_ch = nodes.data['beta_ch'][is_internal]  # has shape bs x L x h
        t = nodes.data['t'][is_internal]

        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf]
            if not self.pos_stationarity else None)

        eta_u_ch_all = thlp.zeros(eta_u.shape[0], self.max_output_degree,
                                  self.h_size)
        if th.any(is_internal):
            # computation only on internal nodes
            # compute P(Q_u, R_u | X)
            U_out = self.__gather_param__(
                self.U_output, types=t if self.num_types > 1 else None)
            # U_out has shape bs x rank x h
            a = thlp.mul(gamma_r.unsqueeze(2), U_out)
            b = thlp.sum_over(a, 1, keepdim=True)
            eta_ur = thlp.div(thlp.mul(a, eta_u[is_internal].unsqueeze(1)),
                              b)  # has shape bs x rank x h

            # compute P(R_u, R_L | X)
            eta_r = thlp.sum_over(eta_ur, 2)  # has shape bs x rank_U
            # R_out = self.__gather_param__(self.R_output, types=t if self.num_types > 1 else None)
            # # R_out has shape bs x rank_L x rank_U
            # a = thlp.mul(R_out, gamma_L[:, :-1].unsqueeze(2))  # has shape bs x rank_L x rank_U
            # b = thlp.sum_over(a, 1, keepdim=True)
            # eta_rul = thlp.div(thlp.mul(a, eta_r.unsqueeze(1)), b)  # has shape bs x rank_L x rank_U

            # compute P(R_l, R_l-1, Q_l | X)
            # eta_rL = thlp.sum_over(eta_rul, 2)  # has shape bs x rank
            eta_rL = eta_r
            U = self.__gather_param__(self.U,
                                      types=t if self.num_types > 1 else None)
            if self.pos_stationarity:
                U = U.expand((-1, self.max_output_degree, -1, -1, -1))
            if U.size(0) == 1:
                U = U.expand((gamma_less_l.size(0), -1, -1, -1, -1))
            # U has shape bs x L x h x rank+1  x rank
            eta_u_ch = thlp.zeros(eta_rL.shape[0], self.max_output_degree,
                                  self.h_size)
            last_eta = eta_rL  # has shape bs x rank
            for i in range(self.max_output_degree - 1, -1, -1):
                pos_flag = i <= n_ch_list
                if th.any(pos_flag):
                    if i > 0:
                        a = thlp.mul(
                            U[pos_flag, i, :, :, :],
                            gamma_less_l[pos_flag,
                                         i - 1, :].unsqueeze(1).unsqueeze(3))
                        a = thlp.mul(
                            a, beta_ch[pos_flag,
                                       i, :].unsqueeze(2).unsqueeze(3))
                    else:
                        a = thlp.zeros(*(U.shape[:1] + U.shape[2:]))
                        a[:, :,
                          -1, :] = thlp.mul(U[:, i, :, -1, :],
                                            beta_ch[:, i, :].unsqueeze(2))
                    b = thlp.sum_over(a, (1, 2), keepdim=True)
                    eta_rul_rlprec = thlp.div(
                        thlp.mul(
                            a,
                            last_eta[pos_flag, :].unsqueeze(1).unsqueeze(2)),
                        b)

                    self.accumulate_posterior(
                        self.U,
                        eta_rul_rlprec,
                        types=t[pos_flag] if self.num_types > 1 else None,
                        pos=th.full(
                            (eta_rul_rlprec.shape[0],
                             1), i, dtype=th.long).squeeze(1)
                        if not self.pos_stationarity else None)

                    eta_u_ch[pos_flag,
                             i, :] = thlp.sum_over(eta_rul_rlprec, (2, 3))
                    last_eta[pos_flag, :] = thlp.sum_over(
                        eta_rul_rlprec, (1, 3))[:, :-1]

            # accumulate posterior
            self.accumulate_posterior(self.U_output,
                                      eta_ur,
                                      types=t if self.num_types > 1 else None)
            # self.accumulate_posterior(self.R_output, eta_rul, types=t if self.num_types > 1 else None)
            eta_u_ch_all[is_internal] = eta_u_ch

        return {'eta_ch': eta_u_ch_all, 'eta': eta_u}
예제 #11
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        bs = eta_u.shape[0]

        gamma_r = nodes.data['gamma_r']  # has shape (bs x rank)
        U_out = self.__gather_param__(
            self.U_output,
            types=nodes.data['t'] if self.num_types > 1 else None)
        # U_out has shape bs x rank x h
        a = thlp.mul(gamma_r.unsqueeze(2), U_out)
        b = thlp.sum_over(a, 1, keepdim=True)
        # P(Q_u, R_u | X)
        eta_ur = thlp.div(thlp.mul(a, eta_u.unsqueeze(1)),
                          b)  # has shape bs x rank x h

        eta_r = thlp.sum_over(eta_ur, 2)  # has shape bs x rank

        gamma_ch_all = nodes.data[
            'gamma_ch_all']  # has shape bs x r x ... x r x r
        new_shape = [-1] + [1] * self.max_output_degree + [self.rank]
        # P(R_u, R_1, ..., R_L | X)
        eta_ru_rch = thlp.div(thlp.mul(gamma_ch_all, eta_r.view(*new_shape)),
                              gamma_r.view(*new_shape))
        # P(R_1, ..., R_L | X)
        eta_rch = thlp.sum_over(eta_ru_rch, -1)  # has shape bs x r+1 x ... r+1

        eta_rl = thlp.zeros(bs, self.max_output_degree, self.rank + 1)
        for i in range(self.max_output_degree):
            sum_over_var = list(
                set(range(1, self.max_output_degree + 1)) - {i + 1})
            eta_rl[:, i, :] = thlp.sum_over(eta_rch, sum_over_var)

        U = self.__gather_param__(
            self.U, types=nodes.data['t'] if self.num_types > 1 else None)
        #  U has shape bs x L x h x rank
        a = thlp.mul(U, nodes.data['beta_ch'].unsqueeze(3))
        b = thlp.sum_over(a, 2, keepdim=True)
        # P(Q_l, R_l | X)
        eta_rql = thlp.div(thlp.mul(a, eta_rl[:, :, :-1].unsqueeze(2)),
                           b)  # has shape bs x L x h x rank

        # accumulate posterior
        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)

        self.accumulate_posterior(
            self.U,
            eta_rql[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.U_output,
            eta_ur[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.G,
            eta_ru_rch[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf])

        return {'eta_ch': thlp.sum_over(eta_rql, 3), 'eta': eta_u}
    def forward(self, *t_list, out_data=None):

        ################################################################################################################
        # UPWARD
        ################################################################################################################
        beta_root_list = []
        loglike_list = []
        for t in t_list:
            # register upward functions
            t.set_n_initializer(dgl.init.zero_initializer)

            # set evidence
            if self.x_emission is not None:
                if self.x_embedding is not None:
                    x_mask = (t.ndata['x'] != ConstValues.NO_ELEMENT)
                    t.ndata['x_embs'] = self.x_embedding(
                        t.ndata['x'] * x_mask) * x_mask.view(-1, 1)
                    t.ndata['x_mask'] = (t.ndata['x'] !=
                                         ConstValues.NO_ELEMENT)
                    evid = self.x_emission.set_evidence(
                        t.ndata['x_embs']) * x_mask.view(-1, 1)
                else:
                    evid = self.x_emission.set_evidence(t.ndata['x'])
            else:
                evid = th.zeros(t.number_of_nodes(),
                                self.state_transition.h_size)

            if self.training and not self.only_root_state and self.y_emission is not None:
                evid = thlp.mul(evid, self.y_emission.set_evidence(out_data))

            t.ndata['evid'] = evid

            # modify types to consider bottom
            t.ndata['t'][t.ndata['t'] == ConstValues.
                         NO_ELEMENT] = self.state_transition.num_types - 1

            # remove -1 in position
            t.ndata['pos'] = t.ndata['pos'] * (
                t.ndata['pos'] != ConstValues.NO_ELEMENT)  #t.ndata['pos_mask']

            # start propagation
            dgl.prop_nodes_topo(
                t,
                message_func=self.state_transition.up_message_func,
                reduce_func=self.state_transition.up_reduce_func,
                apply_node_func=self.state_transition.up_apply_node_func)

            root_ids = [
                i for i in range(t.number_of_nodes()) if t.out_degree(i) == 0
            ]
            beta_root_list.append(t.ndata['beta'][root_ids])

            loglike_list.append(t.ndata['N_u'].sum())
        ################################################################################################################
        eta_root_list = beta_root_list

        if self.only_root_state:
            if self.training:
                if self.y_emission is not None:
                    joint_prob = self.y_emission.set_evidence(
                        out_data)  # bs x h x ... x h
                    bs = joint_prob.size(0)
                    n_vars = len(beta_root_list)
                    for i, beta_i in enumerate(beta_root_list):
                        joint_prob = thlp.mul(
                            joint_prob,
                            beta_i.view(*([bs] + [1] * (i) + [-1] + [1] *
                                          (n_vars - i - 1))))

                    y_eta, y_Z = thlp.normalise(joint_prob,
                                                list(range(1,
                                                           joint_prob.ndim)),
                                                get_Z=True)

                    # accumulate posterior
                    self.y_emission.accumulate_posterior(y_eta, out_data)

                    loglike_list.append(y_Z.sum())

                    eta_root_list = []
                    if len(beta_root_list) == 1:
                        # only one tree, no var elimination is needed
                        eta_root_list.append(y_eta)
                    else:
                        for i in range(len(beta_root_list)):
                            sum_over_vars = list(
                                set(range(1, y_eta.ndim)) - {i + 1})
                            eta_root_list.append(
                                thlp.sum_over(y_eta, sum_over_vars))
            else:
                # we do not need downward
                if self.y_emission is not None:
                    return self.y_emission(*beta_root_list)
                else:
                    return beta_root_list
        ################################################################################################################
        # DOWNWARD
        ################################################################################################################
        all_eta_list = []
        for idx_t, t in enumerate(t_list):
            leaf_ids = [
                i for i in range(t.number_of_nodes()) if t.in_degrees(i) == 0
            ]
            t.ndata['is_leaf'] = th.zeros_like(t.ndata['t'], dtype=th.bool)
            t.ndata['is_leaf'][leaf_ids] = 1

            # set base case for downward recursion
            root_ids = [
                i for i in range(t.number_of_nodes()) if t.out_degree(i) == 0
            ]
            t.ndata['beta'][root_ids] = eta_root_list[idx_t]

            t_rev = self.__reverse_dgl_batch__(t)

            # downward
            t_rev.set_n_initializer(dgl.init.zero_initializer)

            # propagate
            dgl.prop_nodes_topo(
                t_rev,
                message_func=self.state_transition.down_message_func,
                reduce_func=self.state_transition.down_reduce_func,
                apply_node_func=self.state_transition.down_apply_node_func)

            # return the posterior
            eta = t_rev.ndata['eta']
            t.ndata['eta'] = eta

            # append posterior
            all_eta_list.append(eta)

            if self.training:
                # accumulate posterior
                if self.x_embedding is not None:
                    x_mask = t.ndata['x_mask']
                    self.x_emission.accumulate_posterior(
                        t.ndata['eta'][x_mask], t.ndata['x_embs'][x_mask])
                else:
                    self.x_emission.accumulate_posterior(
                        t.ndata['eta'], t.ndata['x'])
                if not self.only_root_state and self.y_emission is not None:
                    self.y_emission.accumulate_posterior(
                        t.ndata['eta'], out_data)
        ################################################################################################################

        # compute the returned value
        if self.training:
            return th.stack(loglike_list).sum()
        else:
            # here only_root_state is false
            if self.y_emission is None:
                return all_eta_list
            else:
                return self.y_emission(*all_eta_list)  # return p(y_i|X)