def hfactor_bfe_obj(factor, T, w):
    """
    Get the contribution to the BFE from a hybrid or continuous factor (use dfactor_bfe_obj for a discrete factor for
    efficiency).
    :param factor:
    :param T: num quad points
    :param w:
    :return:
    """
    K = np.prod(w.shape)
    einsum_eq = utils.outer_prod_einsum_equation(len(factor.nb),
                                                 common_first_ndims=1)
    w_broadcast = tf.reshape(w,
                             [-1] + [1] * len(factor.nb))  # K x 1 x 1 ... x 1

    coefs, axes = get_hfactor_expectation_coefs_points(
        factor, K, T)  # [[K, V1], [K, V2], ..., [K, Vn]]
    coefs = tf.einsum(
        einsum_eq,
        *coefs)  # K x V1 x V2 x ... Vn; K grids of Hadamard products
    belief = eval_hfactor_belief(factor, axes, w)  # K x V1 x V2 x ... Vn
    lpot = utils.eval_fun_grid(factor.log_potential_fun,
                               arrs=axes)  # K x V1 x V2 x ... Vn
    log_belief = tf.log(belief)
    F = -lpot + log_belief
    prod = tf.stop_gradient(
        w_broadcast * coefs *
        F)  # weighted component-wise Hadamard products for K expectations
    bfe = tf.reduce_sum(prod)
    aux_obj = tf.reduce_sum(prod * log_belief)

    return bfe, aux_obj
def eval_hfactors_belief(factors, axes, w):
    """
    Evaluate multiple hybrid/continuous factors' belief on grid(s), assuming the factors have the same nb_domain_type
    (i.e., the same kinds of variables in their cliques) so the dimensions match.
    :param factors:
    :param axes: list of mats [C x M x V1, C x M x V2, ..., C x M x Vn]; we allow the flexibility to evaluate on M > 1
    ndgrids for each factor simultaneously
    :param w:
    :return: a [C x M x V1 x V2 x ... x Vn] tensor, whose (c, m, v1, ..., vn)th coordinate is the mixture belief of the
    cth factor (i.e., factors[c]) evaluated on the point (axes[0][c,m,v1], axes[1][c,m,v2], ..., axes[n][c,m,vn])
    """
    K = np.prod(w.shape)
    C = len(factors)
    M = int(axes[0].shape[1])  # number of grids for each factor
    comp_probs = []
    factor = factors[0]
    n = len(factor.nb)
    for i, domain_type in enumerate(factor.nb_domain_types):
        factors_ith_nb = [
            factor.nb[i] for factor in factors
        ]  # the ith neighbor (rv in clique) across all factors
        if domain_type[0] == 'd':  # discrete
            comp_prob = tf.stack(
                [rv.belief_params_['pi'] for rv in factors_ith_nb], axis=1
            )  # K x C x Vi, where Vi is the number of dstates of factor.nb[i]
            comp_prob = comp_prob[:, :, None, :]  # K x C x 1 x Vi
            comp_prob = tf.tile(
                comp_prob, [1, 1, M, 1])  # K x C x M x Vi; same for all M axes
        elif domain_type[0] == 'c':  # cont, assuming Gaussian for now
            # Mu = tf.stack([rv.belief_params_['mu'] for rv in factors_ith_nb], axis=0)  # C x K
            # Var_inv = tf.stack([rv.belief_params_['var_inv'] for rv in factors_ith_nb], axis=0)  # C x K
            # Mu_KC11 = tf.reshape(tf.transpose(Mu), [K, C, 1, 1])
            # Var_inv_KC11 = tf.reshape(tf.transpose(Var_inv), [K, C, 1, 1])
            Mu_KC11 = tf.stack(
                [rv.belief_params_['mu_K1'] for rv in factors_ith_nb],
                axis=1)[:, :, None]
            Var_inv_KC11 = tf.stack(
                [rv.belief_params_['var_inv_K1'] for rv in factors_ith_nb],
                axis=1)[:, :, None]
            # eval pdf of axes[i] under all K scalar comps of ith nodes in all the cliques; result is K x C x M x Vi
            comp_prob = (2 * np.pi) ** (-0.5) * tf.sqrt(Var_inv_KC11) * \
                        tf.exp(-0.5 * (axes[i] - Mu_KC11) ** 2 * Var_inv_KC11)
        else:
            raise NotImplementedError
        comp_probs.append(comp_prob)

    # multiply all dimensions together, then weigh by w
    einsum_eq = utils.outer_prod_einsum_equation(len(factor.nb),
                                                 common_first_ndims=3)
    joint_comp_probs = tf.einsum(einsum_eq,
                                 *comp_probs)  # K x C x M x V1 x V2 x ... Vn
    w_broadcast = tf.reshape(w, [K] + [1] * (len(factor.nb) + 2))
    return tf.reduce_sum(w_broadcast * joint_comp_probs,
                         axis=0)  # C x M x V1 x V2 x ... Vn
def eval_dfactor_belief(factor, w):
    """
    Evaluate discrete factor belief on the ndgrid tensor formed by the Cartesian product of node states.
    :param factor:
    :return:
    """
    einsum_eq = utils.outer_prod_einsum_equation(len(factor.nb),
                                                 common_first_ndims=1)
    comp_probs = [rv.belief_params_['pi']
                  for rv in factor.nb]  # [K x V1, K x V2, ..., K x Vn]
    # multiply all dimensions together, all K components at once
    joint_comp_probs = tf.einsum(einsum_eq,
                                 *comp_probs)  # K x V1 x V2 x ... Vn
    w_broadcast = tf.reshape(w,
                             [-1] + [1] * len(comp_probs))  # K x 1 x 1 ... x 1
    belief = tf.reduce_sum(w_broadcast * joint_comp_probs,
                           axis=0)  # V1 x V2 x ... Vn
    return belief
def eval_hfactor_belief(factor, axes, w):
    """
    Evaluate hybrid/continuous factor's belief on grid(s).
    :param factor:
    :param axes: list of mats [M x V1, M x V2, ..., M x Vn]; we allow the flexibility to evaluate on M > 1 ndgrids
    simultaneously
    :param w:
    :return: a [M x V1 x V2 x ... x Vn] tensor, whose (m, v1, ..., vn)th coordinate is the mixture belief evaluated on
    the point (axes[0][m,v1], axes[1][m,v2], ..., axes[n][m,vn])
    """
    M = int(axes[0].shape[0])  # number of grids
    comp_probs = []
    for i, rv in enumerate(factor.nb):
        if rv.domain_type[0] == 'd':  # discrete
            comp_prob = rv.belief_params_[
                'pi']  # assuming the states of Xi are sorted, so p_ki(rv.states) = p_ki
            comp_prob = comp_prob[:, None, :]  # K x 1 x Vi
            comp_prob = tf.tile(comp_prob,
                                [1, M, 1])  # K x M x Vi; same for all M axes
        elif rv.domain_type[0] == 'c':  # cont, assuming Gaussian for now
            mean_K1 = rv.belief_params_['mu_K1']
            mean_K11 = mean_K1[:, :, None]
            var_inv_K1 = rv.belief_params_['var_inv_K1']
            var_inv_K11 = var_inv_K1[:, :, None]
            # eval pdf of axes[i] (M x Vi) under all K scalar comps of ith node in the clique; result is K x M x Vi
            comp_prob = (2 * np.pi) ** (-0.5) * tf.sqrt(var_inv_K11) * \
                        tf.exp(-0.5 * (axes[i] - mean_K11) ** 2 * var_inv_K11)
        else:
            raise NotImplementedError
        comp_probs.append(comp_prob)

    # multiply all dimensions together, then weigh by w
    einsum_eq = utils.outer_prod_einsum_equation(len(factor.nb),
                                                 common_first_ndims=2)
    joint_comp_probs = tf.einsum(einsum_eq,
                                 *comp_probs)  # K x M x V1 x V2 x ... Vn
    w_broadcast = tf.reshape(w, [-1] + [1] * (len(factor.nb) + 1))
    return tf.reduce_sum(w_broadcast * joint_comp_probs,
                         axis=0)  # M x V1 x V2 x ... Vn
def dfactors_bfe_obj(factors, w, neg_lpot_only=False):
    """
    Get the contribution to the BFE from multiple discrete factors that have the same kinds of neighboring rvs (i.e.,
    the rvs in factor.nb should have the same number of dstates).
    :param factors: length C list of factor objects that have the same factor.nb_domain_types.
    :param w:
    :param neg_lpot_only: if False (default), compute E_b[-log pot + log b] as in BFE;
    if True, only compute E_b[-log pot] (with no log belief in the expectant), to be used with neg ELBO (for NPVI)
    :return:
    """
    # group factors with the same types of log potentials together for efficient evaluation later
    factors_with_unique_log_potential_fun_types, unique_log_potential_fun_types = \
        utils.get_unique_subsets(factors, key=lambda f: type(f.log_potential_fun))
    factors = sum(factors_with_unique_log_potential_fun_types,
                  [])  # join together into flat list

    C = len(factors)
    factor = factors[0]
    n = len(factor.nb)

    einsum_eq = utils.outer_prod_einsum_equation(len(factor.nb),
                                                 common_first_ndims=2)
    comp_probs = [
        tf.stack([f.nb[i].belief_params_['pi'] for f in factors], axis=0)
        for i in range(n)
    ]  # [C x K x V1, C x K x V2, ..., C x K x Vn]
    # multiply all dimensions together, all C factors and all K components at once
    joint_comp_probs = tf.einsum(einsum_eq,
                                 *comp_probs)  # C x K x V1 x V2 x ... Vn
    w_broadcast = tf.reshape(w,
                             [-1] + [1] * len(comp_probs))  # K x 1 x 1 ... x 1
    belief = tf.reduce_sum(w_broadcast * joint_comp_probs,
                           axis=1)  # C x V1 x V2 x ... Vn
    axes = [
        np.stack([f.nb[i].values for f in factors], axis=0) for i in range(n)
    ]  # [C x V1, C x V2, ..., C x Vn]

    lpot = group_eval_log_potential_funs(
        factors_with_unique_log_potential_fun_types,
        unique_log_potential_fun_types, axes)  # C x V1 x V2 x ... Vn
    # if not lpot.dtype in ('float32', 'float64'):
    if not lpot.dtype == 'float64':  # tf crashes when adding int type to float
        lpot = tf.cast(lpot, 'float64')
    log_belief = tf.log(belief)
    if neg_lpot_only:
        F = -lpot
    else:
        F = -lpot + log_belief
    prod = tf.stop_gradient(belief * F)  # stop_gradient is needed for aux_obj

    factors_bfes = tf.reduce_sum(prod, axis=list(range(
        1, n + 1)))  # reduce the last n dimensions
    factors_aux_objs = tf.reduce_sum(
        prod * log_belief,
        axis=list(range(1, n + 1)))  # reduce the last n dimensions

    sharing_counts = np.array([factor.sharing_count for factor in factors],
                              dtype='float')
    bfe = tf.reduce_sum(sharing_counts * factors_bfes)
    aux_obj = tf.reduce_sum(sharing_counts * factors_aux_objs)

    return bfe, aux_obj
def hfactors_bfe_obj(factors, T, w, dtype='float64', neg_lpot_only=False):
    """
    Get the contribution to the BFE from multiple hybrid (or continuous) factors that have the same types of neighboring
    rvs.
    :param factors: length C list of factor objects that have the same nb_domain_type.
    :param T:
    :param w:
    :param dtype: float type to use
    :param neg_lpot_only: if False (default), compute E_b[-log pot + log b] as in BFE;
    if True, only compute E_b[-log pot] (with no log belief in the expectant), to be used with neg ELBO (for NPVI)
    :return:
    """
    # group factors with the same types of log potentials together for efficient evaluation later
    factors_with_unique_log_potential_fun_types, unique_log_potential_fun_types = \
        utils.get_unique_subsets(factors, key=lambda f: type(f.log_potential_fun))
    factors = sum(factors_with_unique_log_potential_fun_types,
                  [])  # join together into flat list

    K = np.prod(w.shape)
    C = len(factors)
    factor = factors[0]
    n = len(factor.nb)

    ghq_points, ghq_weights = roots_hermite(T)  # assuming Gaussian for now
    ghq_coef = (np.pi)**(-0.5)  # from change-of-var
    ghq_weights = ghq_coef * ghq_weights  # let's fold ghq_coef into the quadrature weights, so no need to worry about it later
    ghq_weights = tf.constant(ghq_weights, dtype=dtype)
    ghq_weights_CKT = tf.tile(tf.reshape(ghq_weights, [1, 1, T]),
                              [C, K, 1])  # C x K x T

    coefs = [None] * n  # will be [[C, K, V1], [C, K, V2], ..., [C, K, Vn]]
    axes = [None] * n  # will be [[C, K, V1], [C, K, V2], ..., [C, K, Vn]]

    comp_probs = []  # for evaluating beliefs along the way
    for i, domain_type in enumerate(factor.nb_domain_types):
        factors_ith_nb = [
            factor.nb[i] for factor in factors
        ]  # the ith neighbor (rv in clique) across all factors
        if domain_type[0] == 'd':
            rv = factor.nb[i]
            c = tf.stack(
                [rv.belief_params_['pi'] for rv in factors_ith_nb], axis=0
            )  # C x K x Vi, where Vi is the number of dstates of factor.nb[i]

            coefs[
                i] = c  # the prob params are exactly the inner-prod coefficients in expectations
            a = np.tile(
                np.reshape(rv.values, [1, 1, -1]),
                [C, K, 1])  # C x K x dstates[i] (last dimension repeated)
            a = tf.constant(
                a, dtype=dtype
            )  # otherwise tf complains about multiplying int tensor with float tensor
            axes[i] = a

            # eval_hfactors_belief
            # comp_prob = tf.stack([rv.belief_params_['pi'] for rv in factors_ith_nb],
            #                      axis=1)  # K x C x Vi, where Vi is the number of dstates of factor.nb[i]
            comp_prob = tf.transpose(c, [1, 0, 2])  # K x C x Vi
            comp_prob = comp_prob[:, :, None, :]  # K x C x 1 x Vi
            comp_prob = tf.tile(
                comp_prob,
                [1, 1, K, 1])  # K x C x M(=K) x Vi; same for all M(=K) axes
        elif domain_type[0] == 'c':
            Mu_CK = tf.stack(
                [rv.belief_params_['mu'] for rv in factors_ith_nb],
                axis=0)  # C x K
            Var_CK = tf.stack(
                [rv.belief_params_['var'] for rv in factors_ith_nb],
                axis=0)  # C x K
            coefs[i] = ghq_weights_CKT
            a = (2 * Var_CK[:, :, None]
                 )**0.5 * ghq_points + Mu_CK[:, :, None]  # C x K x T
            a = tf.stop_gradient(
                a)  # don't want to differentiate w.r.t. evaluation points
            axes[i] = a

            # eval_hfactors_belief
            Mu_KC11 = tf.transpose(Mu_CK)[:, :, None, None]  # K x C x 1 x 1
            Var_inv_KC11 = tf.stack(
                [rv.belief_params_['var_inv_K1'] for rv in factors_ith_nb],
                axis=1)[:, :, None]
            # eval pdf of axes[i] under all K scalar comps of ith nodes in all the cliques; result is K x C x M(=K) x Vi
            comp_prob = (2 * np.pi) ** (-0.5) * tf.sqrt(Var_inv_KC11) * \
                        tf.exp(-0.5 * (axes[i] - Mu_KC11) ** 2 * Var_inv_KC11)
        else:
            raise NotImplementedError
        comp_probs.append(comp_prob)

    # eval_hfactors_belief
    # multiply all dimensions together, then weigh by w
    einsum_eq = utils.outer_prod_einsum_equation(len(factor.nb),
                                                 common_first_ndims=3)
    joint_comp_probs = tf.einsum(einsum_eq,
                                 *comp_probs)  # K x C x M x V1 x V2 x ... Vn
    w_broadcast = tf.reshape(w, [K] + [1] * (len(factor.nb) + 2))
    belief = tf.reduce_sum(w_broadcast * joint_comp_probs,
                           axis=0)  # C x M x V1 x V2 x ... Vn
    # above replaces the call belief = eval_hfactors_belief(factors, axes, w)  # C x K x V1 x V2 x ... Vn

    einsum_eq = utils.outer_prod_einsum_equation(n, common_first_ndims=2)
    coefs = tf.einsum(
        einsum_eq,
        *coefs)  # C x K x V1 x V2 x ... Vn; C x K grids of Hadamard products

    lpot = group_eval_log_potential_funs(
        factors_with_unique_log_potential_fun_types,
        unique_log_potential_fun_types, axes)  # C x K x V1 x V2 x ... Vn
    log_belief = tf.log(belief)
    if neg_lpot_only:
        F = -lpot
    else:
        F = -lpot + log_belief
    w_broadcast = tf.reshape(w, [-1] + [1] * n)  # K x 1 x 1 ... x 1
    prod = tf.stop_gradient(
        w_broadcast * coefs *
        F)  # weighted component-wise Hadamard products for C x K expectations
    factors_bfes = tf.reduce_sum(prod, axis=list(range(
        1, n + 2)))  # reduce the last (n+1) dimensions
    factors_aux_objs = tf.reduce_sum(
        prod * log_belief,
        axis=list(range(1, n + 2)))  # reduce the last (n+1) dimensions

    sharing_counts = np.array([factor.sharing_count for factor in factors],
                              dtype='float')
    bfe = tf.reduce_sum(sharing_counts * factors_bfes)
    aux_obj = tf.reduce_sum(sharing_counts * factors_aux_objs)

    return bfe, aux_obj