def model(R, neighbor, **kwargs): N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = jnp.zeros((N, 1), R.dtype) if nodes is None else nodes _globals = jnp.zeros((1, ), R.dtype) dr_2 = space.square_distance(dR) edge_idx = jnp.where(dr_2 < r_cutoff**2, neighbor.idx, N) net = network(n_recurrences=n_recurrences, mlp_sizes=mlp_sizes, mlp_kwargs=mlp_kwargs, polymer_length=polymer_length, polymer_dimensions=polymer_dimensions) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx)) # pytype: disable=wrong-arg-count
def single_pair_angular_symmetry_function(dR12, dR13, eta, lam, zeta, cutoff_distance): """Computes the angular symmetry function due to one pair of neighbors.""" dR23 = dR12 - dR13 dr12_2 = space.square_distance(dR12) dr13_2 = space.square_distance(dR13) dr23_2 = space.square_distance(dR23) dr12 = space.distance(dR12) dr13 = space.distance(dR13) dr23 = space.distance(dR23) triplet_squared_distances = dr12_2 + dr13_2 + dr23_2 triplet_cutoff = reduce( lambda x, y: x * _behler_parrinello_cutoff_fn(y, cutoff_distance), [dr12, dr13, dr23], 1.0) result = 2.0 ** (1.0 - zeta) * ( 1.0 + lam * quantity.angle_between_two_vectors(dR12, dR13)) ** zeta * \ np.exp(-eta * triplet_squared_distances) * triplet_cutoff return result
def _displacement_or_metric_to_metric_sq(displacement_or_metric): """Checks whether or not a displacement or metric was provided.""" for dim in range(1, 4): try: R = ShapedArray((dim,), f32) dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0) if len(dR_or_dr.shape) == 0: return lambda Ra, Rb, **kwargs: \ displacement_or_metric(Ra, Rb, **kwargs) ** 2 else: return lambda Ra, Rb, **kwargs: space.square_distance( displacement_or_metric(Ra, Rb, **kwargs)) except TypeError: continue except ValueError: continue raise ValueError( 'Canonicalize displacement not implemented for spatial dimension larger' 'than 4.')
def model(R, neighbor, **kwargs): N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_neighbor(d) R_neigh = R[neighbor.idx] dR = d(R, R_neigh) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes _globals = np.zeros((1, ), R.dtype) dr_2 = space.square_distance(dR) edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N) net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))
def model(R: Array, **kwargs) -> Array: N = R.shape[0] d = partial(displacement_fn, **kwargs) d = space.map_product(d) dR = d(R, R) dr_2 = space.square_distance(dR) if 'nodes' in kwargs: _nodes = _canonicalize_node_state(kwargs['nodes']) else: _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes edge_idx = np.broadcast_to(np.arange(N)[np.newaxis, :], (N, N)) edge_idx = np.where(dr_2 < r_cutoff**2, edge_idx, N) _globals = np.zeros((1, ), R.dtype) net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs) return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx)) # pytype: disable=wrong-arg-count