def _get_neighborhood_matrix_params(format: partition.NeighborListFormat, idx: Array, params: Array, combinator: Callable) -> Array: if util.is_array(params): if len(params.shape) == 1: if partition.is_sparse(format): return space.map_bond(combinator)(params[idx[0]], params[idx[1]]) else: return combinator(params[:, None], params[idx]) return space.map_neighbor(combinator)(params, params[idx]) elif len(params.shape) == 2: def query(id_a, id_b): return params[id_a, id_b] if partition.is_sparse(format): return space.map_bond(query)(idx[0], idx[1]) else: query = vmap(vmap(query, (None, 0))) return query(jnp.arange(idx.shape[0], dtype=jnp.int32), idx) elif len(params.shape) == 0: return params else: raise NotImplementedError() elif (isinstance(params, int) or isinstance(params, float) or jnp.issubdtype(params, jnp.integer) or jnp.issubdtype(params, jnp.floating)): return params else: raise NotImplementedError()
def fn_mapped(R: Array, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array: d = partial(displacement_or_metric, **dynamic_kwargs) _species = dynamic_kwargs.get('species', species) normalization = 2.0 if partition.is_sparse(neighbor.format): d = space.map_bond(d) dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]]) mask = neighbor.idx[0] < R.shape[0] if neighbor.format is partition.OrderedSparse: normalization = 1.0 else: d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) mask = neighbor.idx < R.shape[0] merged_kwargs = merge_dicts(kwargs, dynamic_kwargs) merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format, neighbor.idx, _species, merged_kwargs, param_combinators) out = fn(dR, **merged_kwargs) if out.ndim > mask.ndim: ddim = out.ndim - mask.ndim mask = jnp.reshape(mask, mask.shape + (1, ) * ddim) out *= mask if reduce_axis is None: return util.high_precision_sum(out) / normalization if 0 in reduce_axis and 1 not in reduce_axis: raise ValueError() if not partition.is_sparse(neighbor.format): return util.high_precision_sum(out, reduce_axis) / normalization _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1) if 0 in reduce_axis: return util.high_precision_sum(out, (0, ) + _reduce_axis) if neighbor.format is partition.OrderedSparse: raise ValueError( 'Cannot report per-particle values with a neighbor ' 'list whose format is `OrderedSparse`. Please use ' 'either `Dense` or `Sparse`.') out = util.high_precision_sum(out, _reduce_axis) return ops.segment_sum(out, neighbor.idx[0], R.shape[0]) / normalization
def _get_neighborhood_species_params(format: partition.NeighborListFormat, idx: Array, species: Array, params: Array) -> Array: """Get parameters for interactions between species pairs.""" # TODO(schsam): We should do better error checking here. def lookup(species_a, species_b): return params[species_a, species_b] if util.is_array(params): if len(params.shape) == 2: if partition.is_sparse(format): return space.map_bond(lookup)(species[idx[0]], species[idx[1]]) else: lookup = vmap(vmap(lookup, (None, 0))) return lookup(species, species[idx]) elif len(params.shape) == 0: return params else: raise ValueError( 'Params must be a scalar or a 2d array if using a species lookup.' ) return params