Example #1
0
    def model(R, neighbor, **kwargs):
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_neighbor(d)
        R_neigh = R[neighbor.idx]
        dR = d(R, R_neigh)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = jnp.zeros((N, 1), R.dtype) if nodes is None else nodes

        _globals = jnp.zeros((1, ), R.dtype)

        dr_2 = space.square_distance(dR)
        edge_idx = jnp.where(dr_2 < r_cutoff**2, neighbor.idx, N)

        net = network(n_recurrences=n_recurrences,
                      mlp_sizes=mlp_sizes,
                      mlp_kwargs=mlp_kwargs,
                      polymer_length=polymer_length,
                      polymer_dimensions=polymer_dimensions)

        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))  # pytype: disable=wrong-arg-count
Example #2
0
def _get_graphs():
    return [
    nn.GraphTuple(
        nodes=np.array([[1.0], [2.0]]),
        edges=np.array([[[1.0], [2.0]],
                        [[3.0], [4.0]]]),
        globals=np.array([1.0]),
        edge_idx=np.array([[0, 1,], [0, 1]])
    ),
    nn.GraphTuple(
        nodes=np.array([[1.0], [2.0]]),
        edges=np.array([[[1.0], [2.0]],
                        [[3.0], [4.0]]]),
        globals=np.array([1.0]),
        edge_idx=np.array([[0, 1,], [2, 1]])
    )
  ]
Example #3
0
    def model(R, neighbor, **kwargs):
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_neighbor(d)
        R_neigh = R[neighbor.idx]
        dR = d(R, R_neigh)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes

        _globals = np.zeros((1, ), R.dtype)

        dr_2 = space.square_distance(dR)
        edge_idx = np.where(dr_2 < r_cutoff**2, neighbor.idx, N)

        net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs)
        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))
Example #4
0
    def model(R: Array, **kwargs) -> Array:
        N = R.shape[0]

        d = partial(displacement_fn, **kwargs)
        d = space.map_product(d)
        dR = d(R, R)

        dr_2 = space.square_distance(dR)

        if 'nodes' in kwargs:
            _nodes = _canonicalize_node_state(kwargs['nodes'])
        else:
            _nodes = np.zeros((N, 1), R.dtype) if nodes is None else nodes

        edge_idx = np.broadcast_to(np.arange(N)[np.newaxis, :], (N, N))
        edge_idx = np.where(dr_2 < r_cutoff**2, edge_idx, N)

        _globals = np.zeros((1, ), R.dtype)

        net = EnergyGraphNet(n_recurrences, mlp_sizes, mlp_kwargs)
        return net(nn.GraphTuple(_nodes, dR, _globals, edge_idx))  # pytype: disable=wrong-arg-count