Beispiel #1
0
    def model(R, neighbor, **kwargs):
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_neighbor(d)
        R_neigh = R[neighbor.idx]
        dR = d(R, R_neigh)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = jnp.zeros((N, 1), R.dtype) if nodes is None else nodes

        _globals = jnp.zeros((1, ), R.dtype)

        dr_2 = space.square_distance(dR)
        edge_idx = jnp.where(dr_2 < r_cutoff**2, neighbor.idx, N)

        net = network(n_recurrences=n_recurrences,
                      mlp_sizes=mlp_sizes,
                      mlp_kwargs=mlp_kwargs,
                      polymer_length=polymer_length,
                      polymer_dimensions=polymer_dimensions)

        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))  # pytype: disable=wrong-arg-count
Beispiel #2
0
def single_pair_angular_symmetry_function(dR12, dR13, eta, lam, zeta,
                                          cutoff_distance):
    """Computes the angular symmetry function due to one pair of neighbors."""

    dR23 = dR12 - dR13
    dr12_2 = space.square_distance(dR12)
    dr13_2 = space.square_distance(dR13)
    dr23_2 = space.square_distance(dR23)
    dr12 = space.distance(dR12)
    dr13 = space.distance(dR13)
    dr23 = space.distance(dR23)
    triplet_squared_distances = dr12_2 + dr13_2 + dr23_2
    triplet_cutoff = reduce(
        lambda x, y: x * _behler_parrinello_cutoff_fn(y, cutoff_distance),
        [dr12, dr13, dr23], 1.0)
    result = 2.0 ** (1.0 - zeta) * (
        1.0 + lam * quantity.angle_between_two_vectors(dR12, dR13)) ** zeta * \
        np.exp(-eta * triplet_squared_distances) * triplet_cutoff
    return result
Beispiel #3
0
def _displacement_or_metric_to_metric_sq(displacement_or_metric):
  """Checks whether or not a displacement or metric was provided."""
  for dim in range(1, 4):
    try:
      R = ShapedArray((dim,), f32)
      dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0)
      if len(dR_or_dr.shape) == 0:
        return lambda Ra, Rb, **kwargs: \
          displacement_or_metric(Ra, Rb, **kwargs) ** 2
      else:
        return lambda Ra, Rb, **kwargs: space.square_distance(
          displacement_or_metric(Ra, Rb, **kwargs))
    except TypeError:
      continue
    except ValueError:
      continue
  raise ValueError(
    'Canonicalize displacement not implemented for spatial dimension larger'
    'than 4.')
Beispiel #4
0
    def model(R, neighbor, **kwargs):
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_neighbor(d)
        R_neigh = R[neighbor.idx]
        dR = d(R, R_neigh)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes

        _globals = np.zeros((1, ), R.dtype)

        dr_2 = space.square_distance(dR)
        edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N)

        net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs)
        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))
Beispiel #5
0
    def model(R: Array, **kwargs) -> Array:
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_product(d)
        dR = d(R, R)

        dr_2 = space.square_distance(dR)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes

        edge_idx = np.broadcast_to(np.arange(N)[np.newaxis, :], (N, N))
        edge_idx = np.where(dr_2 < r_cutoff**2, edge_idx, N)

        _globals = np.zeros((1, ), R.dtype)

        net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs)
        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))  # pytype: disable=wrong-arg-count