Exemple #1
0
    def diffeo_pushforward(self, diffeo: Diffeomorphism[P, P_],
                           chart: Chart[P_]) -> "Tensor[P_]":
        """Compute the pushforward of this tensor along a diffeomorphism

        This isn't a true pushforward; it also requires pulling back contravariant
        indices along the inverse. It's more like a change of coordinates.
        """
        def coord_map_forward(c: jnp.DeviceArray) -> jnp.DeviceArray:
            return chart.point_to_coords(
                diffeo.forward(self.point.chart.coords_to_point(c)))

        def coord_map_backward(c: jnp.DeviceArray) -> jnp.DeviceArray:
            return self.point.chart.point_to_coords(
                diffeo.backward(chart.coords_to_point(c)))

        image = coord_map_forward(self.point.coords)
        jacobian_backward = jax.jacfwd(coord_map_backward)(image)
        # at every step, we contract the first index of tensor
        # transformed index is appended as last index so they end in the right order
        transformed_t = self.t_coords
        for _ in range(self.n_contra):
            # transform contravariant index by pulling back
            # in this case right multiplication is what we wanted anyway
            transformed_t = jnp.tensordot(transformed_t,
                                          jacobian_backward,
                                          axes=([0], [0]))

        jacobian_forward = jax.jacfwd(coord_map_forward)(self.point.coords)
        for _ in range(self.n_cov):
            # we actually want left multiplication, so contract axis 1 of jacobian
            transformed_t = jnp.tensordot(transformed_t,
                                          jacobian_forward,
                                          axes=([0], [1]))

        return Tensor(ChartPoint(image, chart), transformed_t, self.n_contra)
def correction_layer(Kl, Phi):
    if len(Kl.shape) == 2:  ## For FNNs
        return Kl @ Phi @ Kl

    elif len(Kl.shape) == 4:  ## For 1D CNNs
        N_tr = Kl.shape[0]
        D = Kl.shape[-1]

        correction = 0
        for i in range(N_tr):
            for j in range(N_tr):
                correction += Phi[i, j] * jnp.tensordot(
                    Kl[:, i], Kl[j, :], axes=((-1), (1))) / D

        return np.moveaxis(correction, 2, 1)

    elif len(Kl.shape) == 6:  ## For 2D CNNs
        N_tr = Kl.shape[0]
        w = Kl.shape[-1]
        D = w**2

        Kl = Kl.reshape(N_tr, N_tr, D, D)
        correction = 0
        for i in range(N_tr):
            for j in range(N_tr):
                correction += Phi[i, j] * jnp.tensordot(
                    Kl[:, i], Kl[j, :], axes=((-1), (1))) / D

        return np.moveaxis(correction, 2, 1).reshape(N_tr, N_tr, w, w, w, w)

    else:
        raise NotImplementedError('wtf')
Exemple #3
0
def Potts_ShiftGaugeZeroSum(potts):

    # make L x q array of mean value of h_i for all L sites
    h_mean = np.tensordot(np.mean(potts.h, axis=1), np.ones(potts.q), axes=0)

    # update h to zero-sum gauge
    potts.h = index_update(potts.h, index[:, :], potts.h - h_mean)

    # make L x q x L x q of mean value of e_ij for all L choose 2 site pairs
    # should be symmetric and 0 on diagonal
    e_mean = np.tensordot(np.mean(potts.e, axis=(1, 3)),
                          np.ones((potts.q, potts.q)),
                          axes=0)

    # transpose e so that it is L x L x q x q
    e_transpose = np.transpose(potts.e, axes=[0, 2, 1, 3])

    # shift to zero-sum gauge
    e_transpose = index_update(e_transpose, index[:, :, :, :],
                               e_transpose - e_mean)

    # undo the transpotition to get an L x q x L x q array
    potts.e = index_update(potts.e, index[:, :, :, :],
                           np.transpose(e_transpose, axes=(0, 2, 1, 3)))

    return
Exemple #4
0
    def system_id(self):
        """ returns current estimate of hidden system dynamics """
        assert self.T > 0
        k = self.k if self.k else int(0.15 * self.T)

        # transform eta and x
        eta_np = np.array(self.eta)
        x_np = np.array(self.x_history)

        # prepare vectors and retrieve B
        scan_len = self.T - k - 1  # need extra -1 because we iterate over j=0,..,k
        N_j = np.array([
            np.dot(x_np[j + 1:j + 1 + scan_len].T, eta_np[:scan_len])
            for j in range(k + 1)
        ]) / scan_len
        B = N_j[0]  # np.dot(x_np[1:].T, eta_np[:-1]) / (self.T-1)
        #B = np.dot(x_np[1:].T, eta_np[:-1]) / (self.T-1)
        # retrieve A
        C_0, C_1 = N_j[:-1], N_j[1:]
        C_inv = np.linalg.inv(
            np.tensordot(C_0, C_0, axes=([0, 2], [0, 2])) +
            self.gamma * np.identity(self.n))
        A = np.tensordot(C_1, C_0, axes=([0, 2], [0, 2])) @ C_inv + B @ self.K

        return (A, B)
def _parse_nn_pepo_obc(C, D, vL, vR, vB, vT, lx, ly):
    assert (lx > 2) and (ly > 2
                         )  # (otherwise there is no bulk, to put the Ds in)
    x_d = lx // 2
    y_d = ly // 2

    # HORIZONTAL
    vL_C = np.tensordot(vL, C, [0, 2])  # (p,p*,r)
    C_vR = np.tensordot(vR, C, [0, 3])  # (p,p*,l)
    vB_D = np.tensordot(vB, D, [0, 4])  # (p,p*,l,r,u)
    D_vT = np.tensordot(vT, D, [0, 5])  # (p,p*,l,r,d)

    left_col = [
        vL_C[:, :, :, None]
    ] + [vL_C[:, :, None, :, None]] * (ly - 2) + [vL_C[:, :, None, :]]

    # bottom C:  (p,p*,i,j) = (p,p*,l,r) -> (p,p*,r,l) -> (p,p*,r,u,l)
    # bulk C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,u,l,d,r)
    # top C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,l,d,r)
    mid_col = [np.transpose(C, [0, 1, 3, 2])[:, :, :, None, :]] \
              + [C[:, :, None, :, None, :]] * (ly - 2) \
              + [C[:, :, :, None, :]]

    # vB_D: (p,p*,ijl) = (p,p*,lru) -> (p,p*,rul)
    # D: (p,p*,ijkl) -> (p,p*,likj) = (p,p*,uldr)
    # D_vT: (p,p*,ijk) = (p,p*,lrd) -> (p,p*,ldr)
    d_col = [np.transpose(vB_D, [0, 1, 3, 4, 2])] \
            + [np.transpose(D, [0, 1, 5, 2, 4, 3])] * (ly - 2) \
            + [np.transpose(D_vT, [0, 1, 2, 4, 3])]

    right_col = [
        C_vR[:, :, None, :]
    ] + [C_vR[:, :, None, :, None]] * (ly - 2) + [C_vR[:, :, :, None]]
    tensors = [left_col] \
              + [mid_col] * (x_d - 1) \
              + [d_col] \
              + [mid_col] * (lx - x_d - 2) \
              + [right_col]
    pepo_hor = Pepo(
        tensors, OBC, False
    )  # even if the NnPepo is hermitian, the two separate Pepos could be not.

    # VERTICAL
    # rotate tensors clockwise

    # (p,p*,u,l,d,r) -> (p,p*,l,d,r,u)
    _rotate90 = partial(np.transpose, axes=[0, 1, 3, 4, 5, 2])

    # tensor at new location (x,y) was at (-y-1,x) before
    tensors = [[tensors[-y - 1][0] for y in range(ly)]] \
              + [[tensors[-1][x]] + [_rotate90(tensors[-y - 1][x]) for y in range(1, ly - 1)]
                 + [tensors[0][x]] for x in range(1, lx - 1)] \
              + [[tensors[-y - 1][-1] for y in range(ly)]]

    pepo_vert = Pepo(
        tensors, OBC, False
    )  # even if the NnPepo is hermitian, the two separate Pepos could be not.

    return pepo_hor, pepo_vert
Exemple #6
0
 def lstm_cell(hc, x):
     h, c = hc
     p = params
     tmp = jnp.tensordot(x, p["w"], [-1, 0]) + jnp.tensordot(
         h, p["u"], [-1, 0]) + p["b"]
     ft, it, ot, gt = tmp.T
     ct = jax.nn.sigmoid(ft + 1) * c + jax.nn.sigmoid(it) * jnp.tanh(gt)
     ht = jax.nn.sigmoid(ot) * jnp.tanh(ct)
     return (ht, ct), ct
Exemple #7
0
 def gru_cell(h, x):
     p = params["zr"]
     tmp = jnp.tensordot(x, p["w"], [-1, 0]) + jnp.tensordot(
         h, p["u"], [-1, 0]) + p["b"]
     zt, rt = jax.nn.sigmoid(tmp).T
     ht = jnp.tanh(x @ params["h"]["w"] + (h * rt) @ params["h"]["u"] +
                   params["h"]["b"])
     h = (1 - zt) * h + zt * ht
     return h, h
Exemple #8
0
 def policy_loss(M, bias, w, cost_t=cost_fn):
     y = np.zeros((n, 1))
     for h in range(HH - 1):
         v = -self.K @ y + np.tensordot(
             M, w[h:h + H], axes=([0, 2], [0, 1])) + bias
         y = A @ y + B @ v + w[h + H]
     # Don't update state at the end
     v = -self.K @ y + np.tensordot(
         M, w[h:h + H], axes=([0, 2], [0, 1])) + bias
     return cost_t(y, v)
Exemple #9
0
 def counterfact_loss(M, w):
     y = np.zeros(self.n)
     for h in range(HH - H - 1):
         v = -self.K @ y + np.tensordot(
             M, w[h:(h + self.H)], axes=([0, 2], [0, 1]))
         y = A @ y + B @ v + w[(h + H)]
     v = -self.K @ y + np.tensordot(
         M, w[h:(h + self.H)], axes=([0, 2], [0, 1]))
     cost = loss_fn(y, v)
     return cost
Exemple #10
0
 def contract_with(self, other) -> complex:
     assert self.L == other.L
     tens = np.tensordot(self.tensors[0], other.tensors[0], [[0, 1], [0, 1]])  # (r1,r2,u) & (r1,r2,u) -> (u,u)
     u, u_ = tens.shape
     col = [np.reshape(tens, [u * u_])]
     col += tree_multimap(_contract_with__bulk_contraction2, self.tensors[1:-1], other.tensors[1:-1])
     tens = np.tensordot(self.tensors[-1], other.tensors[-1], [[1, 2], [1, 2]])  # (d,r1,r2) & (d,r1,r2) -> (d,d)
     d, d_ = tens.shape
     col.append(np.reshape(tens, [d * d_]))
     res = np.linalg.multi_dot(col)
     return res * self.norm * other.norm
Exemple #11
0
    def interpolator(self, PhiX, PhiE):

        iPhiE = np.linalg.inv(
            np.tensordot(PhiE, PhiE, axes=([1, 2], [1, 2])) +
            self.reg_inv * onp.eye(self.num_anchor_points))
        Lambda = np.einsum('ijk,ljk,lm', PhiX, PhiE, iPhiE)
        if self.simplex:
            Lambda = Lambda / (np.sum(Lambda, axis=1)[:, np.newaxis] + 1e-3
                               )  # not really a projection on the simplex

        B = np.tensordot(Lambda, PhiE, axes=(1, 0))

        return B, Lambda
Exemple #12
0
def MSAWeight_PB(msa):
    gap_idx = msa.abc.charmap['-']
    q = msa.abc.q
    ax = msa.ax
    (N, L) = ax.shape

    ## step 1: get counts:

    c = np.sum(msa.ax_1hot, axis=0)

    # set gap counts to 0
    c = index_update(c, index[:, gap_idx], 0)

    # get N x L array with count value for corresponding residue in alignment
    # first, get  N x L "column id" array (convenient for vmap)
    # col_id[n,i] = i
    col_id = np.int16(np.tensordot(np.ones(N), np.arange(L), axes=0))
    # ax_c[n, i] = c[i, ax[n,i]]
    ax_c = Get_Henikoff_Counts_Residue(col_id, ax, c)

    ## step 2: get number of unique characters in each column
    r = np.float32(np.sum(np.array(c > 0), axis=1))

    # transform r from Lx1 array to NxL array, where r2[n,i] = r[i])
    # will allow for easy elementwise operations with ax_c
    r2 = np.tensordot(np.ones(N), r, axes=0)

    ## step 3: get ungapped seq lengths
    nongap = np.array(ax != gap_idx)
    l = np.float32(np.sum(nongap, axis=1))

    ## step 4: calculate unnormalized weights
    # get array of main terms in Henikoff sum
    #wgt_un[n,i] = 1 / (r_[i] * c[i, ax[n,i] ])
    wgt_un = np.reciprocal(np.multiply(ax_c, r2))

    # set all terms involving  gap to zero
    wgt_un = np.nan_to_num(np.multiply(wgt_un, nongap))

    # sum accoss all positions to get prelim unnormalized weight for each sequence
    wgt_un = np.sum(wgt_un, axis=1)

    # divide by gapless sequence length
    wgt_un = np.divide(wgt_un, l)

    # step 4: Normalize sequence wieghts
    wgt = (wgt_un * np.float32(N)) / np.sum(wgt_un)
    msa.wgt = wgt

    return
Exemple #13
0
    def initialize_params(dataset, weights, **kwargs):
        # Initialize based on the mean and covariance of the data
        loc, var, num_datapoints = 0, 0, 0
        for data_dict, these_weights in zip(dataset, weights):
            data = data_dict["data"]
            # loc += np.einsum('n,ni->i', these_weights, data)
            # var += np.einsum('n,ni->i', these_weights, data**2)
            loc += np.tensordot(these_weights, data, axes=(0, 0))
            var += np.tensordot(these_weights, data**2, axes=(0, 0))
            num_datapoints += these_weights.sum()

        loc = loc / num_datapoints
        var = (var / num_datapoints - loc**2)
        df = 3.0
        return (df, ), (loc, var)
def var_gate_exact(top_state, site, bottom_state):
    '''
    Goal:
        to find argmax_{gate} <top_state | gate | down_state>
        where gate is actting on (site, site+1)
    Input:
        top_state: (did not have conjugation yet!!!)
        site: gate is applying on (site, site+1)
        bottom_state
    Return:
        new_gate
    '''
    total_dim = top_state.size
    L = int(np.log2(total_dim))
    top_theta = np.reshape(top_state, [(2**site), 4, 2**(L - (site + 2))])
    bottom_theta = np.reshape(bottom_state,
                              [(2**site), 4, 2**(L - (site + 2))])

    M = np.tensordot(top_theta.conj(), bottom_theta, axes=([0, 2], [
        0, 2
    ]))  # [ ..., upper_p, ...], [..., lower_p, ...] --> upper_p, lower_p
    ## If the convention is lower_p, upper_p
    ## uncomment the following line.
    # M = M.T  # the convention is lower_p, upper_p

    ### For detailed explanation of the formula, see function var_gate
    U, _, Vd = misc.svd(M, full_matrices=False)
    new_gate = np.dot(U, Vd).conj()
    # [TODO:remove] new_gate = new_gate.reshape([2, 2, 2, 2])

    return new_gate
def theory_cnn(x_train, y_train, beta, kernel_fns, hidden_widths):

    N_tr = x_train.shape[0]
    n0 = x_train.shape[1] * x_train.shape[2]
    nd = y_train.shape[1]

    Gxx = jnp.moveaxis(jnp.tensordot(x_train, x_train, (3, 3)), (3),
                       (1))  ## Tensordot in channel axis
    Gyy = y_train @ y_train.T / nd

    K_nngp = []
    for i in range(len(kernel_fns)):
        print(convert_nt(kernel_fns[i](x_train, ).nngp).shape)
        K_nngp += [convert_nt(kernel_fns[i](x_train, ).nngp, i)]

    KPsi = jnp.trace(Gxx.reshape(N_tr, N_tr, D, D), axis1=2, axis2=3) / n0
    #     KPsi_2 = x_train.reshape(N_tr,-1)@x_train.reshape(N_tr,-1).T/D
    #     print((KPsi-KPsi_2).std())

    I = jnp.eye(N_tr)
    gamma = KPsi + I / beta
    gamma_inv = jnp.linalg.inv(gamma)
    Phi = gamma_inv @ (Gyy - KPsi - I / beta) @ gamma_inv

    prefactor = jnp.cumsum(nd / jnp.array(hidden_widths))

    K_theory = []
    for i in range(len(prefactor)):
        K_theory += [
            K_nngp[i] + prefactor[i] * correction_layer(K_nngp[i], Phi)
        ]

    return K_nngp, K_theory, Gxx, Gyy
Exemple #16
0
def counterfact_loss(E, W):
    y, cost = np.zeros((n, 1)), 0
    for h in range(H):
        v = - K @ y + np.tensordot(E, W[h : h+M], axes = ([0, 2], [0, 1]))
        cost += (y.T @ Q @ y + v.T @ R @ v)[0][0]
        y = A @ y + B @ v + W[h+M]
    return cost
Exemple #17
0
    def contract(self, contra_index: int, other: "Tensor[P]", cov_index: int):
        """Contract a contravariant index of self with a covariant index of other.

        The contracted indices are removed. The order of the remaining indices is as in
        tensor_prod.
        """
        # it would be easier to implement this as tensor_prod then trace
        if contra_index < 0 or contra_index > self.n_contra:
            raise ValueError(f"contra_index out of bounds: {contra_index}")
        if cov_index < 0 or cov_index > other.n_cov:
            raise ValueError(f"cov_index out of bounds: {cov_index}")

        unordered = jnp.tensordot(
            self.t_coords,
            other.t_coords,
            ([contra_index], other.n_contra + cov_index),
        )
        # currently ordered self:contra, self:cov, other:contra, other:cov
        # want self:contra, other:contra, self:cov, other:cov
        # 1 missing each from self:contra, other:cov
        axis_order = [
            *range(self.n_contra - 1),
            *range(self.n_indices - 1, self.n_indices - 1 + other.n_contra),
            *range(self.n_contra - 1, self.n_indices - 1),
            *range(
                self.n_indices - 1 + other.n_contra,
                self.n_indices + other.n_indices - 2,
            ),
        ]
        ordered = jnp.transpose(unordered, axis_order)
        return Tensor(self.point, ordered, self.n_contra + other.n_contra - 1)
Exemple #18
0
    def apply_kernel(self,
                     scaling: jnp.ndarray,
                     eps: float = None,
                     axis: int = None):
        """Applies grid kernel on scaling vector.

    See notes in parent class for use.

    Reshapes scaling vector as a grid, applies kernels onto each slice, and
    then ravels backs the output as a vector.

    More implementation details in https://arxiv.org/pdf/1708.01955.pdf

    Args:
      scaling: jnp.ndarray, a vector of scaling (>0) values.
      eps: float, regularization strength
      axis: axis (0 or 1) along which summation should be carried out.

    Returns:
      a vector, the result of kernel applied onto scaling.
    """
        scaling = jnp.reshape(scaling, self.grid_size)
        indices = list(range(1, self.grid_dimension))
        for dimension, kernel in enumerate(self.kernel_matrices):
            ind = indices.copy()
            ind.insert(dimension, 0)
            scaling = jnp.tensordot(kernel, scaling,
                                    axes=([0], [dimension])).transpose(ind)
        return scaling.ravel()
Exemple #19
0
 def features(x):
     """Compute the kitchen sink feature."""
     # We need to contract last axis of x with first of W - do this with
     # tensordot. The result has shape:
     #   (?, ?, num_random_features)
     return jnp.sqrt(2 / num_random_features) * jnp.cos(
         jnp.sqrt(2 / gamma) * jnp.tensordot(x, w, axes=1) + b)
Exemple #20
0
 def nngp_fn_diag(nngp):
     xs, ws = quad_points
     x = xs.reshape((xs.shape[0], ) + (1, ) * nngp.ndim)
     x_axes = (0, )
     nngp = np.expand_dims(nngp, x_axes)
     fval = fn(_sqrt(2 * nngp) * x)**2
     return np.tensordot(ws, fval, (x_axes, x_axes)) / np.sqrt(np.pi)
  def preconditioned_grad(self, grad, preconditioners):
    """Precondition the gradient.

    Args:
      grad: A gradient tensor to precondition.
      preconditioners: A list of preconditioners to apply.

    Returns:
      A preconditioned gradient.
    """

    reshaped_grad = jnp.reshape(grad, self._transformed_shape)
    partitioned_grads = self._partitioner.partition(reshaped_grad)
    preconditioned_partitioned_grads = []
    num_splits = self._partitioner.num_splits()
    for i, g in enumerate(partitioned_grads):
      preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
                                                 num_splits]
      rank = len(g.shape)
      precond_g = g
      for j in range(rank):
        precond_g = jnp.tensordot(
            precond_g, preconditioners_for_grad[j], axes=[[0], [0]])
      preconditioned_partitioned_grads.append(precond_g)
    merged_grad = self._partitioner.merge_partitions(
        preconditioned_partitioned_grads)
    return jnp.reshape(merged_grad, self._original_shape)
Exemple #22
0
def counterfact_loss(E, off, W, H, M, x, env_sim, cost_func, U_old, k, K,
                     X_old, D, F, alpha, C):
    y, cost = x, 0
    for h in range(H):
        u_delta = jnp.tensordot(E,
                                jax.lax.dynamic_slice(W, (h, 0),
                                                      (M, W.shape[1])),
                                axes=([0, 2], [0, 1])) + off
        u = (U_old[h] + alpha * k[h] +
             K[h] @ (y.flatten() - X_old[h].flatten()) + C * u_delta)
        cost = cost_func(y, u, env_sim)

        new_state, _ = env_sim(y, u)
        y = y.unflatten(new_state.flatten() + W[h + M])
        ## Removing the bottom functionality for performance
        # if w_is == "de":
        #     y = y.unflatten(new_state.flatten() + W[h + M])
        # elif w_is == "dede":
        #     y = y.unflatten(new_state.flatten() + D[h + M] + W[h + M])
        # else:
        #     y = y.unflatten(
        #         X_old[h + M + 1].flatten()
        #         + F[h + M][0] @ (y.flatten() - X_old[h + M].flatten())
        #         + F[h + M][1] @ (u - U_old[h + M])
        #         + W[h + M]
        #     )
    return cost
 def update_fun(step, grads, state):
     """Apply a step of the optimizer."""
     del step  # Unused.
     params, grad_seq = state
     grad_seq = append_to_sequence(grad_seq, grads)
     params -= jnp.tensordot(meta_params, grad_seq, axes=1)
     return (params, grad_seq)
Exemple #24
0
        def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train):
            t = np.array(t) * learning_rate
            t_shape, t_ndim = t.shape, t.ndim
            first_t_axes = tuple(range(t_ndim))
            t = t.reshape((-1, 1))

            rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train
            rhs = np.moveaxis(rhs, trace_axes,
                              last_t_axes).reshape((-1, ) + rhs_shape)
            shape = t_shape + k_train_train.shape[1::2] + rhs_shape

            if fx_train_0 is not None:
                dfx_train = expm1_fn(rhs, t).reshape(shape)
                dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes)
                fx_train_t = np.expand_dims(fx_train_0,
                                            first_t_axes) + dfx_train

            if fx_test_0 is not None:
                dfx_test = inv_expm1_fn(rhs, t).reshape(shape)
                dfx_test = np.tensordot(k_test_train, dfx_test,
                                        (odd, non_t_axes))
                dfx_test = np.moveaxis(
                    dfx_test,
                    tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) +
                    last_t_axes,
                    tuple(range(t_ndim)) + trace_axes)
                fx_test_t = np.expand_dims(fx_test_0, first_t_axes) + dfx_test

            if fx_train_0 is not None and fx_test_0 is not None:
                return fx_train_t, fx_test_t
            if fx_test_0 is None:
                return fx_train_t
            return fx_test_t
Exemple #25
0
        def conjugate_m_step(expectations, nonconjugate_params):
            # Compute expected sufficient statistics
            suff_stats = None
            num_datapoints = 0
            for expects, data_dict, these_weights in zip(
                    expectations, dataset, weights):
                these_stats = cls.expected_sufficient_statistics(
                    nonconjugate_params,
                    expectations=expects,
                    **data_dict,
                    **kwargs)

                # weight the statistics if weights are given
                these_stats = tuple(
                    np.tensordot(these_weights, s, axes=(0, 0))
                    for s in these_stats)

                # add to our accumulated statistics
                suff_stats = sum_tuples(suff_stats, these_stats)

                # update the number of datapoints
                num_datapoints += these_weights.sum()

            # Find the optimal parameters for the conjugate part of the compound distribution
            posterior_stats = suff_stats
            posterior_counts = num_datapoints
            if prior is not None:
                posterior_stats = sum_tuples(prior.pseudo_obs, posterior_stats)
                posterior_counts += prior.pseudo_counts

            # Compute the posterior distribution
            posterior_class = get_compound(cls)
            posterior = posterior_class.from_stats(posterior_stats,
                                                   posterior_counts, **kwargs)
            return posterior.mode()
Exemple #26
0
    def fit(cls, dataset, weights=None, prior=None, **kwargs):
        """Compute the maximum a posteriori (MAP) estimate of the distribution
        parameters.  For uninformative priors, this reduces to the maximum
        likelihood estimate.
        """
        # Compute the sufficient statistics and the number of datapoints
        suff_stats = None
        num_datapoints = 0
        for data_dict, these_weights in zip(dataset, weights):
            these_stats = cls.sufficient_statistics(**data_dict, **kwargs)

            # weight the statistics if weights are given
            if these_weights is not None:
                these_stats = tuple(
                    np.tensordot(these_weights, s, axes=(0, 0))
                    for s in these_stats)
            else:
                these_stats = tuple(s.sum(axis=0) for s in these_stats)

            # add to our accumulated statistics
            suff_stats = sum_tuples(suff_stats, these_stats)

            # update the number of datapoints
            num_datapoints += these_weights.sum()

        return cls.fit_with_stats(suff_stats,
                                  num_datapoints,
                                  prior=prior,
                                  **kwargs)
Exemple #27
0
def additive_kernel(
    x1,
    x2,
    lengthscales,
    additive_alphas,
    kernel_alphas,
    base_kernel_fun,
    diag_only,
    jitter=DEFAULT_JITTER,
):

    N = additive_alphas.shape[0]

    # TODO: Could make more general to support other kernels
    to_vmap = lambda x1, x2, lengthscale, alpha: base_kernel_fun(
        x1.reshape(-1, 1),
        x2.reshape(-1, 1),
        lengthscale.reshape(-1,),
        alpha,
        diag_only,
        jitter,
    )

    map_res = vmap(to_vmap)(x1.T, x2.T, lengthscales, kernel_alphas)

    girard_res = newton_girard_combination(map_res, N)

    kernel_res = jnp.tensordot(additive_alphas, girard_res, axes=(0, 0))

    return kernel_res
        def update_fun(step, grads, state):
            """Apply a step of the optimizer."""
            del step  # Unused.
            params, grad_seq, param_seq = state
            grad_seq = append_to_sequence(grad_seq, grads)
            param_seq = append_to_sequence(param_seq, params)

            # Differences in parameters.
            # TODO(nirum): This recomputes differences at every iteration. Should
            # time this to ensure that the repeated jnp.diff call is not too slow.
            delta_params = jnp.diff(param_seq, axis=0)

            grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1)
            dx_term = jnp.tensordot(theta_dx, delta_params, axes=1)
            params -= (grad_term + dx_term)

            return (params, grad_seq, param_seq)
Exemple #29
0
 def tensor(self, other):
     if not isinstance(other, Tensor):
         raise TypeError(messages.type_err(Tensor, other))
     dom, cod = self.dom + other.dom, self.cod + other.cod
     array = np.tensordot(self.array, other.array, 0)\
         if self.array.shape and other.array.shape\
         else self.array * other.array
     return Tensor(dom, cod, array)
Exemple #30
0
    def get_action(self):
        if self.T < self.T_0:
            return self.sys_id.get_action(self.x)

        M_tilde = self.M + self.delta * self.eps[-1]
        #choose action
        self.u = -self.K @ self.x + np.tensordot(
            M_tilde, self.w_past, axes=([0, 2], [0, 1]))
        return self.u