예제 #1
0
def _check_dtypes_match(xs, ys):
  def _assert_dtypes_match(x, y):
    if config.x64_enabled:
      assert _dtype(x) == _dtype(y)
    else:
      assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
              _dtypes.canonicalize_dtype(_dtype(y)))
  tree_all(tree_multimap(_assert_dtypes_match, xs, ys))
예제 #2
0
def is_on_cpu(x: PyTree) -> bool:
  def _arr_is_on_cpu(x: np.ndarray) -> bool:
    # TODO(romann): revisit when https://github.com/google/jax/issues/1431 and
    # https://github.com/google/jax/issues/1432 are fixed.
    if hasattr(x, 'device_buffer'):
      return 'cpu' in str(x.device_buffer.device()).lower()

    if isinstance(x, np.ndarray):
      return True

    raise NotImplementedError(type(x))

  return tree_all(tree_map(_arr_is_on_cpu, x))
예제 #3
0
def _is_np_ndarray(x):
    return tree_all(tree_map(lambda y: isinstance(y, np.ndarray), x))
예제 #4
0
def all_none(x: Any, attr: str = None) -> bool:
  get_fn = (lambda x: x) if attr is None else lambda x: getattr(x, attr)
  return tree_all(tree_map(lambda x: get_fn(x) is None, x))
예제 #5
0
def _is_np_ndarray(x) -> bool:
    if x is None:
        return False
    return tree_all(
        tree_map(lambda y: isinstance(y, (onp.ndarray, np.ndarray)), x))
예제 #6
0
def _is_on_cpu(x):
    # Utility function from neural_tangents
    return tree_all(tree_map(_arr_is_on_cpu, x))
예제 #7
0
    def _ApplyGraphNet(
            graph: ShardedEdgesGraphsTuple) -> ShardedEdgesGraphsTuple:
        """Applies a configured GraphNetwork to a sharded graph.

    This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

    There is one difference. For the nodes update the class aggregates over the
    sender edges and receiver edges separately. This is a bit more general
    the algorithm described in the paper. The original behaviour can be
    recovered by using only the receiver edge aggregations for the update.

    In addition this implementation supports softmax attention over incoming
    edge features.


    Many popular Graph Neural Networks can be implemented as special cases of
    GraphNets, for more information please see the paper.

    Args:
      graph: a `GraphsTuple` containing the graph.

    Returns:
      Updated `GraphsTuple`.
    """
        # pylint: disable=g-long-lambda
        nodes, device_edges, device_receivers, device_senders, receivers, senders, globals_, device_n_edge, n_node, n_edge, device_graph_idx = graph
        # Equivalent to jnp.sum(n_node), but jittable.
        sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
        sum_device_n_edge = device_senders.shape[0]
        if not tree.tree_all(
                tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
            raise ValueError(
                'All node arrays in nest must contain the same number of nodes.'
            )

        sent_attributes = tree.tree_map(lambda n: n[device_senders], nodes)
        received_attributes = tree.tree_map(lambda n: n[device_receivers],
                                            nodes)
        # Here we scatter the global features to the corresponding edges,
        # giving us tensors of shape [num_edges, global_feat].
        global_edge_attributes = tree.tree_map(
            lambda g: jnp.repeat(g[device_graph_idx],
                                 device_n_edge,
                                 axis=0,
                                 total_repeat_length=sum_device_n_edge),
            globals_)

        if update_edge_fn:
            device_edges = update_edge_fn(device_edges, sent_attributes,
                                          received_attributes,
                                          global_edge_attributes)

        if attention_logit_fn:
            logits = attention_logit_fn(device_edges, sent_attributes,
                                        received_attributes,
                                        global_edge_attributes)
            tree_calculate_weights = functools.partial(utils.segment_softmax,
                                                       segment_ids=receivers,
                                                       num_segments=sum_n_node)
            weights = tree.tree_map(tree_calculate_weights, logits)
            device_edges = attention_reduce_fn(device_edges, weights)

        if update_node_fn:
            # Aggregations over nodes are assumed to take place over devices
            # specified by the axis_groups (e.g. with sharded_segment_sum).
            sent_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(
                    e, device_senders, sum_n_node, axis_groups), device_edges)
            received_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(
                    e, device_receivers, sum_n_node, axis_groups),
                device_edges)
            # Here we scatter the global features to the corresponding nodes,
            # giving us tensors of shape [num_nodes, global_feat].
            global_attributes = tree.tree_map(
                lambda g: jnp.repeat(
                    g, n_node, axis=0, total_repeat_length=sum_n_node),
                globals_)
            nodes = update_node_fn(nodes, sent_attributes, received_attributes,
                                   global_attributes)

        if update_global_fn:
            n_graph = n_node.shape[0]
            graph_idx = jnp.arange(n_graph)
            # To aggregate nodes and edges from each graph to global features,
            # we first construct tensors that map the node to the corresponding graph.
            # For example, if you have `n_node=[1,2]`, we construct the tensor
            # [0, 1, 1]. We then do the same for edges.
            node_gr_idx = jnp.repeat(graph_idx,
                                     n_node,
                                     axis=0,
                                     total_repeat_length=sum_n_node)
            edge_gr_idx = jnp.repeat(device_graph_idx,
                                     device_n_edge,
                                     axis=0,
                                     total_repeat_length=sum_device_n_edge)
            # We use the aggregation function to pool the nodes/edges per graph.
            node_attributes = tree.tree_map(
                lambda n: aggregate_nodes_for_globals_fn(
                    n, node_gr_idx, n_graph), nodes)
            edge_attribtutes = tree.tree_map(
                lambda e: aggregate_edges_for_globals_fn(
                    e, edge_gr_idx, n_graph, axis_groups), device_edges)
            # These pooled nodes are the inputs to the global update fn.
            globals_ = update_global_fn(node_attributes, edge_attribtutes,
                                        globals_)
        # pylint: enable=g-long-lambda
        return ShardedEdgesGraphsTuple(nodes=nodes,
                                       device_edges=device_edges,
                                       device_senders=device_senders,
                                       device_receivers=device_receivers,
                                       receivers=receivers,
                                       senders=senders,
                                       device_graph_idx=device_graph_idx,
                                       globals=globals_,
                                       n_node=n_node,
                                       n_edge=n_edge,
                                       device_n_edge=device_n_edge)
예제 #8
0
 def is_array(x):
   return tree_all(tree_map(lambda x: isinstance(x, np.ndarray), x))
예제 #9
0
def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
    assert_close = partial(_assert_numpy_close,
                           atol=atol,
                           rtol=rtol,
                           err_msg=err_msg)
    tree_all(tree_multimap(assert_close, xs, ys))
예제 #10
0
def check_eq(xs, ys, err_msg=''):
    assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
    tree_all(tree_multimap(assert_close, xs, ys))
예제 #11
0
def _is_on_cpu(x):
    return tree_all(tree_map(_arr_is_on_cpu, x))
예제 #12
0
def all_none(x, attr: Optional[str] = None) -> bool:
    get_fn = (lambda x: x) if attr is None else lambda x: getattr(x, attr)
    return tree_all(tree_map(lambda x: get_fn(x) is None, x))