def testSegmentSum(self): data = jnp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = jnp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = jnp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test with explicit num_segments larger than the higher index. ans = ops.segment_sum(data, segment_ids, num_segments=5) expected = jnp.array([13, 2, 7, 4, 0]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = jnp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test with negative segment ids and segment ids larger than num_segments, # that will be wrapped with the `mod`. segment_ids = jnp.array([0, 4, 8, 1, 2, -6, -1, 3]) ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = jnp.array([5, 2, 3, 3]) self.assertAllClose(ans, expected, check_dtypes=False) # test with negative segment ids and without without explicit num_segments # such as num_segments is defined by the smaller index. segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6]) ans = ops.segment_sum(data, segment_ids) expected = jnp.array([0, 0, 0, 13, 2, 7]) self.assertAllClose(ans, expected, check_dtypes=False)
def testSegmentSum(self): data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False)
def g_fn(R, neighbor): N, dim = R.shape g_R = [] mask = partition.neighbor_list_mask(neighbor) if neighbor.format is partition.Dense: neighbor_species = species[neighbor.idx] R_neigh = R[neighbor.idx] d = space.map_neighbor(metric) _pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ jnp.sum(mask_s[:, :, jnp.newaxis] * _pairwise(d(R, R_neigh), dim), axis=(1, )) ] elif neighbor.format is partition.Sparse: neighbor_species = species[neighbor.idx[1]] dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]]) _pairwise = vmap(pairwise, (0, None)) for s in species_types: mask_s = mask * (neighbor_species == s) g_R += [ ops.segment_sum(mask_s[:, None] * _pairwise(dr, dim), neighbor.idx[0], N) ] else: raise NotImplementedError( 'Pair correlation function does not support ' 'OrderedSparse neighbor lists.') return g_R
def fn_mapped(R: Array, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array: d = partial(displacement_or_metric, **dynamic_kwargs) _species = dynamic_kwargs.get('species', species) normalization = 2.0 if partition.is_sparse(neighbor.format): d = space.map_bond(d) dR = d(R[neighbor.idx[0]], R[neighbor.idx[1]]) mask = neighbor.idx[0] < R.shape[0] if neighbor.format is partition.OrderedSparse: normalization = 1.0 else: d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) mask = neighbor.idx < R.shape[0] merged_kwargs = merge_dicts(kwargs, dynamic_kwargs) merged_kwargs = _neighborhood_kwargs_to_params(neighbor.format, neighbor.idx, _species, merged_kwargs, param_combinators) out = fn(dR, **merged_kwargs) if out.ndim > mask.ndim: ddim = out.ndim - mask.ndim mask = jnp.reshape(mask, mask.shape + (1, ) * ddim) out *= mask if reduce_axis is None: return util.high_precision_sum(out) / normalization if 0 in reduce_axis and 1 not in reduce_axis: raise ValueError() if not partition.is_sparse(neighbor.format): return util.high_precision_sum(out, reduce_axis) / normalization _reduce_axis = tuple(a - 1 for a in reduce_axis if a > 1) if 0 in reduce_axis: return util.high_precision_sum(out, (0, ) + _reduce_axis) if neighbor.format is partition.OrderedSparse: raise ValueError( 'Cannot report per-particle values with a neighbor ' 'list whose format is `OrderedSparse`. Please use ' 'either `Dense` or `Sparse`.') out = util.high_precision_sum(out, _reduce_axis) return ops.segment_sum(out, neighbor.idx[0], R.shape[0]) / normalization
def sym_fn(R: Array, neighbor: NeighborList, mask_i: Array = None, mask_j: Array = None, **kwargs) -> Array: D_fn = partial(displacement, **kwargs) if neighbor.format is partition.Dense: D_fn = space.map_neighbor(D_fn) R_neigh = R[neighbor.idx] dR = D_fn(R, R_neigh) _all_pairs_angular = vmap( vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0) all_angular = _all_pairs_angular(dR, dR) mask_i = True if mask_i is None else mask_i[neighbor.idx] mask_j = True if mask_j is None else mask_j[neighbor.idx] mask_i = (neighbor.idx < R.shape[0]) & mask_i mask_i = mask_i[:, :, jnp.newaxis, jnp.newaxis] mask_j = (neighbor.idx < R.shape[0]) & mask_j mask_j = mask_j[:, jnp.newaxis, :, jnp.newaxis] return util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2]) elif neighbor.format is partition.Sparse: D_fn = space.map_bond(D_fn) dR = D_fn(R[neighbor.idx[0]], R[neighbor.idx[1]]) _all_pairs_angular = vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)) all_angular = _all_pairs_angular(dR, dR) N = R.shape[0] mask_i = True if mask_i is None else mask_i[neighbor.idx[1]] mask_j = True if mask_j is None else mask_j[neighbor.idx[1]] mask_i = (neighbor.idx[0] < N) & mask_i mask_j = (neighbor.idx[0] < N) & mask_j mask = mask_i[:, None] & mask_j[None, :] mask = mask[:, :, None, None] all_angular = jnp.reshape(all_angular, (-1, ) + all_angular.shape[2:]) neighbor_idx = jnp.repeat(neighbor.idx[0], len(neighbor.idx[0])) out = ops.segment_sum(all_angular, neighbor_idx, N) return out else: raise ValueError()
def count_cell_filling(position: Array, box_size: Box, minimum_cell_size: float) -> Array: """Counts the number of particles per-cell in a spatial partition.""" dim = int(position.shape[1]) box_size, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) hash_multipliers = _compute_hash_constants(dim, cells_per_side) particle_index = jnp.array(position / cell_size, dtype=i32) particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1) filling = ops.segment_sum(jnp.ones_like(particle_hash), particle_hash, cell_count) return filling
def sparse_mean_pooling(node_feats: jnp.ndarray, graph_idx: jnp.ndarray) -> jnp.ndarray: """Mean pooling function for sparse pattern graph data. node_feats : ndarray of shape (N, in_feats) Batch input node features. N is the total number of nodes in the batch. graph_idx : ndarray of shape (N,) This idx indicate a graph number for node_feats in the batch. When the two nodes shows the same graph idx, these belong to the same graph. Returns ------- ndarray of shape (batch_size, in_feats) Batch graph features. """ num_nodes = node_feats.shape[0] batch_size = graph_idx[-1] + 1 n_atom = segment_sum(jnp.ones(num_nodes), graph_idx, num_segments=batch_size) n_atom = jnp.expand_dims(n_atom, 1) sum_nodes = segment_sum(node_feats, graph_idx, num_segments=batch_size) return sum_nodes / n_atom
def sparse_sum_pooling(node_feats: jnp.ndarray, graph_idx: jnp.ndarray) -> jnp.ndarray: """Sum pooling function for sparse pattern graph data. node_feats : ndarray of shape (N, in_feats) Batch input node features. N is the total number of nodes in the batch. graph_idx : ndarray of shape (N,) This idx indicate a graph number for node_feats in the batch. When the two nodes shows the same graph idx, these belong to the same graph. Returns ------- ndarray of shape (batch_size, in_feats) Batch graph features. """ batch_size = graph_idx[-1] + 1 return segment_sum(node_feats, graph_idx, num_segments=batch_size)
def g_fn(R, neighbor): N, dim = R.shape mask = partition.neighbor_list_mask(neighbor) if neighbor.format is partition.Dense: R_neigh = R[neighbor.idx] d = space.map_neighbor(metric) _pairwise = vmap(vmap(pairwise, (0, None)), (0, None)) return jnp.sum(mask[:, :, None] * _pairwise(d(R, R_neigh), dim), axis=(1,)) elif neighbor.format is partition.Sparse: dr = space.map_bond(metric)(R[neighbor.idx[0]], R[neighbor.idx[1]]) _pairwise = vmap(pairwise, (0, None)) return ops.segment_sum(mask[:, None] * _pairwise(dr, dim), neighbor.idx[0], N) else: raise NotImplementedError('Pair correlation function does not support ' 'OrderedSparse neighbor lists.')
def to_dense(neighbor: NeighborList) -> Array: """Converts a sparse neighbor list to dense ids. Cannot be JIT.""" if neighbor.format is not Sparse: raise ValueError('Can only convert sparse neighbor lists to dense ones.') receivers, senders = neighbor.idx mask = neighbor_list_mask(neighbor) receivers = receivers[mask] senders = senders[mask] N = len(neighbor.reference_position) count = ops.segment_sum(jnp.ones(len(receivers), i32), receivers, N) max_count = jnp.max(count) offset = jnp.tile(jnp.arange(max_count), N)[:len(senders)] hashes = senders * max_count + offset dense_idx = N * jnp.ones((N * max_count,), i32) dense_idx = dense_idx.at[hashes].set(receivers).reshape((N, max_count)) return dense_idx
def sym_fn(R: Array, neighbor: NeighborList, mask: Array = None, **kwargs) -> Array: _metric = partial(metric, **kwargs) if neighbor.format is partition.Dense: _metric = space.map_neighbor(_metric) R_neigh = R[neighbor.idx] mask = True if mask is None else mask[neighbor.idx] mask = (neighbor.idx < R.shape[0])[None, :, :] & mask dr = _metric(R, R_neigh) return util.high_precision_sum(radial_fn(etas, dr) * mask, axis=2).T elif neighbor.format is partition.Sparse: _metric = space.map_bond(_metric) dr = _metric(R[neighbor.idx[0]], R[neighbor.idx[1]]) radial = radial_fn(etas, dr).T N = R.shape[0] mask = True if mask is None else mask[neighbor.idx[1]] mask = (neighbor.idx[0] < N) & mask return ops.segment_sum(radial * mask[:, None], neighbor.idx[0], N) else: raise ValueError()
def cell_list_fn(position: Array, capacity_overflow_update: Optional[Tuple[ int, bool, Callable[..., CellList]]] = None, extra_capacity: int = 0, **kwargs) -> CellList: N = position.shape[0] dim = position.shape[1] if dim != 2 and dim != 3: # NOTE(schsam): Do we want to check this in compute_fn as well? raise ValueError( f'Cell list spatial dimension must be 2 or 3. Found {dim}.') _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if capacity_overflow_update is None: cell_capacity = _estimate_cell_capacity(position, box_size, cell_size, buffer_size_multiplier) cell_capacity += extra_capacity overflow = False update_fn = cell_list_fn else: cell_capacity, overflow, update_fn = capacity_overflow_update hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(i32, N) # NOTE(schsam): We use the convention that particles that are successfully, # copied have their true id whereas particles empty slots have id = N. # Then when we copy data back from the grid, copy it to an array of shape # [N + 1, output_dimension] and then truncate it to an array of shape # [N, output_dimension] which ignores the empty slots. cell_position = jnp.zeros((cell_count * cell_capacity, dim), dtype=position.dtype) cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32) # It might be worth adding an occupied mask. However, that will involve # more compute since often we will do a mask for species that will include # an occupancy test. It seems easier to design around this empty_data_value # for now and revisit the issue if it comes up later. empty_kwarg_value = 10**5 cell_kwargs = {} # pytype: disable=attribute-error for k, v in kwargs.items(): if not util.is_array(v): raise ValueError( (f'Data must be specified as an ndarray. Found "{k}" ' f'with type {type(v)}.')) if v.shape[0] != position.shape[0]: raise ValueError( ('Data must be specified per-particle (an ndarray ' f'with shape ({N}, ...)). Found "{k}" with ' f'shape {v.shape}.')) kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, ) cell_kwargs[k] = empty_kwarg_value * jnp.ones( (cell_count * cell_capacity, ) + kwarg_shape, v.dtype) # pytype: enable=attribute-error indices = jnp.array(position / cell_size, dtype=i32) hashes = jnp.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where # grid_id is a flat list that repeats 0, .., cell_capacity. So long as # there are fewer than cell_capacity particles per cell, each particle is # guarenteed to get a cell id that is unique. sort_map = jnp.argsort(hashes) sorted_position = position[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_kwargs = {} for k, v in kwargs.items(): sorted_kwargs[k] = v[sort_map] sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_position = cell_position.at[sorted_cell_id].set(sorted_position) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1, )) cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v) cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k], cells_per_side, dim) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count) max_occupancy = jnp.max(occupancy) overflow = overflow | (max_occupancy >= cell_capacity) return CellList(cell_position, cell_id, cell_kwargs, overflow, cell_capacity, update_fn) # pytype: disable=wrong-arg-count
def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray: """Update node features. Parameters ---------- node_feats : ndarray of shape (N, in_feats) Batch input node features. N is the total number of nodes in the batch adj : ndarray of shape (2, E) Batch adjacency list. E is the total number of edges in the batch is_training : bool Whether the model is training or not. Returns ------- new_node_feats : ndarray of shape (N, out_feats) Batch new node features. """ dropout = self.dropout if is_training is True else 0.0 num_nodes = node_feats.shape[0] # affine transformation new_node_feats = jnp.dot(node_feats, self.w) if self.bias: new_node_feats += self.b # update nodes if self.normalize: # add self connection self_loop = jnp.tile(jnp.arange(num_nodes), (2, 1)) adj = jnp.concatenate((adj, self_loop), axis=1) src_idx, dest_idx = adj[0], adj[1] # calculate the norm degree = segment_sum(jnp.ones(len(dest_idx)), dest_idx, num_segments=num_nodes) deg_inv_sqrt = jax.lax.pow(degree, -0.5) norm = deg_inv_sqrt[src_idx] * deg_inv_sqrt[dest_idx] # update nodes source_feats = jnp.take(new_node_feats, src_idx, axis=0) source_feats = norm.reshape(-1, 1) * source_feats new_node_feats = segment_sum(source_feats, dest_idx, num_segments=num_nodes) else: src_idx, dest_idx = adj[0], adj[1] source_feats = jnp.take(new_node_feats, src_idx, axis=0) aggregated_messages = segment_sum(source_feats, dest_idx, num_segments=num_nodes) new_node_feats = jnp.add(aggregated_messages, new_node_feats) new_node_feats = self.activation(new_node_feats) if dropout != 0.0: new_node_feats = hk.dropout(hk.next_rng_key(), dropout, new_node_feats) if self.batch_norm: new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats, is_training) return new_node_feats
def compute_centroids(self, coords): """Returns an array containing the centroids of each group""" return segment_sum(coords, self.scatter_inds) / jnp.expand_dims( self.group_sizes, axis=1)