Ejemplo n.º 1
0
  def testSegmentSum(self):
    data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3])
    segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3])

    # test with explicit num_segments
    ans = ops.segment_sum(data, segment_ids, num_segments=4)
    expected = jnp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test with explicit num_segments larger than the higher index.
    ans = ops.segment_sum(data, segment_ids, num_segments=5)
    expected = jnp.array([13, 2, 7, 4, 0])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test without explicit num_segments
    ans = ops.segment_sum(data, segment_ids)
    expected = jnp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test with negative segment ids and segment ids larger than num_segments,
    # that will be wrapped with the `mod`.
    segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3])
    ans = ops.segment_sum(data, segment_ids, num_segments=4)
    expected = jnp.array([5, 2, 3, 3])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test with negative segment ids and without without explicit num_segments
    # such as num_segments is defined by the smaller index.
    segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
    ans = ops.segment_sum(data, segment_ids)
    expected = jnp.array([0, 0, 0, 13, 2, 7])
    self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 2
0
  def testSegmentSum(self):
    data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
    segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

    # test with explicit num_segments
    ans = ops.segment_sum(data, segment_ids, num_segments=4)
    expected = onp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test without explicit num_segments
    ans = ops.segment_sum(data, segment_ids)
    expected = onp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 3
0
        def g_fn(R, neighbor):
            N, dim = R.shape
            g_R = []
            mask = partition.neighbor_list_mask(neighbor)
            if neighbor.format is partition.Dense:
                neighbor_species = species[neighbor.idx]
                R_neigh = R[neighbor.idx]
                d = space.map_neighbor(metric)
                _pairwise = vmap(vmap(pairwise, (0, None)), (0, None))
                for s in species_types:
                    mask_s = mask * (neighbor_species == s)
                    g_R += [
                        jnp.sum(mask_s[:, :, jnp.newaxis] *
                                _pairwise(d(R, R_neigh), dim),
                                axis=(1, ))
                    ]
            elif neighbor.format is partition.Sparse:
                neighbor_species = species[neighbor.idx[1]]
                dr = space.map_bond(metric)(R[neighbor.idx[0]],
                                            R[neighbor.idx[1]])
                _pairwise = vmap(pairwise, (0, None))
                for s in species_types:
                    mask_s = mask * (neighbor_species == s)
                    g_R += [
                        ops.segment_sum(mask_s[:, None] * _pairwise(dr, dim),
                                        neighbor.idx[0], N)
                    ]
            else:
                raise NotImplementedError(
                    'Pair correlation function does not support '
                    'OrderedSparse neighbor lists.')

            return g_R
Ejemplo n.º 4
0
    def fn_mapped(R: Array, neighbor: partition.NeighborList,
                  **dynamic_kwargs) -> Array:
        d = partial(displacement_or_metric, **dynamic_kwargs)
        _species = dynamic_kwargs.get('species', species)

        normalization = 2.0

        if partition.is_sparse(neighbor.format):
            d = space.map_bond(d)
            dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]])
            mask = neighbor.idx[0] < R.shape[0]
            if neighbor.format is partition.OrderedSparse:
                normalization = 1.0
        else:
            d = space.map_neighbor(d)
            R_neigh = R[neighbor.idx]
            dR = d(R, R_neigh)
            mask = neighbor.idx < R.shape[0]

        merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
        merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format,
                                                       neighbor.idx, _species,
                                                       merged_kwargs,
                                                       param_combinators)
        out = fn(dR, **merged_kwargs)
        if out.ndim > mask.ndim:
            ddim = out.ndim - mask.ndim
            mask = jnp.reshape(mask, mask.shape + (1, ) * ddim)
        out *= mask

        if reduce_axis is None:
            return util.high_precision_sum(out) / normalization

        if 0 in reduce_axis and 1 not in reduce_axis:
            raise ValueError()

        if not partition.is_sparse(neighbor.format):
            return util.high_precision_sum(out, reduce_axis) / normalization

        _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1)

        if 0 in reduce_axis:
            return util.high_precision_sum(out, (0, ) + _reduce_axis)

        if neighbor.format is partition.OrderedSparse:
            raise ValueError(
                'Cannot report per-particle values with a neighbor '
                'list whose format is `OrderedSparse`. Please use '
                'either `Dense` or `Sparse`.')

        out = util.high_precision_sum(out, _reduce_axis)
        return ops.segment_sum(out, neighbor.idx[0],
                               R.shape[0]) / normalization
Ejemplo n.º 5
0
    def sym_fn(R: Array,
               neighbor: NeighborList,
               mask_i: Array = None,
               mask_j: Array = None,
               **kwargs) -> Array:
        D_fn = partial(displacement, **kwargs)

        if neighbor.format is partition.Dense:
            D_fn = space.map_neighbor(D_fn)

            R_neigh = R[neighbor.idx]

            dR = D_fn(R, R_neigh)
            _all_pairs_angular = vmap(
                vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0)
            all_angular = _all_pairs_angular(dR, dR)

            mask_i = True if mask_i is None else mask_i[neighbor.idx]
            mask_j = True if mask_j is None else mask_j[neighbor.idx]

            mask_i = (neighbor.idx < R.shape[0]) & mask_i
            mask_i = mask_i[:, :, jnp.newaxis, jnp.newaxis]
            mask_j = (neighbor.idx < R.shape[0]) & mask_j
            mask_j = mask_j[:, jnp.newaxis, :, jnp.newaxis]

            return util.high_precision_sum(all_angular * mask_i * mask_j,
                                           axis=[1, 2])
        elif neighbor.format is partition.Sparse:
            D_fn = space.map_bond(D_fn)
            dR = D_fn(R[neighbor.idx[0]], R[neighbor.idx[1]])
            _all_pairs_angular = vmap(vmap(_batched_angular_fn, (0, None)),
                                      (None, 0))
            all_angular = _all_pairs_angular(dR, dR)

            N = R.shape[0]
            mask_i = True if mask_i is None else mask_i[neighbor.idx[1]]
            mask_j = True if mask_j is None else mask_j[neighbor.idx[1]]
            mask_i = (neighbor.idx[0] < N) & mask_i
            mask_j = (neighbor.idx[0] < N) & mask_j

            mask = mask_i[:, None] & mask_j[None, :]
            mask = mask[:, :, None, None]
            all_angular = jnp.reshape(all_angular,
                                      (-1, ) + all_angular.shape[2:])
            neighbor_idx = jnp.repeat(neighbor.idx[0], len(neighbor.idx[0]))
            out = ops.segment_sum(all_angular, neighbor_idx, N)
            return out
        else:
            raise ValueError()
Ejemplo n.º 6
0
def count_cell_filling(position: Array, box_size: Box,
                       minimum_cell_size: float) -> Array:
    """Counts the number of particles per-cell in a spatial partition."""
    dim = int(position.shape[1])
    box_size, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    particle_index = jnp.array(position / cell_size, dtype=i32)
    particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1)

    filling = ops.segment_sum(jnp.ones_like(particle_hash), particle_hash,
                              cell_count)
    return filling
Ejemplo n.º 7
0
def sparse_mean_pooling(node_feats: jnp.ndarray,
                        graph_idx: jnp.ndarray) -> jnp.ndarray:
    """Mean pooling function for sparse pattern graph data.

    node_feats : ndarray of shape (N, in_feats)
        Batch input node features.
        N is the total number of nodes in the batch.
    graph_idx : ndarray of shape (N,)
        This idx indicate a graph number for node_feats in the batch.
        When the two nodes shows the same graph idx, these belong to the same graph.

    Returns
    -------
    ndarray of shape (batch_size, in_feats)
        Batch graph features.
    """
    num_nodes = node_feats.shape[0]
    batch_size = graph_idx[-1] + 1
    n_atom = segment_sum(jnp.ones(num_nodes),
                         graph_idx,
                         num_segments=batch_size)
    n_atom = jnp.expand_dims(n_atom, 1)
    sum_nodes = segment_sum(node_feats, graph_idx, num_segments=batch_size)
    return sum_nodes / n_atom
Ejemplo n.º 8
0
def sparse_sum_pooling(node_feats: jnp.ndarray,
                       graph_idx: jnp.ndarray) -> jnp.ndarray:
    """Sum pooling function for sparse pattern graph data.

    node_feats : ndarray of shape (N, in_feats)
        Batch input node features.
        N is the total number of nodes in the batch.
    graph_idx : ndarray of shape (N,)
        This idx indicate a graph number for node_feats in the batch.
        When the two nodes shows the same graph idx, these belong to the same graph.

    Returns
    -------
    ndarray of shape (batch_size, in_feats)
        Batch graph features.
    """
    batch_size = graph_idx[-1] + 1
    return segment_sum(node_feats, graph_idx, num_segments=batch_size)
Ejemplo n.º 9
0
 def g_fn(R, neighbor):
   N, dim = R.shape
   mask = partition.neighbor_list_mask(neighbor)
   if neighbor.format is partition.Dense:
     R_neigh = R[neighbor.idx]
     d = space.map_neighbor(metric)
     _pairwise = vmap(vmap(pairwise, (0, None)), (0, None))
     return jnp.sum(mask[:, :, None] *
                    _pairwise(d(R, R_neigh), dim), axis=(1,))
   elif neighbor.format is partition.Sparse:
     dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]])
     _pairwise = vmap(pairwise, (0, None))
     return ops.segment_sum(mask[:, None] * _pairwise(dr, dim),
                            neighbor.idx[0],
                            N)
   else:
     raise NotImplementedError('Pair correlation function does not support '
                               'OrderedSparse neighbor lists.')
Ejemplo n.º 10
0
def to_dense(neighbor: NeighborList) -> Array:
  """Converts a sparse neighbor list to dense ids. Cannot be JIT."""
  if neighbor.format is not Sparse:
    raise ValueError('Can only convert sparse neighbor lists to dense ones.')

  receivers, senders = neighbor.idx
  mask = neighbor_list_mask(neighbor)

  receivers = receivers[mask]
  senders = senders[mask]

  N = len(neighbor.reference_position)
  count = ops.segment_sum(jnp.ones(len(receivers), i32), receivers, N)
  max_count = jnp.max(count)
  offset = jnp.tile(jnp.arange(max_count), N)[:len(senders)]
  hashes = senders * max_count + offset
  dense_idx = N * jnp.ones((N * max_count,), i32)
  dense_idx = dense_idx.at[hashes].set(receivers).reshape((N, max_count))
  return dense_idx
Ejemplo n.º 11
0
 def sym_fn(R: Array,
            neighbor: NeighborList,
            mask: Array = None,
            **kwargs) -> Array:
     _metric = partial(metric, **kwargs)
     if neighbor.format is partition.Dense:
         _metric = space.map_neighbor(_metric)
         R_neigh = R[neighbor.idx]
         mask = True if mask is None else mask[neighbor.idx]
         mask = (neighbor.idx < R.shape[0])[None, :, :] & mask
         dr = _metric(R, R_neigh)
         return util.high_precision_sum(radial_fn(etas, dr) * mask,
                                        axis=2).T
     elif neighbor.format is partition.Sparse:
         _metric = space.map_bond(_metric)
         dr = _metric(R[neighbor.idx[0]], R[neighbor.idx[1]])
         radial = radial_fn(etas, dr).T
         N = R.shape[0]
         mask = True if mask is None else mask[neighbor.idx[1]]
         mask = (neighbor.idx[0] < N) & mask
         return ops.segment_sum(radial * mask[:, None], neighbor.idx[0], N)
     else:
         raise ValueError()
Ejemplo n.º 12
0
    def cell_list_fn(position: Array,
                     capacity_overflow_update: Optional[Tuple[
                         int, bool, Callable[..., CellList]]] = None,
                     extra_capacity: int = 0,
                     **kwargs) -> CellList:
        N = position.shape[0]
        dim = position.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                f'Cell list spatial dimension must be 2 or 3. Found {dim}.')

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if capacity_overflow_update is None:
            cell_capacity = _estimate_cell_capacity(position, box_size,
                                                    cell_size,
                                                    buffer_size_multiplier)
            cell_capacity += extra_capacity
            overflow = False
            update_fn = cell_list_fn
        else:
            cell_capacity, overflow, update_fn = capacity_overflow_update

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(i32, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        cell_position = jnp.zeros((cell_count * cell_capacity, dim),
                                  dtype=position.dtype)
        cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32)

        # It might be worth adding an occupied mask. However, that will involve
        # more compute since often we will do a mask for species that will include
        # an occupancy test. It seems easier to design around this empty_data_value
        # for now and revisit the issue if it comes up later.
        empty_kwarg_value = 10**5
        cell_kwargs = {}
        #  pytype: disable=attribute-error
        for k, v in kwargs.items():
            if not util.is_array(v):
                raise ValueError(
                    (f'Data must be specified as an ndarray. Found "{k}" '
                     f'with type {type(v)}.'))
            if v.shape[0] != position.shape[0]:
                raise ValueError(
                    ('Data must be specified per-particle (an ndarray '
                     f'with shape ({N}, ...)). Found "{k}" with '
                     f'shape {v.shape}.'))
            kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, )
            cell_kwargs[k] = empty_kwarg_value * jnp.ones(
                (cell_count * cell_capacity, ) + kwarg_shape, v.dtype)
        #  pytype: enable=attribute-error
        indices = jnp.array(position / cell_size, dtype=i32)
        hashes = jnp.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where
        # grid_id is a flat list that repeats 0, .., cell_capacity. So long as
        # there are fewer than cell_capacity particles per cell, each particle is
        # guarenteed to get a cell id that is unique.
        sort_map = jnp.argsort(hashes)
        sorted_position = position[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_kwargs = {}
        for k, v in kwargs.items():
            sorted_kwargs[k] = v[sort_map]

        sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
        sorted_id = jnp.reshape(sorted_id, (N, 1))
        cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
        cell_position = _unflatten_cell_buffer(cell_position, cells_per_side,
                                               dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        for k, v in sorted_kwargs.items():
            if v.ndim == 1:
                v = jnp.reshape(v, v.shape + (1, ))
            cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
            cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k],
                                                    cells_per_side, dim)

        occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
        max_occupancy = jnp.max(occupancy)
        overflow = overflow | (max_occupancy >= cell_capacity)

        return CellList(cell_position, cell_id, cell_kwargs, overflow,
                        cell_capacity, update_fn)  # pytype: disable=wrong-arg-count
Ejemplo n.º 13
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 is_training: bool) -> jnp.ndarray:
        """Update node features.

        Parameters
        ----------
        node_feats : ndarray of shape (N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch
        adj : ndarray of shape (2, E)
            Batch adjacency list.
            E is the total number of edges in the batch
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        new_node_feats : ndarray of shape (N, out_feats)
            Batch new node features.
        """
        dropout = self.dropout if is_training is True else 0.0
        num_nodes = node_feats.shape[0]

        # affine transformation
        new_node_feats = jnp.dot(node_feats, self.w)
        if self.bias:
            new_node_feats += self.b

        # update nodes
        if self.normalize:
            # add self connection
            self_loop = jnp.tile(jnp.arange(num_nodes), (2, 1))
            adj = jnp.concatenate((adj, self_loop), axis=1)
            src_idx, dest_idx = adj[0], adj[1]

            # calculate the norm
            degree = segment_sum(jnp.ones(len(dest_idx)),
                                 dest_idx,
                                 num_segments=num_nodes)
            deg_inv_sqrt = jax.lax.pow(degree, -0.5)
            norm = deg_inv_sqrt[src_idx] * deg_inv_sqrt[dest_idx]

            # update nodes
            source_feats = jnp.take(new_node_feats, src_idx, axis=0)
            source_feats = norm.reshape(-1, 1) * source_feats
            new_node_feats = segment_sum(source_feats,
                                         dest_idx,
                                         num_segments=num_nodes)
        else:
            src_idx, dest_idx = adj[0], adj[1]
            source_feats = jnp.take(new_node_feats, src_idx, axis=0)
            aggregated_messages = segment_sum(source_feats,
                                              dest_idx,
                                              num_segments=num_nodes)
            new_node_feats = jnp.add(aggregated_messages, new_node_feats)

        new_node_feats = self.activation(new_node_feats)

        if dropout != 0.0:
            new_node_feats = hk.dropout(hk.next_rng_key(), dropout,
                                        new_node_feats)
        if self.batch_norm:
            new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats,
                                                           is_training)

        return new_node_feats
Ejemplo n.º 14
0
 def compute_centroids(self, coords):
     """Returns an array containing the centroids of each group"""
     return segment_sum(coords, self.scatter_inds) / jnp.expand_dims(
         self.group_sizes, axis=1)