Example #1
0
def _get_neighborhood_matrix_params(format: partition.NeighborListFormat,
                                    idx: Array, params: Array,
                                    combinator: Callable) -> Array:
    if util.is_array(params):
        if len(params.shape) == 1:
            if partition.is_sparse(format):
                return space.map_bond(combinator)(params[idx[0]],
                                                  params[idx[1]])
            else:
                return combinator(params[:, None], params[idx])
                return space.map_neighbor(combinator)(params, params[idx])
        elif len(params.shape) == 2:

            def query(id_a, id_b):
                return params[id_a, id_b]

            if partition.is_sparse(format):
                return space.map_bond(query)(idx[0], idx[1])
            else:
                query = vmap(vmap(query, (None, 0)))
                return query(jnp.arange(idx.shape[0], dtype=jnp.int32), idx)
        elif len(params.shape) == 0:
            return params
        else:
            raise NotImplementedError()
    elif (isinstance(params, int) or isinstance(params, float)
          or jnp.issubdtype(params, jnp.integer)
          or jnp.issubdtype(params, jnp.floating)):
        return params
    else:
        raise NotImplementedError()
Example #2
0
    def fn_mapped(R: Array, neighbor: partition.NeighborList,
                  **dynamic_kwargs) -> Array:
        d = partial(displacement_or_metric, **dynamic_kwargs)
        _species = dynamic_kwargs.get('species', species)

        normalization = 2.0

        if partition.is_sparse(neighbor.format):
            d = space.map_bond(d)
            dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]])
            mask = neighbor.idx[0] < R.shape[0]
            if neighbor.format is partition.OrderedSparse:
                normalization = 1.0
        else:
            d = space.map_neighbor(d)
            R_neigh = R[neighbor.idx]
            dR = d(R, R_neigh)
            mask = neighbor.idx < R.shape[0]

        merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
        merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format,
                                                       neighbor.idx, _species,
                                                       merged_kwargs,
                                                       param_combinators)
        out = fn(dR, **merged_kwargs)
        if out.ndim > mask.ndim:
            ddim = out.ndim - mask.ndim
            mask = jnp.reshape(mask, mask.shape + (1, ) * ddim)
        out *= mask

        if reduce_axis is None:
            return util.high_precision_sum(out) / normalization

        if 0 in reduce_axis and 1 not in reduce_axis:
            raise ValueError()

        if not partition.is_sparse(neighbor.format):
            return util.high_precision_sum(out, reduce_axis) / normalization

        _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1)

        if 0 in reduce_axis:
            return util.high_precision_sum(out, (0, ) + _reduce_axis)

        if neighbor.format is partition.OrderedSparse:
            raise ValueError(
                'Cannot report per-particle values with a neighbor '
                'list whose format is `OrderedSparse`. Please use '
                'either `Dense` or `Sparse`.')

        out = util.high_precision_sum(out, _reduce_axis)
        return ops.segment_sum(out, neighbor.idx[0],
                               R.shape[0]) / normalization
Example #3
0
def _get_neighborhood_species_params(format: partition.NeighborListFormat,
                                     idx: Array, species: Array,
                                     params: Array) -> Array:
    """Get parameters for interactions between species pairs."""

    # TODO(schsam): We should do better error checking here.
    def lookup(species_a, species_b):
        return params[species_a, species_b]

    if util.is_array(params):
        if len(params.shape) == 2:
            if partition.is_sparse(format):
                return space.map_bond(lookup)(species[idx[0]], species[idx[1]])
            else:
                lookup = vmap(vmap(lookup, (None, 0)))
                return lookup(species, species[idx])
        elif len(params.shape) == 0:
            return params
        else:
            raise ValueError(
                'Params must be a scalar or a 2d array if using a species lookup.'
            )
    return params