def partial_trace(A, A_label): """ Partial trace on tensor A over repeated labels in A_label """ num_cont = len(A_label) - len(np.unique(A_label)) if num_cont > 0: dup_list = [] for ele in np.unique(A_label): if sum(A_label == ele) > 1: dup_list.append([np.where(A_label == ele)[0]]) cont_ind = np.array(dup_list).reshape(2*num_cont,order='F') free_ind = onp.delete(np.arange(len(A_label)),cont_ind) cont_dim = np.prod(np.array(A.shape)[cont_ind[:num_cont]]) free_dim = np.array(A.shape)[free_ind] B_label = onp.delete(A_label, cont_ind) cont_label = np.unique(A_label[cont_ind]) B = np.zeros(np.prod(free_dim)) A = A.transpose(np.append(free_ind, cont_ind)).reshape(np.prod(free_dim),cont_dim,cont_dim) for ip in range(cont_dim): B = B + A[:,ip,ip] return B.reshape(free_dim), B_label, cont_label else: return A, A_label, []
def rde(lr: Union[float, Schedule] = 2**-15, train: Union[bool, Schedule] = False, Rs: Array = jnp.unique(jnp.abs(comm.const("16QAM", norm=True))), const: Optional[Array] = None) -> AdaptiveFilter: """Radius Directed adaptive Equalizer Args: lr: learning rate. scalar or Schedule train: schedule training mode, which can be a bool for global control within one call or an array of bool to swich training on iteration basis Rs: the radii of the target constellation const: Optional; constellation used to infer R2 when R2 is None Returns: an ``AdaptiveFilter`` object References: - [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and carrier phase recovery in a 16-QAM optical coherent system. Journal of lightwave technology, 27(15), pp.3042-3049. """ lr = cxopt.make_schedule(lr) train = cxopt.make_schedule(train) if const is not None: Rs = jnp.array(jnp.unique(jnp.abs(const))) def init(dims=2, w0=None, taps=32, dtype=np.complex64): if w0 is None: w0 = np.zeros((dims, dims, taps), dtype=dtype) ctap = (taps + 1) // 2 - 1 w0[np.arange(dims), np.arange(dims), ctap] = 1. return w0 def loss_fn(w, u, x, i): v = r2c(mimo(w, u)[None, :]) R2 = jnp.where( train(i), jnp.abs(x)**2, Rs[jnp.argmin(jnp.abs(Rs[:, None] * v / jnp.abs(v) - v), axis=0)]**2) l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0, :])**2)) return l def update(i, w, inp): u, x = inp l, g = jax.value_and_grad(loss_fn)(w, u, x, i) out = (w, l) w = w - lr(i) * g.conj() return w, out def apply(ws, yf): return jax.vmap(mimo)(ws, yf) return AdaptiveFilter(init, update, apply)
def bdd_message_func(self, edges): """Message function for block-diagonal-decomposition regularizer""" if edges.src['h'].dtype == jnp.int64 and len(edges.src['h'].shape) == 1: raise TypeError('Block decomposition does not allow integer ID feature.') # calculate msg @ W_r before put msg into edge if self.low_mem: etypes = jnp.unique(edges.data['type']) msg = jnp.zeros((edges.src['h'].shape[0], self.out_feat)) for etype in etypes: loc = edges.data['type'] == etype w = self.weight[etype].reshape((self.num_bases, self.submat_in, self.submat_out)) src = edges.src['h'][loc].reshape((-1, self.num_bases, self.submat_in)) sub_msg = jnp.einsum('abc,bcd->abd', src, w) sub_msg = sub_msg.reshape((-1, self.out_feat)) msg = jax.ops.index_update( msg, loc, sub_msg ) else: weight = jnp.take( self.weight, edges.data['type'], 0, ).reshape( (-1, self.submat_in, self.submat_out), ) node = edges.src['h'].reshape((-1, 1, self.submat_in)) msg = jax.lax.batch_matmul(node, weight).reshape((-1, self.out_feat)) if 'norm' in edges.data: msg = msg * edges.data['norm'] return {'msg': msg}
def basis_message_func(self, edges): """Message function for basis regularizer""" if self.num_bases < self.num_rels: # generate all weights from bases weight = self.weight.reshape((self.num_bases, self.in_feat * self.out_feat)) weight = jnp.matmul(self.w_comp, weight).reshape(( self.num_rels, self.in_feat, self.out_feat)) else: weight = self.weight # calculate msg @ W_r before put msg into edge # if src is jnp.int64 we expect it is an index select if edges.src['h'].dtype != jnp.int64 and self.low_mem: etypes = jnp.unique(edges.data['type']) msg = jnp.zeros((edges.src['h'].shape[0], self.out_feat)) for etype in etypes: loc = edges.data['type'] == etype w = weight[etype] src = edges.src['h'][loc] sub_msg = jnp.matmul(src, w) msg = jax.ops.index_update( msg, loc, sub_msg ) else: # put W_r into edges then do msg @ W_r msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type']) if 'norm' in edges.data: msg = msg * edges.data['norm'] return {'msg': msg}
def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic mean computation. Parameters ---------- pre_values: float, jax.numpy.ndarray, JaxArray, Variable The pre-synaptic values. pre_ids: jax.numpy.ndarray, JaxArray The connected pre-synaptic neuron ids. post_ids: jax.numpy.ndarray, JaxArray The connected post-synaptic neuron ids. post_num: int Output dimension. The number of post-synaptic neurons. Returns ------- post_val: jax.numpy.ndarray, JaxArray The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num, dtype=profile.float_) pre_values = as_device_array(pre_values) post_ids = as_device_array(post_ids) if jnp.ndim(pre_values) == 0: # return out.at[post_ids].set(pre_values) return out.at[jnp.unique(post_ids)].set(pre_values) else: _raise_pre_ids_is_none(pre_ids) pre_ids = as_device_array(pre_ids) pre_values = pre2syn(pre_values, pre_ids) return syn2post_mean(pre_values, post_ids, post_num)
def sum_from_unique( cls, input: np.array, mean: bool = True) -> Tuple[np.array, np.array, "SparseReduce"]: un, cts = np.unique(input, return_counts=True) un_idx = [ np.argwhere(input == un[i]).flatten() for i in range(un.size) ] l_arr = np.array([i.size for i in un_idx]) argsort = np.argsort(l_arr) un_sorted = un[argsort] cts_sorted = cts[argsort] un_idx_sorted = [un_idx[i] for i in argsort] change = list( np.argwhere( l_arr[argsort][:-1] - l_arr[argsort][1:] != 0).flatten() + 1) change.insert(0, 0) change.append(len(l_arr)) change = np.array(change) el = [] for i in range(len(change) - 1): el.append( np.array([ un_idx_sorted[j] for j in range(change[i], change[i + 1]) ])) #assert False return un_sorted, cts_sorted, SparseReduce(el, mean)
def _log_compare(mat, cats, significance_test=scipy.stats.ttest_ind): """Calculates pairwise log ratios between all features and performs a significiance test (i.e. t-test) to determine if there is a significant difference in feature ratios with respect to the variable of interest. Parameters ---------- mat: np.array rows correspond to samples and columns correspond to features (i.e. OTUs) cats: np.array, float Vector of categories significance_test: function statistical test to run Returns: -------- log_ratio : np.array log ratio pvalue matrix """ r, c = mat.shape log_ratio = np.zeros((c, c)) log_mat = np.log(mat) cs = np.unique(cats) def func(x): return significance_test(*[x[cats == k] for k in cs]) for i in range(c - 1): ratio = (log_mat[:, i].T - log_mat[:, i + 1:].T).T m, p = np.apply_along_axis(func, axis=0, arr=ratio) log_ratio[i, i + 1:] = np.squeeze(np.array(p.T)) return log_ratio
def compute_fun(R, **kwargs): D_fn = partial(displacement, **kwargs) D_fn = space.map_product(D_fn) D_different_types = [ D_fn(R[species == atom_type, :], R) for atom_type in np.unique(species) ] out = [] atom_types = np.unique(species) for i in range(len(atom_types)): for j in range(i, len(atom_types)): out += [ np.sum( _all_pairs_angular(D_different_types[i], D_different_types[j]), axis=[1, 2]) ] return np.hstack(out)
def prob_inf_house_size_iter(state, hh_sizes_, house_dist): """ Function that computes the probability of an individual getting infected given their household size. @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation @type : Device Array of shape (# of iterations, population size) @param hh_sizes_ : An array which keeps track of the size of each individual's household @type : Array of length = population size @param house_dist : Distribution of household sizes @type : List or 1D array @return : Returns the probability of infection given household size and the mean probability of infection @type : Tuple """ hh_sizes = np.asarray(hh_sizes_) iterations = len(state) prob_hh_size = np.zeros((iterations, len(house_dist))) pop = len(state[0]) mean_inf_prob = np.zeros(iterations) # First compute the probability of the household size given that the person was infected and then use Bayes rule for i in range(iterations): if_inf = np.where(state[i] > 0)[0] inf_size = len(if_inf) hh_inf = hh_sizes[if_inf] prob = ((np.array(np.unique(hh_inf, return_counts= True))[-1])/inf_size) * (inf_size/pop) * (1/house_dist) # Bayes rule prob_hh_size = index_add(prob_hh_size, i, prob) mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop) # Returns the probability of infection given household size return np.average(prob_hh_size, axis = 0) , np.average(mean_inf_prob)
def arithmetic_encoding_num_bits(v: jnp.ndarray) -> int: """Computes number of bits needed to store v via arithmetic coding.""" v = jnp.nan_to_num(v) v = v.flatten() uniq = jnp.unique(v) entropy = _entropy(v, uniq) hist_bits = _hist_bits(v, uniq) return hist_bits + (v.size * entropy) + (2 * 32) + 2
def unique(x, return_index=False, return_inverse=False, return_counts=False, axis=None): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.unique(x, return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, axis=axis))
def pair_correlation(displacement_or_metric: Union[DisplacementFn, MetricFn], radii: Array, sigma: float, species: Array = None): """Computes the pair correlation function at a mesh of distances. The pair correlation function measures the number of particles at a given distance from a central particle. The pair correlation function is defined by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the approximation $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$. Args: displacement_or_metric: A function that computes the displacement or distance between two points. radii: An array of radii at which we would like to compute g(r). sigima: A float specifying the width of the approximating Gaussian. species: An optional array specifying the species of each particle. If species is None then we compute a single g(r) for all particles, otherwise we compute one g(r) for each species. Returns: A function `g_fn` that computes the pair correlation function for a collection of particles. """ d = space.canonicalize_displacement_or_metric(displacement_or_metric) d = space.map_product(d) def pairwise(dr, dim): return jnp.exp(-f32(0.5) * (dr - radii)**2 / sigma**2) / radii**(dim - 1) pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) if species is None: def g_fn(R): dim = R.shape[-1] mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype) return jnp.sum(mask[:, :, jnp.newaxis] * pairwise(d(R, R), dim), axis=(1, )) else: if not (isinstance(species, jnp.ndarray) and is_integer(species)): raise TypeError('Malformed species; expecting array of integers.') species_types = jnp.unique(species) def g_fn(R): dim = R.shape[-1] g_R = [] mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype) for s in species_types: Rs = R[species == s] mask_s = mask[:, species == s, jnp.newaxis] g_R += [jnp.sum(mask_s * pairwise(d(Rs, R), dim), axis=(1, ))] return g_R return g_fn
def heavy_atoms(self): unique, counts = jnp.unique(self.Z, return_counts=True) dictionary = dict(zip(unique, counts)) try: heavy_atoms = self.Z.size - dictionary[1] except KeyError: print("In file %s no hydrogens were reported" % self.filename) heavy_atoms = self.Z.size return (heavy_atoms)
def check_inputs(connect_list, flat_connect, dims_list, cont_order): """ Check consistancy of NCON inputs""" pos_ind = flat_connect[flat_connect > 0] neg_ind = flat_connect[flat_connect < 0] # check that lengths of lists match if len(dims_list) != len(connect_list): raise ValueError(('NCON error: %i tensors given but %i index sublists given') %(len(dims_list), len(connect_list))) # check that tensors have the right number of indices for ele in range(len(dims_list)): if len(dims_list[ele]) != len(connect_list[ele]): raise ValueError(('NCON error: number of indices does not match number of labels on tensor %i: ' '%i-indices versus %i-labels')%(ele,len(dims_list[ele]),len(connect_list[ele]))) # check that contraction order is valid if not np.array_equal(np.sort(cont_order),np.unique(pos_ind)): raise ValueError(('NCON error: invalid contraction order')) # check that negative indices are valid for ind in np.arange(-1,-len(neg_ind)-1,-1): if sum(neg_ind == ind) == 0: raise ValueError(('NCON error: no index labelled %i') %(ind)) elif sum(neg_ind == ind) > 1: raise ValueError(('NCON error: more than one index labelled %i')%(ind)) # check that positive indices are valid and contracted tensor dimensions match flat_dims = np.array([item for sublist in dims_list for item in sublist]) for ind in np.unique(pos_ind): if sum(pos_ind == ind) == 1: raise ValueError(('NCON error: only one index labelled %i')%(ind)) elif sum(pos_ind == ind) > 2: raise ValueError(('NCON error: more than two indices labelled %i')%(ind)) cont_dims = flat_dims[flat_connect == ind] if cont_dims[0] != cont_dims[1]: raise ValueError(('NCON error: tensor dimension mismatch on index labelled %i: ' 'dim-%i versus dim-%i')%(ind,cont_dims[0],cont_dims[1])) return True
def sum_from_unique( cls, input: Array, mean: bool = True) -> Tuple[np.array, np.array, "LinearReduce"]: un, cts = np.unique(input, return_counts=True) un_idx = [ np.argwhere(input == un[i]).flatten() for i in range(un.size) ] m = np.zeros((len(un_idx), input.shape[0])) for i, idx in enumerate(un_idx): b = np.ones(int(cts[i].squeeze())).squeeze() m = m.at[i, idx.squeeze()].set(b / cts[i].squeeze() if mean else b) return un, cts, LinearReduce(m)
def get_group_zellner(groups, X, isgmom=False): """Note that V=(XtX)^-1 and Vinv=XtX.""" n, p = X.shape Vinv = jnp.zeros((p, p)) V = jnp.zeros((p, p)) for group, p_j in zip(*jnp.unique(groups, return_counts=True)): mask = jnp.arange(p)[groups == group] X_j = X[:, mask] p_term = cond(isgmom, p_j, lambda x: x, p_j, lambda x: x + 2) aux = jnp.dot(X_j.T, X_j) * n / p_term Vinv = Vinv.at[jnp.ix_(mask, mask)].set(aux) V = V.at[jnp.ix_(mask, mask)].set(jnp.linalg.inv(aux)) return V, Vinv
def random_adjacency(key: jnp.ndarray, num_nodes: int, num_edges: int, dtype=jnp.float32) -> COO: """ Get the adjacency matrix of a random fully connected undirected graph. Note that `num_edges` is only approximate. The process of creating edges it: - sample `num_edges` random edges - remove self-edges - add ring edges - add reverse edges - filter duplicates Args: key: `jax.random.PRNGKey`. num_nodes: number of nodes in returned graph. num_edges: number of random internal edges initially added. dtype: dtype of returned JAXSparse. Returns: COO, shape (num_nodes, num_nodes), weights all ones. """ shape = num_nodes, num_nodes internal_indices = jax.random.uniform( key, shape=(num_edges, ), dtype=jnp.float32, maxval=num_nodes**2, ).astype(jnp.int32) # remove randomly sampled self-edges. self_edges = (internal_indices // num_nodes) == (internal_indices % num_nodes) internal_indices = internal_indices[jnp.logical_not(self_edges)] # add a ring so we know the graph is connected r = jnp.arange(num_nodes, dtype=jnp.int32) ring_indices = r * num_nodes + (r + 1) % num_nodes indices = jnp.concatenate((internal_indices, ring_indices)) # add reverse indices coords = jnp.unravel_index(indices, shape) coords_rev = coords[-1::-1] indices_rev = jnp.ravel_multi_index(coords_rev, shape) indices = jnp.concatenate((indices, indices_rev)) # filter out duplicates indices = jnp.unique(indices) row, col = jnp.unravel_index(indices, shape) return COO((jnp.ones((row.size, ), dtype=dtype), row, col), shape=shape)
def create_spatiotemporal_grid(X, Y): """ create a grid of data sized [T, R1, R2] note that this function removes full duplicates (i.e. where all dimensions match) TODO: generalise to >5D """ if Y.ndim < 2: Y = Y[:, None] num_spatial_dims = X.shape[1] - 1 if num_spatial_dims == 4: sort_ind = nnp.lexsort( (X[:, 4], X[:, 3], X[:, 2], X[:, 1], X[:, 0])) # sort by 0, 1, 2, 4 elif num_spatial_dims == 3: sort_ind = nnp.lexsort( (X[:, 3], X[:, 2], X[:, 1], X[:, 0])) # sort by 0, 1, 2, 3 elif num_spatial_dims == 2: sort_ind = nnp.lexsort((X[:, 2], X[:, 1], X[:, 0])) # sort by 0, 1, 2 elif num_spatial_dims == 1: sort_ind = nnp.lexsort((X[:, 1], X[:, 0])) # sort by 0, 1 else: raise NotImplementedError X = X[sort_ind] Y = Y[sort_ind] unique_time = np.unique(X[:, 0]) unique_space = nnp.unique(X[:, 1:], axis=0) N_t = unique_time.shape[0] N_r = unique_space.shape[0] if num_spatial_dims == 4: R = np.tile(unique_space, [N_t, 1, 1, 1, 1]) elif num_spatial_dims == 3: R = np.tile(unique_space, [N_t, 1, 1, 1]) elif num_spatial_dims == 2: R = np.tile(unique_space, [N_t, 1, 1]) elif num_spatial_dims == 1: R = np.tile(unique_space, [N_t, 1]) else: raise NotImplementedError R_flat = R.reshape(-1, num_spatial_dims) Y_dummy = np.nan * np.zeros([N_t * N_r, 1]) time_duplicate = np.tile(unique_time, [N_r, 1]).T.flatten() X_dummy = np.block([time_duplicate[:, None], R_flat]) X_all = np.vstack([X, X_dummy]) Y_all = np.vstack([Y, Y_dummy]) X_unique, ind = nnp.unique(X_all, axis=0, return_index=True) Y_unique = Y_all[ind] grid_shape = (unique_time.shape[0], ) + unique_space.shape R_grid = X_unique[:, 1:].reshape(grid_shape) Y_grid = Y_unique.reshape(grid_shape[:-1] + (1, )) return unique_time[:, None], R_grid, Y_grid
def compute_fun(R, **kwargs): _metric = partial(metric, **kwargs) _metric = space.map_product(_metric) radial_fn = lambda eta, dr: (np.exp(-eta * dr**2) * _behler_parrinello_cutoff_fn(dr, cutoff_distance)) def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[species == atom_type, :] dr = _metric(R, R_neigh) radial = vmap(radial_fn, (0, None))(etas, dr) return np.sum(radial, axis=1).T return np.hstack([return_radial(atom_type) for atom_type in np.unique(species)])
def union_bad_ants(JDs): """Return all the bad antennas for the specified JDs :param: Julian Days :type: ndarray, list :return: Union of bad antennas for JDs :rtype: ndarray """ bad_ants_fn = os.path.join(os.path.dirname(__file__), 'bad_ants_idr2.pkl') with open(bad_ants_fn, 'rb') as f: bad_ants_dict = pickle.load(f) bad_ants = np.array([], dtype=int) for JD in JDs: bad_ants = np.append(bad_ants, bad_ants_dict[JD]) return np.sort(np.unique(bad_ants))
def updateMeans(self, X, clusters, means): clusters = clusters.reshape(clusters.shape[0], 1) n = X.shape[1] X = jnp.hstack((X, clusters)) X = X[X[:, n].argsort()] spilited = jnp.split(X[:, :n], jnp.unique(X[:, n], return_index=True)[1][1:]) temp = [0 for j in range(len(spilited))] #jnp.zeros((len(spilited),n)) for i in range(len(spilited)): temp[i] = jnp.mean(spilited[i], axis=0) temp = jnp.array(temp) newmean = (means + temp) / 2 return newmean
def _load_dataset(): _, fetch = load_dataset(COVTYPE, shuffle=False) features, labels = fetch() # normalize features and add intercept features = (features - features.mean(0)) / features.std(0) features = jnp.hstack([features, jnp.ones((features.shape[0], 1))]) # make binary feature _, counts = jnp.unique(labels, return_counts=True) specific_category = jnp.argmax(counts) labels = labels == specific_category N, dim = features.shape print("Data shape:", features.shape) print("Label distribution: {} has label 1, {} has label 0".format( labels.sum(), N - labels.sum())) return features, labels
def plot_gmm_changepoints(ax, gmm_output, timesteps=None): X = gmm_output["observed"] states = gmm_output["latent"] n_states = len(jnp.unique(states)) T = len(X) timesteps = jnp.arange(T) if timesteps is None else timesteps ax[0].plot(timesteps, X, marker="o", markersize=3, linewidth=1, c="tab:gray") ax[1].scatter(timesteps, states, c="tab:gray") ax[1].set_yticks(jnp.arange(n_states)) for y in range(n_states): ax[1].axhline(y=y, c="tab:gray", alpha=0.3) for changepoint, axi in product(changepoints, ax): axi.axvline(x=changepoint, c="tab:red", linestyle="dotted") for axi in ax: axi.set_xlim(timesteps[0], timesteps[-1])
def glmm(dept, male, applications, admit=None): v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.]))) sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2))) L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2)) scale_tril = sigma[..., jnp.newaxis] * L_Rho # non-centered parameterization num_dept = len(jnp.unique(dept)) z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1)) v = jnp.dot(scale_tril, z.T).T logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male if admit is None: # we use a Delta site to record probs for predictive distribution probs = expit(logits) numpyro.sample('probs', dist.Delta(probs), obs=probs) numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def compute_fun(R: Array, neighbor: NeighborList, **kwargs) -> Array: _metric = partial(metric, **kwargs) _metric = space.map_neighbor(_metric) radial_fn = lambda eta, dr: (np.exp( -eta * dr**2) * _behler_parrinello_cutoff_fn(dr, cutoff_distance)) def return_radial(atom_type): """Returns the radial symmetry functions for neighbor type atom_type.""" R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] mask = np.logical_and(neighbor.idx < R.shape[0], species_neigh == atom_type) dr = _metric(R, R_neigh) radial = vmap(radial_fn, (0, None))(etas, dr) return util.high_precision_sum(radial * mask[np.newaxis, :, :], axis=2).T return np.hstack( [return_radial(atom_type) for atom_type in np.unique(species)])
def lib_as_grid(self): """Convert the library parameters to pixel indices in each dimension, and build and store a KDTree for the pixel coordinates. """ # Get the unique gridpoints in each param self.gridpoints = {} self.binwidths = {} for p in self.labels: self.gridpoints[p] = np.unique(self.libparams[p]) self.binwidths[p] = np.diff(self.gridpoints[p]) # Digitize the library parameters X = np.array([ np.digitize(self.libparams[p], bins=self.gridpoints[p], right=True) for p in self.labels ]) self.X = X.T # Build the KDTree startime = datetime.now() self._kdt = KDTree(self.X, leafsize=1000) # , metric='euclidean') print('built KDTree: {}'.format(datetime.now() - startime))
def rde(lr=1e-4, Rs=jnp.unique(jnp.abs(comm.const("16QAM", norm=True)))): ''' References: [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and carrier phase recovery in a 16-QAM optical coherent system. Journal of lightwave technology, 27(15), pp.3042-3049. ''' def init(w0=None, taps=19, dims=2, unitarize=False): if w0 is None: w0 = np.zeros((2, 2, taps), dtype=np.complex64) ctap = (taps + 1) // 2 - 1 w0[np.arange(dims), np.arange(dims), ctap] = 1. elif unitarize: try: w0 = unitarize_mimo_weights(w0) except: pass return w0 def update(w, inp): u, Rx, train = inp def loss_fn(w, u): v = mimo(w, u)[None,:] R2 = jnp.where(train, Rx**2, Rs[jnp.argmin( jnp.abs(Rs[:,None] * v / jnp.abs(v) - v), axis=0)]**2) l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0,:])**2)) return l l, g = jax.value_and_grad(loss_fn)(w, u) out = (l, w) w = w - lr * g.conj() return w, out def static_map(ws, yf): return jax.vmap(mimo)(ws, yf) return AdaptiveFilter(init, update, static_map)
def get_theta_grid(self): self.theta_grid = dict() tg = self.theta_grid ntheta = 11 ntheta_fine = 121 # preliminary theta_min = 0.01 theta_max = 0.99 tg['theta_grid_coarse'] = np.linspace(theta_min, theta_max, ntheta) tg['ntheta_coarse'] = ntheta tfine = np.unique( np.concatenate( (np.linspace(theta_min, theta_max, ntheta_fine), tg['theta_grid_coarse']))) tg['theta_gird_fine'] = tfine tg['ntheta_fine'] = tfine.size tg['v_theta'] = VecOnGrid(tg['theta_grid_coarse'], tg['theta_gird_fine'])
def compute_fun(R, neighbor, **kwargs): D_fn = partial(displacement, **kwargs) D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] species_neigh = species[neighbor.idx] atom_types = np.unique(species) base_mask = neighbor.idx < len(R) mask = [ np.logical_and(base_mask, species_neigh == t) for t in atom_types ] out = [] dR = D_fn(R, R_neigh) all_angular = _all_pairs_angular(dR, dR) for i in range(len(atom_types)): mask_i = mask[i][:, :, np.newaxis, np.newaxis] for j in range(i, len(atom_types)): mask_j = mask[j][:, np.newaxis, :, np.newaxis] out += [np.sum(all_angular * mask_i * mask_j, axis=[1, 2])] return np.hstack(out)
def restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, options, deriv_order=0, return_aux_data=False): # Load keyword options maxit = options['maxit'] damping = options['damping'] damp_factor = options['damp_factor'] spectral_shift = options['spectral_shift'] convergence = 1e-10 nelectrons = int(jnp.sum(nuclear_charges)) - charge ndocc = nelectrons // 2 # If we are doing MP2 or CCSD after, might as well use jit-compiled JK-build, since HF will not be memory bottleneck if return_aux_data: jk_build = jax.jit(jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None))) else: jk_build = jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None)) # Canonical orthogonalization via cholesky decomposition S, T, V, G = compute_integrals(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order, options) A = cholesky_orthogonalization(S) nbf = S.shape[0] # For slightly shifting eigenspectrum of transformed Fock for degenerate eigenvalues # (JAX cannot differentiate degenerate eigenvalue eigh) if spectral_shift: # Shifting eigenspectrum requires lower convergence. convergence = 1e-8 fudge = jnp.asarray(np.linspace(0, 1, nbf)) * convergence shift = jnp.diag(fudge) else: shift = jnp.zeros_like(S) H = T + V Enuc = nuclear_repulsion(geom.reshape(-1,3),nuclear_charges) D = jnp.zeros_like(H) def rhf_iter(F,D): E_scf = jnp.einsum('pq,pq->', F + H, D) + Enuc Fp = jnp.dot(A.T, jnp.dot(F, A)) Fp = Fp + shift eps, C2 = jnp.linalg.eigh(Fp) C = jnp.dot(A,C2) Cocc = C[:, :ndocc] D = jnp.dot(Cocc, Cocc.T) return E_scf, D, C, eps iteration = 0 E_scf = 1.0 E_old = 0.0 Dold = jnp.zeros_like(D) dRMS = 1.0 # Converge according to energy and DIIS residual to ensure eigenvalues and eigenvectors are maximally converged. # This is crucial for numerical stability for higher order derivatives of correlated methods. while ((abs(E_scf - E_old) > convergence) or (dRMS > convergence)): E_old = E_scf * 1 if damping: if iteration < 10: D = Dold * damp_factor + D * damp_factor Dold = D * 1 # Build JK matrix: 2 * J - K JK = 2 * jk_build(G, D) JK -= jk_build(G.transpose((0,2,1,3)), D) # Build Fock F = H + JK # Update convergence error if iteration > 1: diis_e = jnp.einsum('ij,jk,kl->il', F, D, S) - jnp.einsum('ij,jk,kl->il', S, D, F) diis_e = A.dot(diis_e).dot(A) dRMS = jnp.mean(diis_e**2)**0.5 # Compute energy, transform Fock and diagonalize, get new density E_scf, D, C, eps = rhf_iter(F,D) iteration += 1 if iteration == maxit: break print(iteration, " RHF iterations performed") # If many orbitals are degenerate, warn that higher order derivatives may be unstable tmp = jnp.round(eps,6) ndegen_orbs = tmp.shape[0] - jnp.unique(tmp).shape[0] if (ndegen_orbs / nbf) > 0.20: print("Hartree-Fock warning: More than 20% of orbitals have degeneracies. Higher order derivatives may be unstable due to eigendecomposition AD rule") if not return_aux_data: return E_scf else: return E_scf, C, eps, G