def diffeo_pushforward(self, diffeo: Diffeomorphism[P, P_], chart: Chart[P_]) -> "Tensor[P_]": """Compute the pushforward of this tensor along a diffeomorphism This isn't a true pushforward; it also requires pulling back contravariant indices along the inverse. It's more like a change of coordinates. """ def coord_map_forward(c: jnp.DeviceArray) -> jnp.DeviceArray: return chart.point_to_coords( diffeo.forward(self.point.chart.coords_to_point(c))) def coord_map_backward(c: jnp.DeviceArray) -> jnp.DeviceArray: return self.point.chart.point_to_coords( diffeo.backward(chart.coords_to_point(c))) image = coord_map_forward(self.point.coords) jacobian_backward = jax.jacfwd(coord_map_backward)(image) # at every step, we contract the first index of tensor # transformed index is appended as last index so they end in the right order transformed_t = self.t_coords for _ in range(self.n_contra): # transform contravariant index by pulling back # in this case right multiplication is what we wanted anyway transformed_t = jnp.tensordot(transformed_t, jacobian_backward, axes=([0], [0])) jacobian_forward = jax.jacfwd(coord_map_forward)(self.point.coords) for _ in range(self.n_cov): # we actually want left multiplication, so contract axis 1 of jacobian transformed_t = jnp.tensordot(transformed_t, jacobian_forward, axes=([0], [1])) return Tensor(ChartPoint(image, chart), transformed_t, self.n_contra)
def correction_layer(Kl, Phi): if len(Kl.shape) == 2: ## For FNNs return Kl @ Phi @ Kl elif len(Kl.shape) == 4: ## For 1D CNNs N_tr = Kl.shape[0] D = Kl.shape[-1] correction = 0 for i in range(N_tr): for j in range(N_tr): correction += Phi[i, j] * jnp.tensordot( Kl[:, i], Kl[j, :], axes=((-1), (1))) / D return np.moveaxis(correction, 2, 1) elif len(Kl.shape) == 6: ## For 2D CNNs N_tr = Kl.shape[0] w = Kl.shape[-1] D = w**2 Kl = Kl.reshape(N_tr, N_tr, D, D) correction = 0 for i in range(N_tr): for j in range(N_tr): correction += Phi[i, j] * jnp.tensordot( Kl[:, i], Kl[j, :], axes=((-1), (1))) / D return np.moveaxis(correction, 2, 1).reshape(N_tr, N_tr, w, w, w, w) else: raise NotImplementedError('wtf')
def Potts_ShiftGaugeZeroSum(potts): # make L x q array of mean value of h_i for all L sites h_mean = np.tensordot(np.mean(potts.h, axis=1), np.ones(potts.q), axes=0) # update h to zero-sum gauge potts.h = index_update(potts.h, index[:, :], potts.h - h_mean) # make L x q x L x q of mean value of e_ij for all L choose 2 site pairs # should be symmetric and 0 on diagonal e_mean = np.tensordot(np.mean(potts.e, axis=(1, 3)), np.ones((potts.q, potts.q)), axes=0) # transpose e so that it is L x L x q x q e_transpose = np.transpose(potts.e, axes=[0, 2, 1, 3]) # shift to zero-sum gauge e_transpose = index_update(e_transpose, index[:, :, :, :], e_transpose - e_mean) # undo the transpotition to get an L x q x L x q array potts.e = index_update(potts.e, index[:, :, :, :], np.transpose(e_transpose, axes=(0, 2, 1, 3))) return
def system_id(self): """ returns current estimate of hidden system dynamics """ assert self.T > 0 k = self.k if self.k else int(0.15 * self.T) # transform eta and x eta_np = np.array(self.eta) x_np = np.array(self.x_history) # prepare vectors and retrieve B scan_len = self.T - k - 1 # need extra -1 because we iterate over j=0,..,k N_j = np.array([ np.dot(x_np[j + 1:j + 1 + scan_len].T, eta_np[:scan_len]) for j in range(k + 1) ]) / scan_len B = N_j[0] # np.dot(x_np[1:].T, eta_np[:-1]) / (self.T-1) #B = np.dot(x_np[1:].T, eta_np[:-1]) / (self.T-1) # retrieve A C_0, C_1 = N_j[:-1], N_j[1:] C_inv = np.linalg.inv( np.tensordot(C_0, C_0, axes=([0, 2], [0, 2])) + self.gamma * np.identity(self.n)) A = np.tensordot(C_1, C_0, axes=([0, 2], [0, 2])) @ C_inv + B @ self.K return (A, B)
def _parse_nn_pepo_obc(C, D, vL, vR, vB, vT, lx, ly): assert (lx > 2) and (ly > 2 ) # (otherwise there is no bulk, to put the Ds in) x_d = lx // 2 y_d = ly // 2 # HORIZONTAL vL_C = np.tensordot(vL, C, [0, 2]) # (p,p*,r) C_vR = np.tensordot(vR, C, [0, 3]) # (p,p*,l) vB_D = np.tensordot(vB, D, [0, 4]) # (p,p*,l,r,u) D_vT = np.tensordot(vT, D, [0, 5]) # (p,p*,l,r,d) left_col = [ vL_C[:, :, :, None] ] + [vL_C[:, :, None, :, None]] * (ly - 2) + [vL_C[:, :, None, :]] # bottom C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,r,l) -> (p,p*,r,u,l) # bulk C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,u,l,d,r) # top C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,l,d,r) mid_col = [np.transpose(C, [0, 1, 3, 2])[:, :, :, None, :]] \ + [C[:, :, None, :, None, :]] * (ly - 2) \ + [C[:, :, :, None, :]] # vB_D: (p,p*,ijl) = (p,p*,lru) -> (p,p*,rul) # D: (p,p*,ijkl) -> (p,p*,likj) = (p,p*,uldr) # D_vT: (p,p*,ijk) = (p,p*,lrd) -> (p,p*,ldr) d_col = [np.transpose(vB_D, [0, 1, 3, 4, 2])] \ + [np.transpose(D, [0, 1, 5, 2, 4, 3])] * (ly - 2) \ + [np.transpose(D_vT, [0, 1, 2, 4, 3])] right_col = [ C_vR[:, :, None, :] ] + [C_vR[:, :, None, :, None]] * (ly - 2) + [C_vR[:, :, :, None]] tensors = [left_col] \ + [mid_col] * (x_d - 1) \ + [d_col] \ + [mid_col] * (lx - x_d - 2) \ + [right_col] pepo_hor = Pepo( tensors, OBC, False ) # even if the NnPepo is hermitian, the two separate Pepos could be not. # VERTICAL # rotate tensors clockwise # (p,p*,u,l,d,r) -> (p,p*,l,d,r,u) _rotate90 = partial(np.transpose, axes=[0, 1, 3, 4, 5, 2]) # tensor at new location (x,y) was at (-y-1,x) before tensors = [[tensors[-y - 1][0] for y in range(ly)]] \ + [[tensors[-1][x]] + [_rotate90(tensors[-y - 1][x]) for y in range(1, ly - 1)] + [tensors[0][x]] for x in range(1, lx - 1)] \ + [[tensors[-y - 1][-1] for y in range(ly)]] pepo_vert = Pepo( tensors, OBC, False ) # even if the NnPepo is hermitian, the two separate Pepos could be not. return pepo_hor, pepo_vert
def lstm_cell(hc, x): h, c = hc p = params tmp = jnp.tensordot(x, p["w"], [-1, 0]) + jnp.tensordot( h, p["u"], [-1, 0]) + p["b"] ft, it, ot, gt = tmp.T ct = jax.nn.sigmoid(ft + 1) * c + jax.nn.sigmoid(it) * jnp.tanh(gt) ht = jax.nn.sigmoid(ot) * jnp.tanh(ct) return (ht, ct), ct
def gru_cell(h, x): p = params["zr"] tmp = jnp.tensordot(x, p["w"], [-1, 0]) + jnp.tensordot( h, p["u"], [-1, 0]) + p["b"] zt, rt = jax.nn.sigmoid(tmp).T ht = jnp.tanh(x @ params["h"]["w"] + (h * rt) @ params["h"]["u"] + params["h"]["b"]) h = (1 - zt) * h + zt * ht return h, h
def policy_loss(M, bias, w, cost_t=cost_fn): y = np.zeros((n, 1)) for h in range(HH - 1): v = -self.K @ y + np.tensordot( M, w[h:h + H], axes=([0, 2], [0, 1])) + bias y = A @ y + B @ v + w[h + H] # Don't update state at the end v = -self.K @ y + np.tensordot( M, w[h:h + H], axes=([0, 2], [0, 1])) + bias return cost_t(y, v)
def counterfact_loss(M, w): y = np.zeros(self.n) for h in range(HH - H - 1): v = -self.K @ y + np.tensordot( M, w[h:(h + self.H)], axes=([0, 2], [0, 1])) y = A @ y + B @ v + w[(h + H)] v = -self.K @ y + np.tensordot( M, w[h:(h + self.H)], axes=([0, 2], [0, 1])) cost = loss_fn(y, v) return cost
def contract_with(self, other) -> complex: assert self.L == other.L tens = np.tensordot(self.tensors[0], other.tensors[0], [[0, 1], [0, 1]]) # (r1,r2,u) & (r1,r2,u) -> (u,u) u, u_ = tens.shape col = [np.reshape(tens, [u * u_])] col += tree_multimap(_contract_with__bulk_contraction2, self.tensors[1:-1], other.tensors[1:-1]) tens = np.tensordot(self.tensors[-1], other.tensors[-1], [[1, 2], [1, 2]]) # (d,r1,r2) & (d,r1,r2) -> (d,d) d, d_ = tens.shape col.append(np.reshape(tens, [d * d_])) res = np.linalg.multi_dot(col) return res * self.norm * other.norm
def interpolator(self, PhiX, PhiE): iPhiE = np.linalg.inv( np.tensordot(PhiE, PhiE, axes=([1, 2], [1, 2])) + self.reg_inv * onp.eye(self.num_anchor_points)) Lambda = np.einsum('ijk,ljk,lm', PhiX, PhiE, iPhiE) if self.simplex: Lambda = Lambda / (np.sum(Lambda, axis=1)[:, np.newaxis] + 1e-3 ) # not really a projection on the simplex B = np.tensordot(Lambda, PhiE, axes=(1, 0)) return B, Lambda
def MSAWeight_PB(msa): gap_idx = msa.abc.charmap['-'] q = msa.abc.q ax = msa.ax (N, L) = ax.shape ## step 1: get counts: c = np.sum(msa.ax_1hot, axis=0) # set gap counts to 0 c = index_update(c, index[:, gap_idx], 0) # get N x L array with count value for corresponding residue in alignment # first, get N x L "column id" array (convenient for vmap) # col_id[n,i] = i col_id = np.int16(np.tensordot(np.ones(N), np.arange(L), axes=0)) # ax_c[n, i] = c[i, ax[n,i]] ax_c = Get_Henikoff_Counts_Residue(col_id, ax, c) ## step 2: get number of unique characters in each column r = np.float32(np.sum(np.array(c > 0), axis=1)) # transform r from Lx1 array to NxL array, where r2[n,i] = r[i]) # will allow for easy elementwise operations with ax_c r2 = np.tensordot(np.ones(N), r, axes=0) ## step 3: get ungapped seq lengths nongap = np.array(ax != gap_idx) l = np.float32(np.sum(nongap, axis=1)) ## step 4: calculate unnormalized weights # get array of main terms in Henikoff sum #wgt_un[n,i] = 1 / (r_[i] * c[i, ax[n,i] ]) wgt_un = np.reciprocal(np.multiply(ax_c, r2)) # set all terms involving gap to zero wgt_un = np.nan_to_num(np.multiply(wgt_un, nongap)) # sum accoss all positions to get prelim unnormalized weight for each sequence wgt_un = np.sum(wgt_un, axis=1) # divide by gapless sequence length wgt_un = np.divide(wgt_un, l) # step 4: Normalize sequence wieghts wgt = (wgt_un * np.float32(N)) / np.sum(wgt_un) msa.wgt = wgt return
def initialize_params(dataset, weights, **kwargs): # Initialize based on the mean and covariance of the data loc, var, num_datapoints = 0, 0, 0 for data_dict, these_weights in zip(dataset, weights): data = data_dict["data"] # loc += np.einsum('n,ni->i', these_weights, data) # var += np.einsum('n,ni->i', these_weights, data**2) loc += np.tensordot(these_weights, data, axes=(0, 0)) var += np.tensordot(these_weights, data**2, axes=(0, 0)) num_datapoints += these_weights.sum() loc = loc / num_datapoints var = (var / num_datapoints - loc**2) df = 3.0 return (df, ), (loc, var)
def var_gate_exact(top_state, site, bottom_state): ''' Goal: to find argmax_{gate} <top_state | gate | down_state> where gate is actting on (site, site+1) Input: top_state: (did not have conjugation yet!!!) site: gate is applying on (site, site+1) bottom_state Return: new_gate ''' total_dim = top_state.size L = int(np.log2(total_dim)) top_theta = np.reshape(top_state, [(2**site), 4, 2**(L - (site + 2))]) bottom_theta = np.reshape(bottom_state, [(2**site), 4, 2**(L - (site + 2))]) M = np.tensordot(top_theta.conj(), bottom_theta, axes=([0, 2], [ 0, 2 ])) # [ ..., upper_p, ...], [..., lower_p, ...] --> upper_p, lower_p ## If the convention is lower_p, upper_p ## uncomment the following line. # M = M.T # the convention is lower_p, upper_p ### For detailed explanation of the formula, see function var_gate U, _, Vd = misc.svd(M, full_matrices=False) new_gate = np.dot(U, Vd).conj() # [TODO:remove] new_gate = new_gate.reshape([2, 2, 2, 2]) return new_gate
def theory_cnn(x_train, y_train, beta, kernel_fns, hidden_widths): N_tr = x_train.shape[0] n0 = x_train.shape[1] * x_train.shape[2] nd = y_train.shape[1] Gxx = jnp.moveaxis(jnp.tensordot(x_train, x_train, (3, 3)), (3), (1)) ## Tensordot in channel axis Gyy = y_train @ y_train.T / nd K_nngp = [] for i in range(len(kernel_fns)): print(convert_nt(kernel_fns[i](x_train, ).nngp).shape) K_nngp += [convert_nt(kernel_fns[i](x_train, ).nngp, i)] KPsi = jnp.trace(Gxx.reshape(N_tr, N_tr, D, D), axis1=2, axis2=3) / n0 # KPsi_2 = x_train.reshape(N_tr,-1)@x_train.reshape(N_tr,-1).T/D # print((KPsi-KPsi_2).std()) I = jnp.eye(N_tr) gamma = KPsi + I / beta gamma_inv = jnp.linalg.inv(gamma) Phi = gamma_inv @ (Gyy - KPsi - I / beta) @ gamma_inv prefactor = jnp.cumsum(nd / jnp.array(hidden_widths)) K_theory = [] for i in range(len(prefactor)): K_theory += [ K_nngp[i] + prefactor[i] * correction_layer(K_nngp[i], Phi) ] return K_nngp, K_theory, Gxx, Gyy
def counterfact_loss(E, W): y, cost = np.zeros((n, 1)), 0 for h in range(H): v = - K @ y + np.tensordot(E, W[h : h+M], axes = ([0, 2], [0, 1])) cost += (y.T @ Q @ y + v.T @ R @ v)[0][0] y = A @ y + B @ v + W[h+M] return cost
def contract(self, contra_index: int, other: "Tensor[P]", cov_index: int): """Contract a contravariant index of self with a covariant index of other. The contracted indices are removed. The order of the remaining indices is as in tensor_prod. """ # it would be easier to implement this as tensor_prod then trace if contra_index < 0 or contra_index > self.n_contra: raise ValueError(f"contra_index out of bounds: {contra_index}") if cov_index < 0 or cov_index > other.n_cov: raise ValueError(f"cov_index out of bounds: {cov_index}") unordered = jnp.tensordot( self.t_coords, other.t_coords, ([contra_index], other.n_contra + cov_index), ) # currently ordered self:contra, self:cov, other:contra, other:cov # want self:contra, other:contra, self:cov, other:cov # 1 missing each from self:contra, other:cov axis_order = [ *range(self.n_contra - 1), *range(self.n_indices - 1, self.n_indices - 1 + other.n_contra), *range(self.n_contra - 1, self.n_indices - 1), *range( self.n_indices - 1 + other.n_contra, self.n_indices + other.n_indices - 2, ), ] ordered = jnp.transpose(unordered, axis_order) return Tensor(self.point, ordered, self.n_contra + other.n_contra - 1)
def apply_kernel(self, scaling: jnp.ndarray, eps: float = None, axis: int = None): """Applies grid kernel on scaling vector. See notes in parent class for use. Reshapes scaling vector as a grid, applies kernels onto each slice, and then ravels backs the output as a vector. More implementation details in https://arxiv.org/pdf/1708.01955.pdf Args: scaling: jnp.ndarray, a vector of scaling (>0) values. eps: float, regularization strength axis: axis (0 or 1) along which summation should be carried out. Returns: a vector, the result of kernel applied onto scaling. """ scaling = jnp.reshape(scaling, self.grid_size) indices = list(range(1, self.grid_dimension)) for dimension, kernel in enumerate(self.kernel_matrices): ind = indices.copy() ind.insert(dimension, 0) scaling = jnp.tensordot(kernel, scaling, axes=([0], [dimension])).transpose(ind) return scaling.ravel()
def features(x): """Compute the kitchen sink feature.""" # We need to contract last axis of x with first of W - do this with # tensordot. The result has shape: # (?, ?, num_random_features) return jnp.sqrt(2 / num_random_features) * jnp.cos( jnp.sqrt(2 / gamma) * jnp.tensordot(x, w, axes=1) + b)
def nngp_fn_diag(nngp): xs, ws = quad_points x = xs.reshape((xs.shape[0], ) + (1, ) * nngp.ndim) x_axes = (0, ) nngp = np.expand_dims(nngp, x_axes) fval = fn(_sqrt(2 * nngp) * x)**2 return np.tensordot(ws, fval, (x_axes, x_axes)) / np.sqrt(np.pi)
def preconditioned_grad(self, grad, preconditioners): """Precondition the gradient. Args: grad: A gradient tensor to precondition. preconditioners: A list of preconditioners to apply. Returns: A preconditioned gradient. """ reshaped_grad = jnp.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) preconditioned_partitioned_grads = [] num_splits = self._partitioner.num_splits() for i, g in enumerate(partitioned_grads): preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) * num_splits] rank = len(g.shape) precond_g = g for j in range(rank): precond_g = jnp.tensordot( precond_g, preconditioners_for_grad[j], axes=[[0], [0]]) preconditioned_partitioned_grads.append(precond_g) merged_grad = self._partitioner.merge_partitions( preconditioned_partitioned_grads) return jnp.reshape(merged_grad, self._original_shape)
def counterfact_loss(E, off, W, H, M, x, env_sim, cost_func, U_old, k, K, X_old, D, F, alpha, C): y, cost = x, 0 for h in range(H): u_delta = jnp.tensordot(E, jax.lax.dynamic_slice(W, (h, 0), (M, W.shape[1])), axes=([0, 2], [0, 1])) + off u = (U_old[h] + alpha * k[h] + K[h] @ (y.flatten() - X_old[h].flatten()) + C * u_delta) cost = cost_func(y, u, env_sim) new_state, _ = env_sim(y, u) y = y.unflatten(new_state.flatten() + W[h + M]) ## Removing the bottom functionality for performance # if w_is == "de": # y = y.unflatten(new_state.flatten() + W[h + M]) # elif w_is == "dede": # y = y.unflatten(new_state.flatten() + D[h + M] + W[h + M]) # else: # y = y.unflatten( # X_old[h + M + 1].flatten() # + F[h + M][0] @ (y.flatten() - X_old[h + M].flatten()) # + F[h + M][1] @ (u - U_old[h + M]) # + W[h + M] # ) return cost
def update_fun(step, grads, state): """Apply a step of the optimizer.""" del step # Unused. params, grad_seq = state grad_seq = append_to_sequence(grad_seq, grads) params -= jnp.tensordot(meta_params, grad_seq, axes=1) return (params, grad_seq)
def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train): t = np.array(t) * learning_rate t_shape, t_ndim = t.shape, t.ndim first_t_axes = tuple(range(t_ndim)) t = t.reshape((-1, 1)) rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train rhs = np.moveaxis(rhs, trace_axes, last_t_axes).reshape((-1, ) + rhs_shape) shape = t_shape + k_train_train.shape[1::2] + rhs_shape if fx_train_0 is not None: dfx_train = expm1_fn(rhs, t).reshape(shape) dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes) fx_train_t = np.expand_dims(fx_train_0, first_t_axes) + dfx_train if fx_test_0 is not None: dfx_test = inv_expm1_fn(rhs, t).reshape(shape) dfx_test = np.tensordot(k_test_train, dfx_test, (odd, non_t_axes)) dfx_test = np.moveaxis( dfx_test, tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) + last_t_axes, tuple(range(t_ndim)) + trace_axes) fx_test_t = np.expand_dims(fx_test_0, first_t_axes) + dfx_test if fx_train_0 is not None and fx_test_0 is not None: return fx_train_t, fx_test_t if fx_test_0 is None: return fx_train_t return fx_test_t
def conjugate_m_step(expectations, nonconjugate_params): # Compute expected sufficient statistics suff_stats = None num_datapoints = 0 for expects, data_dict, these_weights in zip( expectations, dataset, weights): these_stats = cls.expected_sufficient_statistics( nonconjugate_params, expectations=expects, **data_dict, **kwargs) # weight the statistics if weights are given these_stats = tuple( np.tensordot(these_weights, s, axes=(0, 0)) for s in these_stats) # add to our accumulated statistics suff_stats = sum_tuples(suff_stats, these_stats) # update the number of datapoints num_datapoints += these_weights.sum() # Find the optimal parameters for the conjugate part of the compound distribution posterior_stats = suff_stats posterior_counts = num_datapoints if prior is not None: posterior_stats = sum_tuples(prior.pseudo_obs, posterior_stats) posterior_counts += prior.pseudo_counts # Compute the posterior distribution posterior_class = get_compound(cls) posterior = posterior_class.from_stats(posterior_stats, posterior_counts, **kwargs) return posterior.mode()
def fit(cls, dataset, weights=None, prior=None, **kwargs): """Compute the maximum a posteriori (MAP) estimate of the distribution parameters. For uninformative priors, this reduces to the maximum likelihood estimate. """ # Compute the sufficient statistics and the number of datapoints suff_stats = None num_datapoints = 0 for data_dict, these_weights in zip(dataset, weights): these_stats = cls.sufficient_statistics(**data_dict, **kwargs) # weight the statistics if weights are given if these_weights is not None: these_stats = tuple( np.tensordot(these_weights, s, axes=(0, 0)) for s in these_stats) else: these_stats = tuple(s.sum(axis=0) for s in these_stats) # add to our accumulated statistics suff_stats = sum_tuples(suff_stats, these_stats) # update the number of datapoints num_datapoints += these_weights.sum() return cls.fit_with_stats(suff_stats, num_datapoints, prior=prior, **kwargs)
def additive_kernel( x1, x2, lengthscales, additive_alphas, kernel_alphas, base_kernel_fun, diag_only, jitter=DEFAULT_JITTER, ): N = additive_alphas.shape[0] # TODO: Could make more general to support other kernels to_vmap = lambda x1, x2, lengthscale, alpha: base_kernel_fun( x1.reshape(-1, 1), x2.reshape(-1, 1), lengthscale.reshape(-1,), alpha, diag_only, jitter, ) map_res = vmap(to_vmap)(x1.T, x2.T, lengthscales, kernel_alphas) girard_res = newton_girard_combination(map_res, N) kernel_res = jnp.tensordot(additive_alphas, girard_res, axes=(0, 0)) return kernel_res
def update_fun(step, grads, state): """Apply a step of the optimizer.""" del step # Unused. params, grad_seq, param_seq = state grad_seq = append_to_sequence(grad_seq, grads) param_seq = append_to_sequence(param_seq, params) # Differences in parameters. # TODO(nirum): This recomputes differences at every iteration. Should # time this to ensure that the repeated jnp.diff call is not too slow. delta_params = jnp.diff(param_seq, axis=0) grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1) dx_term = jnp.tensordot(theta_dx, delta_params, axes=1) params -= (grad_term + dx_term) return (params, grad_seq, param_seq)
def tensor(self, other): if not isinstance(other, Tensor): raise TypeError(messages.type_err(Tensor, other)) dom, cod = self.dom + other.dom, self.cod + other.cod array = np.tensordot(self.array, other.array, 0)\ if self.array.shape and other.array.shape\ else self.array * other.array return Tensor(dom, cod, array)
def get_action(self): if self.T < self.T_0: return self.sys_id.get_action(self.x) M_tilde = self.M + self.delta * self.eps[-1] #choose action self.u = -self.K @ self.x + np.tensordot( M_tilde, self.w_past, axes=([0, 2], [0, 1])) return self.u