Ejemplo n.º 1
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        gamma_ch = nodes.data['gamma_ch']  # has shape (bs x L x h x h)
        gamma_p_ch = nodes.data['gamma_p_ch'].unsqueeze(
            2)  # has shape (bs x L x 1 x h)
        gamma_r = nodes.data['gamma_r'].unsqueeze(1).unsqueeze(
            2)  # has shape (bs x 1 x 1 x h)
        n_ch_mask = gamma_p_ch.exp().sum((2, 3), keepdim=True)
        a = thlp.div(gamma_ch, gamma_p_ch * n_ch_mask)
        a = thlp.mul(a, gamma_r)
        b = thlp.sum_over(a, 2, keepdim=True)
        # P(Q_l, Q_ch_l | X)
        eta_u_chl = thlp.div(thlp.mul(a,
                                      eta_u.unsqueeze(1).unsqueeze(2)),
                             b * n_ch_mask)  # has shape (bs x L x h x h)

        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        self.accumulate_posterior(
            self.U,
            eta_u_chl[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf]
            if not self.pos_stationarity else None)

        return {'eta_ch': thlp.sum_over(eta_u_chl, 3), 'eta': eta_u}
Ejemplo n.º 2
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        U_out = self.__gather_param__(
            self.U_output,
            types=nodes.data['t'] if self.num_types > 1 else None)

        gamma_r = nodes.data['gamma_r']  # has shape (bs x rank)
        a = thlp.mul(U_out, gamma_r.unsqueeze(2))  # has shape bs x rank x h
        b = thlp.sum_over(a, 1, keepdim=True)
        # P(Q_u, R_u | X)
        eta_ur = thlp.mul(thlp.div(a, b),
                          eta_u.unsqueeze(1))  # has shape bs x rank x h

        eta_r = thlp.sum_over(eta_ur, 2).unsqueeze(1).unsqueeze(
            2)  # has shape (bs x 1 x 1 x rank)
        gamma_ch = nodes.data['gamma_ch']  # has shape (bs x L x h x rank)
        gamma_p_ch = nodes.data['gamma_p_ch'].unsqueeze(
            2)  # has shape (bs x L x 1 x rank)
        gamma_r = gamma_r.unsqueeze(1).unsqueeze(
            2)  # has shape (bs x 1 x 1 x rank)
        n_ch_mask = gamma_p_ch.exp().sum((2, 3), keepdim=True)
        a = thlp.div(gamma_ch, gamma_p_ch * n_ch_mask)
        a = thlp.mul(a, gamma_r)
        b = thlp.sum_over(a, 2, keepdim=True)
        # P(Q_l, R_l | X)
        eta_ur_ch = thlp.div(thlp.mul(a, eta_r),
                             b * n_ch_mask)  # has shape (bs x L x h x rank)

        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        self.accumulate_posterior(
            self.U,
            eta_ur_ch[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.U_output,
            eta_ur[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf]
            if not self.pos_stationarity else None)

        return {'eta_ch': thlp.sum_over(eta_ur_ch, 3), 'eta': eta_u}
Ejemplo n.º 3
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        beta_ch = nodes.data['beta_ch']  # has shape (bs x h x ... x h x h)
        beta_np = nodes.data['beta_np']
        new_shape = [-1] + [1] * self.max_output_degree + [self.h_size]
        # P(Q_u, Q_1, ..., Q_L | X)
        eta_uch = thlp.div(thlp.mul(beta_ch, eta_u.view(*new_shape)),
                           beta_np.view(*new_shape))
        # P(Q_1, ..., Q_L | X)
        eta_joint_ch = thlp.sum_over(eta_uch, -1)

        bs = eta_u.shape[0]
        eta_ch = th.empty(bs, self.max_output_degree, self.h_size + 1)
        for i in range(self.max_output_degree):
            sum_over_var = list(
                set(range(1, self.max_output_degree + 1)) - {i + 1})
            eta_ch[:, i, :] = thlp.sum_over(eta_joint_ch, sum_over_var)

        # accumulate posterior
        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        self.accumulate_posterior(
            self.U,
            eta_uch[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None)

        return {'eta_ch': eta_ch[:, :, :-1], 'eta': eta_u}
Ejemplo n.º 4
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)
        n_ch_list = nodes.data['n_ch'][is_internal] - 1
        gamma_r = nodes.data['gamma_r'][is_internal]  # has shape (bs x rank)
        # has shape bs x rank+1
        gamma_less_l = nodes.data['gamma_less_l'][
            is_internal]  # has shape bs x L x rank+1
        beta_ch = nodes.data['beta_ch'][is_internal]  # has shape bs x L x h
        t = nodes.data['t'][is_internal]

        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf]
            if not self.pos_stationarity else None)

        eta_u_ch_all = thlp.zeros(eta_u.shape[0], self.max_output_degree,
                                  self.h_size)
        if th.any(is_internal):
            # computation only on internal nodes
            # compute P(Q_u, R_u | X)
            U_out = self.__gather_param__(
                self.U_output, types=t if self.num_types > 1 else None)
            # U_out has shape bs x rank x h
            a = thlp.mul(gamma_r.unsqueeze(2), U_out)
            b = thlp.sum_over(a, 1, keepdim=True)
            eta_ur = thlp.div(thlp.mul(a, eta_u[is_internal].unsqueeze(1)),
                              b)  # has shape bs x rank x h

            # compute P(R_u, R_L | X)
            eta_r = thlp.sum_over(eta_ur, 2)  # has shape bs x rank_U
            # R_out = self.__gather_param__(self.R_output, types=t if self.num_types > 1 else None)
            # # R_out has shape bs x rank_L x rank_U
            # a = thlp.mul(R_out, gamma_L[:, :-1].unsqueeze(2))  # has shape bs x rank_L x rank_U
            # b = thlp.sum_over(a, 1, keepdim=True)
            # eta_rul = thlp.div(thlp.mul(a, eta_r.unsqueeze(1)), b)  # has shape bs x rank_L x rank_U

            # compute P(R_l, R_l-1, Q_l | X)
            # eta_rL = thlp.sum_over(eta_rul, 2)  # has shape bs x rank
            eta_rL = eta_r
            U = self.__gather_param__(self.U,
                                      types=t if self.num_types > 1 else None)
            if self.pos_stationarity:
                U = U.expand((-1, self.max_output_degree, -1, -1, -1))
            if U.size(0) == 1:
                U = U.expand((gamma_less_l.size(0), -1, -1, -1, -1))
            # U has shape bs x L x h x rank+1  x rank
            eta_u_ch = thlp.zeros(eta_rL.shape[0], self.max_output_degree,
                                  self.h_size)
            last_eta = eta_rL  # has shape bs x rank
            for i in range(self.max_output_degree - 1, -1, -1):
                pos_flag = i <= n_ch_list
                if th.any(pos_flag):
                    if i > 0:
                        a = thlp.mul(
                            U[pos_flag, i, :, :, :],
                            gamma_less_l[pos_flag,
                                         i - 1, :].unsqueeze(1).unsqueeze(3))
                        a = thlp.mul(
                            a, beta_ch[pos_flag,
                                       i, :].unsqueeze(2).unsqueeze(3))
                    else:
                        a = thlp.zeros(*(U.shape[:1] + U.shape[2:]))
                        a[:, :,
                          -1, :] = thlp.mul(U[:, i, :, -1, :],
                                            beta_ch[:, i, :].unsqueeze(2))
                    b = thlp.sum_over(a, (1, 2), keepdim=True)
                    eta_rul_rlprec = thlp.div(
                        thlp.mul(
                            a,
                            last_eta[pos_flag, :].unsqueeze(1).unsqueeze(2)),
                        b)

                    self.accumulate_posterior(
                        self.U,
                        eta_rul_rlprec,
                        types=t[pos_flag] if self.num_types > 1 else None,
                        pos=th.full(
                            (eta_rul_rlprec.shape[0],
                             1), i, dtype=th.long).squeeze(1)
                        if not self.pos_stationarity else None)

                    eta_u_ch[pos_flag,
                             i, :] = thlp.sum_over(eta_rul_rlprec, (2, 3))
                    last_eta[pos_flag, :] = thlp.sum_over(
                        eta_rul_rlprec, (1, 3))[:, :-1]

            # accumulate posterior
            self.accumulate_posterior(self.U_output,
                                      eta_ur,
                                      types=t if self.num_types > 1 else None)
            # self.accumulate_posterior(self.R_output, eta_rul, types=t if self.num_types > 1 else None)
            eta_u_ch_all[is_internal] = eta_u_ch

        return {'eta_ch': eta_u_ch_all, 'eta': eta_u}
Ejemplo n.º 5
0
    def down_apply_node_func(self, nodes):
        if 'eta' in nodes.data:
            eta_u = nodes.data['eta']
        else:
            # root
            eta_u = nodes.data['beta']

        bs = eta_u.shape[0]

        gamma_r = nodes.data['gamma_r']  # has shape (bs x rank)
        U_out = self.__gather_param__(
            self.U_output,
            types=nodes.data['t'] if self.num_types > 1 else None)
        # U_out has shape bs x rank x h
        a = thlp.mul(gamma_r.unsqueeze(2), U_out)
        b = thlp.sum_over(a, 1, keepdim=True)
        # P(Q_u, R_u | X)
        eta_ur = thlp.div(thlp.mul(a, eta_u.unsqueeze(1)),
                          b)  # has shape bs x rank x h

        eta_r = thlp.sum_over(eta_ur, 2)  # has shape bs x rank

        gamma_ch_all = nodes.data[
            'gamma_ch_all']  # has shape bs x r x ... x r x r
        new_shape = [-1] + [1] * self.max_output_degree + [self.rank]
        # P(R_u, R_1, ..., R_L | X)
        eta_ru_rch = thlp.div(thlp.mul(gamma_ch_all, eta_r.view(*new_shape)),
                              gamma_r.view(*new_shape))
        # P(R_1, ..., R_L | X)
        eta_rch = thlp.sum_over(eta_ru_rch, -1)  # has shape bs x r+1 x ... r+1

        eta_rl = thlp.zeros(bs, self.max_output_degree, self.rank + 1)
        for i in range(self.max_output_degree):
            sum_over_var = list(
                set(range(1, self.max_output_degree + 1)) - {i + 1})
            eta_rl[:, i, :] = thlp.sum_over(eta_rch, sum_over_var)

        U = self.__gather_param__(
            self.U, types=nodes.data['t'] if self.num_types > 1 else None)
        #  U has shape bs x L x h x rank
        a = thlp.mul(U, nodes.data['beta_ch'].unsqueeze(3))
        b = thlp.sum_over(a, 2, keepdim=True)
        # P(Q_l, R_l | X)
        eta_rql = thlp.div(thlp.mul(a, eta_rl[:, :, :-1].unsqueeze(2)),
                           b)  # has shape bs x L x h x rank

        # accumulate posterior
        is_leaf = nodes.data['is_leaf']
        is_internal = th.logical_not(is_leaf)

        self.accumulate_posterior(
            self.U,
            eta_rql[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.U_output,
            eta_ur[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.G,
            eta_ru_rch[is_internal],
            types=nodes.data['t'][is_internal] if self.num_types > 1 else None)
        self.accumulate_posterior(
            self.p,
            eta_u[is_leaf],
            types=nodes.data['t'][is_leaf] if self.num_types > 1 else None,
            pos=nodes.data['pos'][is_leaf])

        return {'eta_ch': thlp.sum_over(eta_rql, 3), 'eta': eta_u}