def partial_trace(A, A_label):
    """ Partial trace on tensor A over repeated labels in A_label """

    num_cont = len(A_label) - len(np.unique(A_label))
    if num_cont > 0:
        dup_list = []
        for ele in np.unique(A_label):
            if sum(A_label == ele) > 1:
                dup_list.append([np.where(A_label == ele)[0]])

        cont_ind = np.array(dup_list).reshape(2*num_cont,order='F')
        free_ind = onp.delete(np.arange(len(A_label)),cont_ind)

        cont_dim = np.prod(np.array(A.shape)[cont_ind[:num_cont]])
        free_dim = np.array(A.shape)[free_ind]

        B_label = onp.delete(A_label, cont_ind)
        cont_label = np.unique(A_label[cont_ind])
        B = np.zeros(np.prod(free_dim))
        A = A.transpose(np.append(free_ind, cont_ind)).reshape(np.prod(free_dim),cont_dim,cont_dim)
        for ip in range(cont_dim):
            B = B + A[:,ip,ip]

        return B.reshape(free_dim), B_label, cont_label

    else:
        return A, A_label, []
Пример #2
0
def rde(lr: Union[float, Schedule] = 2**-15,
        train: Union[bool, Schedule] = False,
        Rs: Array = jnp.unique(jnp.abs(comm.const("16QAM", norm=True))),
        const: Optional[Array] = None) -> AdaptiveFilter:
    """Radius Directed adaptive Equalizer

    Args:
      lr: learning rate. scalar or Schedule
      train: schedule training mode, which can be a bool for global control within one call
        or an array of bool to swich training on iteration basis
      Rs: the radii of the target constellation
      const: Optional; constellation used to infer R2 when R2 is None

    Returns:
      an ``AdaptiveFilter`` object

    References:
      - [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and
        carrier phase recovery in a 16-QAM optical coherent system. Journal
        of lightwave technology, 27(15), pp.3042-3049.
    """
    lr = cxopt.make_schedule(lr)
    train = cxopt.make_schedule(train)

    if const is not None:
        Rs = jnp.array(jnp.unique(jnp.abs(const)))

    def init(dims=2, w0=None, taps=32, dtype=np.complex64):
        if w0 is None:
            w0 = np.zeros((dims, dims, taps), dtype=dtype)
            ctap = (taps + 1) // 2 - 1
            w0[np.arange(dims), np.arange(dims), ctap] = 1.
        return w0

    def loss_fn(w, u, x, i):
        v = r2c(mimo(w, u)[None, :])
        R2 = jnp.where(
            train(i),
            jnp.abs(x)**2,
            Rs[jnp.argmin(jnp.abs(Rs[:, None] * v / jnp.abs(v) - v),
                          axis=0)]**2)
        l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0, :])**2))
        return l

    def update(i, w, inp):
        u, x = inp
        l, g = jax.value_and_grad(loss_fn)(w, u, x, i)
        out = (w, l)
        w = w - lr(i) * g.conj()
        return w, out

    def apply(ws, yf):
        return jax.vmap(mimo)(ws, yf)

    return AdaptiveFilter(init, update, apply)
Пример #3
0
    def bdd_message_func(self, edges):
        """Message function for block-diagonal-decomposition regularizer"""
        if edges.src['h'].dtype == jnp.int64 and len(edges.src['h'].shape) == 1:
            raise TypeError('Block decomposition does not allow integer ID feature.')

        # calculate msg @ W_r before put msg into edge
        if self.low_mem:
            etypes = jnp.unique(edges.data['type'])
            msg = jnp.zeros((edges.src['h'].shape[0], self.out_feat))
            for etype in etypes:
                loc = edges.data['type'] == etype
                w = self.weight[etype].reshape((self.num_bases, self.submat_in, self.submat_out))
                src = edges.src['h'][loc].reshape((-1, self.num_bases, self.submat_in))
                sub_msg = jnp.einsum('abc,bcd->abd', src, w)
                sub_msg = sub_msg.reshape((-1, self.out_feat))
                msg = jax.ops.index_update(
                    msg,
                    loc,
                    sub_msg
                )
        else:
            weight = jnp.take(
                self.weight,
                edges.data['type'],
                0,
            ).reshape(
                (-1, self.submat_in, self.submat_out),
            )

            node = edges.src['h'].reshape((-1, 1, self.submat_in))
            msg = jax.lax.batch_matmul(node, weight).reshape((-1, self.out_feat))
        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg}
Пример #4
0
    def basis_message_func(self, edges):
        """Message function for basis regularizer"""
        if self.num_bases < self.num_rels:
            # generate all weights from bases
            weight = self.weight.reshape((self.num_bases,
                                      self.in_feat * self.out_feat))
            weight = jnp.matmul(self.w_comp, weight).reshape((
                self.num_rels, self.in_feat, self.out_feat))
        else:
            weight = self.weight

        # calculate msg @ W_r before put msg into edge
        # if src is jnp.int64 we expect it is an index select
        if edges.src['h'].dtype != jnp.int64 and self.low_mem:
            etypes = jnp.unique(edges.data['type'])
            msg = jnp.zeros((edges.src['h'].shape[0], self.out_feat))
            for etype in etypes:
                loc = edges.data['type'] == etype
                w = weight[etype]
                src = edges.src['h'][loc]
                sub_msg = jnp.matmul(src, w)
                msg = jax.ops.index_update(
                    msg,
                    loc,
                    sub_msg
                )
        else:
            # put W_r into edges then do msg @ W_r
            msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type'])

        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg}
Пример #5
0
def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None):
    """The pre-to-post synaptic mean computation.

  Parameters
  ----------
  pre_values: float, jax.numpy.ndarray, JaxArray, Variable
    The pre-synaptic values.
  pre_ids: jax.numpy.ndarray, JaxArray
    The connected pre-synaptic neuron ids.
  post_ids: jax.numpy.ndarray, JaxArray
    The connected post-synaptic neuron ids.
  post_num: int
    Output dimension. The number of post-synaptic neurons.

  Returns
  -------
  post_val: jax.numpy.ndarray, JaxArray
    The value with the size of post-synaptic neurons.
  """
    out = jnp.zeros(post_num, dtype=profile.float_)
    pre_values = as_device_array(pre_values)
    post_ids = as_device_array(post_ids)
    if jnp.ndim(pre_values) == 0:
        # return out.at[post_ids].set(pre_values)
        return out.at[jnp.unique(post_ids)].set(pre_values)
    else:
        _raise_pre_ids_is_none(pre_ids)
        pre_ids = as_device_array(pre_ids)
        pre_values = pre2syn(pre_values, pre_ids)
        return syn2post_mean(pre_values, post_ids, post_num)
Пример #6
0
    def sum_from_unique(
            cls,
            input: np.array,
            mean: bool = True) -> Tuple[np.array, np.array, "SparseReduce"]:
        un, cts = np.unique(input, return_counts=True)
        un_idx = [
            np.argwhere(input == un[i]).flatten() for i in range(un.size)
        ]
        l_arr = np.array([i.size for i in un_idx])
        argsort = np.argsort(l_arr)
        un_sorted = un[argsort]
        cts_sorted = cts[argsort]
        un_idx_sorted = [un_idx[i] for i in argsort]

        change = list(
            np.argwhere(
                l_arr[argsort][:-1] - l_arr[argsort][1:] != 0).flatten() + 1)
        change.insert(0, 0)
        change.append(len(l_arr))
        change = np.array(change)

        el = []
        for i in range(len(change) - 1):
            el.append(
                np.array([
                    un_idx_sorted[j] for j in range(change[i], change[i + 1])
                ]))

        #assert False
        return un_sorted, cts_sorted, SparseReduce(el, mean)
Пример #7
0
def _log_compare(mat, cats, significance_test=scipy.stats.ttest_ind):
    """Calculates pairwise log ratios between all features and performs a
    significiance test (i.e. t-test) to determine if there is a significant
    difference in feature ratios with respect to the variable of interest.

    Parameters
    ----------
    mat: np.array
       rows correspond to samples and columns correspond to
       features (i.e. OTUs)
    cats: np.array, float
       Vector of categories
    significance_test: function
        statistical test to run

    Returns:
    --------
    log_ratio : np.array
        log ratio pvalue matrix
    """
    r, c = mat.shape
    log_ratio = np.zeros((c, c))
    log_mat = np.log(mat)
    cs = np.unique(cats)

    def func(x):
        return significance_test(*[x[cats == k] for k in cs])

    for i in range(c - 1):
        ratio = (log_mat[:, i].T - log_mat[:, i + 1:].T).T
        m, p = np.apply_along_axis(func, axis=0, arr=ratio)
        log_ratio[i, i + 1:] = np.squeeze(np.array(p.T))
    return log_ratio
Пример #8
0
 def compute_fun(R, **kwargs):
   D_fn = partial(displacement, **kwargs)
   D_fn = space.map_product(D_fn)
   D_different_types = [
       D_fn(R[species == atom_type, :], R) for atom_type in np.unique(species)
   ]
   out = []
   atom_types = np.unique(species)
   for i in range(len(atom_types)):
     for j in range(i, len(atom_types)):
       out += [
           np.sum(
               _all_pairs_angular(D_different_types[i], D_different_types[j]),
               axis=[1, 2])
       ]
   return np.hstack(out)
Пример #9
0
def prob_inf_house_size_iter(state, hh_sizes_, house_dist):
  """ Function that computes the probability of an individual getting infected given their household size.
  @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation
  @type : Device Array of shape (# of iterations, population size)
  @param hh_sizes_ : An array which keeps track of the size of each individual's household
  @type : Array of length = population size
  @param house_dist : Distribution of household sizes 
  @type : List or 1D array
  @return : Returns the probability of infection given household size and the mean probability of infection
  @type : Tuple
  """
  hh_sizes = np.asarray(hh_sizes_)
  iterations = len(state)
  prob_hh_size = np.zeros((iterations, len(house_dist)))
  pop = len(state[0])
  mean_inf_prob = np.zeros(iterations)
  
  # First compute the probability of the household size given that the person was infected and then use Bayes rule
  for i in range(iterations):
    if_inf = np.where(state[i] > 0)[0]
    inf_size = len(if_inf)
    hh_inf = hh_sizes[if_inf]
    prob = ((np.array(np.unique(hh_inf, return_counts= True))[-1])/inf_size) * (inf_size/pop) * (1/house_dist) # Bayes rule
    prob_hh_size = index_add(prob_hh_size, i, prob)
    mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop)

  # Returns the probability of infection given household size
  return np.average(prob_hh_size, axis = 0) , np.average(mean_inf_prob)
Пример #10
0
def arithmetic_encoding_num_bits(v: jnp.ndarray) -> int:
  """Computes number of bits needed to store v via arithmetic coding."""
  v = jnp.nan_to_num(v)
  v = v.flatten()
  uniq = jnp.unique(v)
  entropy = _entropy(v, uniq)
  hist_bits = _hist_bits(v, uniq)
  return hist_bits + (v.size * entropy) + (2 * 32) + 2
Пример #11
0
def unique(x, return_index=False, return_inverse=False,
           return_counts=False, axis=None):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.unique(x,
                             return_index=return_index,
                             return_inverse=return_inverse,
                             return_counts=return_counts,
                             axis=axis))
Пример #12
0
def pair_correlation(displacement_or_metric: Union[DisplacementFn, MetricFn],
                     radii: Array,
                     sigma: float,
                     species: Array = None):
    """Computes the pair correlation function at a mesh of distances.

  The pair correlation function measures the number of particles at a given
  distance from a central particle. The pair correlation function is defined
  by $g(r) = <\sum_{i\neq j}\delta(r - |r_i - r_j|)>.$ We make the
  approximation
  $\delta(r) \approx {1 \over \sqrt{2\pi\sigma^2}e^{-r / (2\sigma^2)}}$.

  Args:
    displacement_or_metric: A function that computes the displacement or
      distance between two points.
    radii: An array of radii at which we would like to compute g(r).
    sigima: A float specifying the width of the approximating Gaussian.
    species: An optional array specifying the species of each particle. If
      species is None then we compute a single g(r) for all particles,
      otherwise we compute one g(r) for each species.

  Returns:
    A function `g_fn` that computes the pair correlation function for a
    collection of particles.
  """
    d = space.canonicalize_displacement_or_metric(displacement_or_metric)
    d = space.map_product(d)

    def pairwise(dr, dim):
        return jnp.exp(-f32(0.5) *
                       (dr - radii)**2 / sigma**2) / radii**(dim - 1)

    pairwise = vmap(vmap(pairwise, (0, None)), (0, None))

    if species is None:

        def g_fn(R):
            dim = R.shape[-1]
            mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype)
            return jnp.sum(mask[:, :, jnp.newaxis] * pairwise(d(R, R), dim),
                           axis=(1, ))
    else:
        if not (isinstance(species, jnp.ndarray) and is_integer(species)):
            raise TypeError('Malformed species; expecting array of integers.')
        species_types = jnp.unique(species)

        def g_fn(R):
            dim = R.shape[-1]
            g_R = []
            mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype)
            for s in species_types:
                Rs = R[species == s]
                mask_s = mask[:, species == s, jnp.newaxis]
                g_R += [jnp.sum(mask_s * pairwise(d(Rs, R), dim), axis=(1, ))]
            return g_R

    return g_fn
Пример #13
0
    def heavy_atoms(self):
        unique, counts = jnp.unique(self.Z, return_counts=True)
        dictionary = dict(zip(unique, counts))
        try:
            heavy_atoms = self.Z.size - dictionary[1]
        except KeyError:
            print("In file %s no hydrogens were reported" % self.filename)
            heavy_atoms = self.Z.size

        return (heavy_atoms)
Пример #14
0
def check_inputs(connect_list, flat_connect, dims_list, cont_order):
    """ Check consistancy of NCON inputs"""

    pos_ind = flat_connect[flat_connect > 0]
    neg_ind = flat_connect[flat_connect < 0]

    # check that lengths of lists match
    if len(dims_list) != len(connect_list):
        raise ValueError(('NCON error: %i tensors given but %i index sublists given')
            %(len(dims_list), len(connect_list)))

    # check that tensors have the right number of indices
    for ele in range(len(dims_list)):
        if len(dims_list[ele]) != len(connect_list[ele]):
            raise ValueError(('NCON error: number of indices does not match number of labels on tensor %i: '
                              '%i-indices versus %i-labels')%(ele,len(dims_list[ele]),len(connect_list[ele])))

    # check that contraction order is valid
    if not np.array_equal(np.sort(cont_order),np.unique(pos_ind)):
        raise ValueError(('NCON error: invalid contraction order'))

    # check that negative indices are valid
    for ind in np.arange(-1,-len(neg_ind)-1,-1):
        if sum(neg_ind == ind) == 0:
            raise ValueError(('NCON error: no index labelled %i') %(ind))
        elif sum(neg_ind == ind) > 1:
            raise ValueError(('NCON error: more than one index labelled %i')%(ind))

    # check that positive indices are valid and contracted tensor dimensions match
    flat_dims = np.array([item for sublist in dims_list for item in sublist])
    for ind in np.unique(pos_ind):
        if sum(pos_ind == ind) == 1:
            raise ValueError(('NCON error: only one index labelled %i')%(ind))
        elif sum(pos_ind == ind) > 2:
            raise ValueError(('NCON error: more than two indices labelled %i')%(ind))

        cont_dims = flat_dims[flat_connect == ind]
        if cont_dims[0] != cont_dims[1]:
            raise ValueError(('NCON error: tensor dimension mismatch on index labelled %i: '
                              'dim-%i versus dim-%i')%(ind,cont_dims[0],cont_dims[1]))

    return True
Пример #15
0
 def sum_from_unique(
         cls,
         input: Array,
         mean: bool = True) -> Tuple[np.array, np.array, "LinearReduce"]:
     un, cts = np.unique(input, return_counts=True)
     un_idx = [
         np.argwhere(input == un[i]).flatten() for i in range(un.size)
     ]
     m = np.zeros((len(un_idx), input.shape[0]))
     for i, idx in enumerate(un_idx):
         b = np.ones(int(cts[i].squeeze())).squeeze()
         m = m.at[i, idx.squeeze()].set(b / cts[i].squeeze() if mean else b)
     return un, cts, LinearReduce(m)
Пример #16
0
def get_group_zellner(groups, X, isgmom=False):
    """Note that V=(XtX)^-1 and Vinv=XtX."""
    n, p = X.shape
    Vinv = jnp.zeros((p, p))
    V = jnp.zeros((p, p))
    for group, p_j in zip(*jnp.unique(groups, return_counts=True)):
        mask = jnp.arange(p)[groups == group]
        X_j = X[:, mask]
        p_term = cond(isgmom, p_j, lambda x: x, p_j, lambda x: x + 2)
        aux = jnp.dot(X_j.T, X_j) * n / p_term
        Vinv = Vinv.at[jnp.ix_(mask, mask)].set(aux)
        V = V.at[jnp.ix_(mask, mask)].set(jnp.linalg.inv(aux))
    return V, Vinv
Пример #17
0
def random_adjacency(key: jnp.ndarray,
                     num_nodes: int,
                     num_edges: int,
                     dtype=jnp.float32) -> COO:
    """
    Get the adjacency matrix of a random fully connected undirected graph.

    Note that `num_edges` is only approximate. The process of creating edges it:
    - sample `num_edges` random edges
    - remove self-edges
    - add ring edges
    - add reverse edges
    - filter duplicates

    Args:
        key: `jax.random.PRNGKey`.
        num_nodes: number of nodes in returned graph.
        num_edges: number of random internal edges initially added.
        dtype: dtype of returned JAXSparse.

    Returns:
        COO, shape (num_nodes, num_nodes), weights all ones.
    """
    shape = num_nodes, num_nodes

    internal_indices = jax.random.uniform(
        key,
        shape=(num_edges, ),
        dtype=jnp.float32,
        maxval=num_nodes**2,
    ).astype(jnp.int32)
    # remove randomly sampled self-edges.
    self_edges = (internal_indices // num_nodes) == (internal_indices %
                                                     num_nodes)
    internal_indices = internal_indices[jnp.logical_not(self_edges)]

    # add a ring so we know the graph is connected
    r = jnp.arange(num_nodes, dtype=jnp.int32)
    ring_indices = r * num_nodes + (r + 1) % num_nodes
    indices = jnp.concatenate((internal_indices, ring_indices))

    # add reverse indices
    coords = jnp.unravel_index(indices, shape)
    coords_rev = coords[-1::-1]
    indices_rev = jnp.ravel_multi_index(coords_rev, shape)
    indices = jnp.concatenate((indices, indices_rev))

    # filter out duplicates
    indices = jnp.unique(indices)
    row, col = jnp.unravel_index(indices, shape)
    return COO((jnp.ones((row.size, ), dtype=dtype), row, col), shape=shape)
Пример #18
0
def create_spatiotemporal_grid(X, Y):
    """
    create a grid of data sized [T, R1, R2]
    note that this function removes full duplicates (i.e. where all dimensions match)
    TODO: generalise to >5D
    """
    if Y.ndim < 2:
        Y = Y[:, None]
    num_spatial_dims = X.shape[1] - 1
    if num_spatial_dims == 4:
        sort_ind = nnp.lexsort(
            (X[:, 4], X[:, 3], X[:, 2], X[:, 1], X[:,
                                                   0]))  # sort by 0, 1, 2, 4
    elif num_spatial_dims == 3:
        sort_ind = nnp.lexsort(
            (X[:, 3], X[:, 2], X[:, 1], X[:, 0]))  # sort by 0, 1, 2, 3
    elif num_spatial_dims == 2:
        sort_ind = nnp.lexsort((X[:, 2], X[:, 1], X[:, 0]))  # sort by 0, 1, 2
    elif num_spatial_dims == 1:
        sort_ind = nnp.lexsort((X[:, 1], X[:, 0]))  # sort by 0, 1
    else:
        raise NotImplementedError
    X = X[sort_ind]
    Y = Y[sort_ind]
    unique_time = np.unique(X[:, 0])
    unique_space = nnp.unique(X[:, 1:], axis=0)
    N_t = unique_time.shape[0]
    N_r = unique_space.shape[0]
    if num_spatial_dims == 4:
        R = np.tile(unique_space, [N_t, 1, 1, 1, 1])
    elif num_spatial_dims == 3:
        R = np.tile(unique_space, [N_t, 1, 1, 1])
    elif num_spatial_dims == 2:
        R = np.tile(unique_space, [N_t, 1, 1])
    elif num_spatial_dims == 1:
        R = np.tile(unique_space, [N_t, 1])
    else:
        raise NotImplementedError
    R_flat = R.reshape(-1, num_spatial_dims)
    Y_dummy = np.nan * np.zeros([N_t * N_r, 1])
    time_duplicate = np.tile(unique_time, [N_r, 1]).T.flatten()
    X_dummy = np.block([time_duplicate[:, None], R_flat])
    X_all = np.vstack([X, X_dummy])
    Y_all = np.vstack([Y, Y_dummy])
    X_unique, ind = nnp.unique(X_all, axis=0, return_index=True)
    Y_unique = Y_all[ind]
    grid_shape = (unique_time.shape[0], ) + unique_space.shape
    R_grid = X_unique[:, 1:].reshape(grid_shape)
    Y_grid = Y_unique.reshape(grid_shape[:-1] + (1, ))
    return unique_time[:, None], R_grid, Y_grid
Пример #19
0
  def compute_fun(R, **kwargs):
    _metric = partial(metric, **kwargs)
    _metric = space.map_product(_metric)
    radial_fn = lambda eta, dr: (np.exp(-eta * dr**2) *
                _behler_parrinello_cutoff_fn(dr, cutoff_distance))
    def return_radial(atom_type):
      """Returns the radial symmetry functions for neighbor type atom_type."""
      R_neigh = R[species == atom_type, :]
      dr = _metric(R, R_neigh)
      
      radial = vmap(radial_fn, (0, None))(etas, dr)
      return np.sum(radial, axis=1).T

    return np.hstack([return_radial(atom_type) for 
                     atom_type in np.unique(species)])
Пример #20
0
def union_bad_ants(JDs):
    """Return all the bad antennas for the specified JDs

    :param: Julian Days
    :type: ndarray, list

    :return: Union of bad antennas for JDs
    :rtype: ndarray
    """
    bad_ants_fn = os.path.join(os.path.dirname(__file__), 'bad_ants_idr2.pkl')
    with open(bad_ants_fn, 'rb') as f:
        bad_ants_dict = pickle.load(f)
    bad_ants = np.array([], dtype=int)
    for JD in JDs:
        bad_ants = np.append(bad_ants, bad_ants_dict[JD])
    return np.sort(np.unique(bad_ants))
Пример #21
0
    def updateMeans(self, X, clusters, means):
        clusters = clusters.reshape(clusters.shape[0], 1)
        n = X.shape[1]
        X = jnp.hstack((X, clusters))
        X = X[X[:, n].argsort()]

        spilited = jnp.split(X[:, :n],
                             jnp.unique(X[:, n], return_index=True)[1][1:])

        temp = [0 for j in range(len(spilited))]  #jnp.zeros((len(spilited),n))
        for i in range(len(spilited)):
            temp[i] = jnp.mean(spilited[i], axis=0)

        temp = jnp.array(temp)
        newmean = (means + temp) / 2

        return newmean
Пример #22
0
def _load_dataset():
    _, fetch = load_dataset(COVTYPE, shuffle=False)
    features, labels = fetch()

    # normalize features and add intercept
    features = (features - features.mean(0)) / features.std(0)
    features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])

    # make binary feature
    _, counts = jnp.unique(labels, return_counts=True)
    specific_category = jnp.argmax(counts)
    labels = labels == specific_category

    N, dim = features.shape
    print("Data shape:", features.shape)
    print("Label distribution: {} has label 1, {} has label 0".format(
        labels.sum(), N - labels.sum()))
    return features, labels
Пример #23
0
def plot_gmm_changepoints(ax, gmm_output, timesteps=None):
    X = gmm_output["observed"]
    states = gmm_output["latent"]
    n_states = len(jnp.unique(states))
    T = len(X)
    timesteps = jnp.arange(T) if timesteps is None else timesteps

    ax[0].plot(timesteps, X, marker="o", markersize=3, linewidth=1, c="tab:gray")

    ax[1].scatter(timesteps, states, c="tab:gray")
    ax[1].set_yticks(jnp.arange(n_states))
    for y in range(n_states):
        ax[1].axhline(y=y, c="tab:gray", alpha=0.3)
        
    for changepoint, axi in product(changepoints, ax):
        axi.axvline(x=changepoint, c="tab:red", linestyle="dotted")
    
    for axi in ax:
        axi.set_xlim(timesteps[0], timesteps[-1])
Пример #24
0
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(jnp.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit',
                   dist.Binomial(applications, logits=logits),
                   obs=admit)
Пример #25
0
    def compute_fun(R: Array, neighbor: NeighborList, **kwargs) -> Array:
        _metric = partial(metric, **kwargs)
        _metric = space.map_neighbor(_metric)
        radial_fn = lambda eta, dr: (np.exp(
            -eta * dr**2) * _behler_parrinello_cutoff_fn(dr, cutoff_distance))

        def return_radial(atom_type):
            """Returns the radial symmetry functions for neighbor type atom_type."""
            R_neigh = R[neighbor.idx]
            species_neigh = species[neighbor.idx]
            mask = np.logical_and(neighbor.idx < R.shape[0],
                                  species_neigh == atom_type)
            dr = _metric(R, R_neigh)

            radial = vmap(radial_fn, (0, None))(etas, dr)
            return util.high_precision_sum(radial * mask[np.newaxis, :, :],
                                           axis=2).T

        return np.hstack(
            [return_radial(atom_type) for atom_type in np.unique(species)])
Пример #26
0
 def lib_as_grid(self):
     """Convert the library parameters to pixel indices in each dimension,
     and build and store a KDTree for the pixel coordinates.
     """
     # Get the unique gridpoints in each param
     self.gridpoints = {}
     self.binwidths = {}
     for p in self.labels:
         self.gridpoints[p] = np.unique(self.libparams[p])
         self.binwidths[p] = np.diff(self.gridpoints[p])
     # Digitize the library parameters
     X = np.array([
         np.digitize(self.libparams[p], bins=self.gridpoints[p], right=True)
         for p in self.labels
     ])
     self.X = X.T
     # Build the KDTree
     startime = datetime.now()
     self._kdt = KDTree(self.X, leafsize=1000)  # , metric='euclidean')
     print('built KDTree: {}'.format(datetime.now() - startime))
Пример #27
0
def rde(lr=1e-4, Rs=jnp.unique(jnp.abs(comm.const("16QAM", norm=True)))):
    '''
    References:
    [1] Fatadin, I., Ives, D. and Savory, S.J., 2009. Blind equalization and
        carrier phase recovery in a 16-QAM optical coherent system. Journal
        of lightwave technology, 27(15), pp.3042-3049.
    '''
    def init(w0=None, taps=19, dims=2, unitarize=False):
        if w0 is None:
            w0 = np.zeros((2, 2, taps), dtype=np.complex64)
            ctap = (taps + 1) // 2 - 1
            w0[np.arange(dims), np.arange(dims), ctap] = 1.
        elif unitarize:
            try:
                w0 = unitarize_mimo_weights(w0)
            except:
                pass
        return w0

    def update(w, inp):
        u, Rx, train = inp

        def loss_fn(w, u):
            v = mimo(w, u)[None,:]
            R2 = jnp.where(train,
                           Rx**2,
                           Rs[jnp.argmin(
                               jnp.abs(Rs[:,None] * v / jnp.abs(v) - v),
                               axis=0)]**2)
            l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0,:])**2))
            return l

        l, g = jax.value_and_grad(loss_fn)(w, u)
        out = (l, w)
        w = w - lr * g.conj()
        return w, out

    def static_map(ws, yf):
        return jax.vmap(mimo)(ws, yf)

    return AdaptiveFilter(init, update, static_map)
Пример #28
0
    def get_theta_grid(self):

        self.theta_grid = dict()
        tg = self.theta_grid
        ntheta = 11
        ntheta_fine = 121  # preliminary
        theta_min = 0.01
        theta_max = 0.99

        tg['theta_grid_coarse'] = np.linspace(theta_min, theta_max, ntheta)
        tg['ntheta_coarse'] = ntheta

        tfine = np.unique(
            np.concatenate(
                (np.linspace(theta_min, theta_max,
                             ntheta_fine), tg['theta_grid_coarse'])))

        tg['theta_gird_fine'] = tfine
        tg['ntheta_fine'] = tfine.size
        tg['v_theta'] = VecOnGrid(tg['theta_grid_coarse'],
                                  tg['theta_gird_fine'])
Пример #29
0
    def compute_fun(R, neighbor, **kwargs):
        D_fn = partial(displacement, **kwargs)
        D_fn = space.map_neighbor(D_fn)

        R_neigh = R[neighbor.idx]
        species_neigh = species[neighbor.idx]

        atom_types = np.unique(species)

        base_mask = neighbor.idx < len(R)
        mask = [
            np.logical_and(base_mask, species_neigh == t) for t in atom_types
        ]

        out = []
        dR = D_fn(R, R_neigh)
        all_angular = _all_pairs_angular(dR, dR)

        for i in range(len(atom_types)):
            mask_i = mask[i][:, :, np.newaxis, np.newaxis]
            for j in range(i, len(atom_types)):
                mask_j = mask[j][:, np.newaxis, :, np.newaxis]
                out += [np.sum(all_angular * mask_i * mask_j, axis=[1, 2])]
        return np.hstack(out)
Пример #30
0
def restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, options, deriv_order=0, return_aux_data=False):
    # Load keyword options
    maxit = options['maxit']
    damping = options['damping']
    damp_factor = options['damp_factor']
    spectral_shift = options['spectral_shift']
    convergence = 1e-10

    nelectrons = int(jnp.sum(nuclear_charges)) - charge
    ndocc = nelectrons // 2

    # If we are doing MP2 or CCSD after, might as well use jit-compiled JK-build, since HF will not be memory bottleneck
    if return_aux_data:
        jk_build = jax.jit(jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None)))
    else: 
        jk_build = jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None))

    # Canonical orthogonalization via cholesky decomposition
    S, T, V, G = compute_integrals(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order, options)
    A = cholesky_orthogonalization(S)

    nbf = S.shape[0]

    # For slightly shifting eigenspectrum of transformed Fock for degenerate eigenvalues 
    # (JAX cannot differentiate degenerate eigenvalue eigh) 
    if spectral_shift:
        # Shifting eigenspectrum requires lower convergence.
        convergence = 1e-8 
        fudge = jnp.asarray(np.linspace(0, 1, nbf)) * convergence
        shift = jnp.diag(fudge)
    else:
        shift = jnp.zeros_like(S)

    H = T + V
    Enuc = nuclear_repulsion(geom.reshape(-1,3),nuclear_charges)
    D = jnp.zeros_like(H)
    
    def rhf_iter(F,D):
        E_scf = jnp.einsum('pq,pq->', F + H, D) + Enuc
        Fp = jnp.dot(A.T, jnp.dot(F, A))
        Fp = Fp + shift 
        eps, C2 = jnp.linalg.eigh(Fp)
        C = jnp.dot(A,C2)
        Cocc = C[:, :ndocc]
        D = jnp.dot(Cocc, Cocc.T)
        return E_scf, D, C, eps

    iteration = 0
    E_scf = 1.0
    E_old = 0.0
    Dold = jnp.zeros_like(D)
    dRMS = 1.0

    # Converge according to energy and DIIS residual to ensure eigenvalues and eigenvectors are maximally converged.
    # This is crucial for numerical stability for higher order derivatives of correlated methods.
    while ((abs(E_scf - E_old) > convergence) or (dRMS > convergence)):
        E_old = E_scf * 1
        if damping:
            if iteration < 10:
                D = Dold * damp_factor + D * damp_factor
                Dold = D * 1
        # Build JK matrix: 2 * J - K
        JK = 2 * jk_build(G, D)
        JK -= jk_build(G.transpose((0,2,1,3)), D)
        # Build Fock
        F = H + JK
        # Update convergence error
        if iteration > 1:
            diis_e = jnp.einsum('ij,jk,kl->il', F, D, S) - jnp.einsum('ij,jk,kl->il', S, D, F)
            diis_e = A.dot(diis_e).dot(A)
            dRMS = jnp.mean(diis_e**2)**0.5
        # Compute energy, transform Fock and diagonalize, get new density
        E_scf, D, C, eps = rhf_iter(F,D)
        iteration += 1
        if iteration == maxit:
            break
    print(iteration, " RHF iterations performed")

    # If many orbitals are degenerate, warn that higher order derivatives may be unstable 
    tmp = jnp.round(eps,6)
    ndegen_orbs =  tmp.shape[0] - jnp.unique(tmp).shape[0] 
    if (ndegen_orbs / nbf) > 0.20:
        print("Hartree-Fock warning: More than 20% of orbitals have degeneracies. Higher order derivatives may be unstable due to eigendecomposition AD rule")
    if not return_aux_data:
        return E_scf
    else:
        return E_scf, C, eps, G