def model(R, neighbor, **kwargs): N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = jnp.zeros((N, 1), R.dtype) if nodes is None else nodes _globals = jnp.zeros((1, ), R.dtype) dr_2 = space.square_distance(dR) edge_idx = jnp.where(dr_2 < r_cutoff**2, neighbor.idx, N) net = network(n_recurrences=n_recurrences, mlp_sizes=mlp_sizes, mlp_kwargs=mlp_kwargs, polymer_length=polymer_length, polymer_dimensions=polymer_dimensions) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx)) # pytype: disable=wrong-arg-count
def _get_graphs(): return [ nn.GraphTuple( nodes=np.array([[1.0], [2.0]]), edges=np.array([[[1.0], [2.0]], [[3.0], [4.0]]]), globals=np.array([1.0]), edge_idx=np.array([[0, 1,], [0, 1]]) ), nn.GraphTuple( nodes=np.array([[1.0], [2.0]]), edges=np.array([[[1.0], [2.0]], [[3.0], [4.0]]]), globals=np.array([1.0]), edge_idx=np.array([[0, 1,], [2, 1]]) ) ]
def model(R, neighbor, **kwargs): N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes _globals = np.zeros((1, ), R.dtype) dr_2 = space.square_distance(dR) edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N) net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))
def model(R: Array, **kwargs) -> Array: N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_product(d) dR = d(R, R) dr_2 = space.square_distance(dR) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes edge_idx = np.broadcast_to(np.arange(N)[np.newaxis, :], (N, N)) edge_idx = np.where(dr_2 < r_cutoff**2, edge_idx, N) _globals = np.zeros((1, ), R.dtype) net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx)) # pytype: disable=wrong-arg-count