示例#1
0
  def __init__(self,
               edge_model_fn,
               node_model_fn,
               reducer=tf.unsorted_segment_sum,
               name="interaction_network"):
    """Initializes the InteractionNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.EdgeBlock` for details), and the shape of the
        output of this module must match the one of the input nodes, but for the
        first and last axis.
      node_model_fn: A callable that will be passed to `NodeBlock` to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.NodeBlock` for details).
      reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
    super(InteractionNetwork, self).__init__(name=name)

    with self._enter_variable_scope():
      self._edge_block = blocks.EdgeBlock(
          edge_model_fn=edge_model_fn, use_globals=False)
      self._node_block = blocks.NodeBlock(
          node_model_fn=node_model_fn,
          use_sent_edges=False,
          use_globals=False,
          received_edges_reducer=reducer)
示例#2
0
  def __init__(self,
               edge_model_fn,
               global_model_fn,
               reducer=tf.unsorted_segment_sum,
               name="relation_network"):
    """Initializes the RelationNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
    super(RelationNetwork, self).__init__(name=name)

    with self._enter_variable_scope():
      self._edge_block = blocks.EdgeBlock(
          edge_model_fn=edge_model_fn,
          use_edges=False,
          use_receiver_nodes=True,
          use_sender_nodes=True,
          use_globals=False)

      self._global_block = blocks.GlobalBlock(
          global_model_fn=global_model_fn,
          use_edges=True,
          use_nodes=False,
          use_globals=False,
          edges_reducer=reducer)
示例#3
0
  def __init__(self,
               edge_model_fn,
               node_encoder_model_fn,
               node_model_fn,
               reducer=tf.unsorted_segment_sum,
               name="comm_net"):
    """Initializes the CommNet module.

    Args:
      edge_model_fn: A callable to be passed to EdgeBlock. The callable must
        return a Sonnet module (or equivalent; see EdgeBlock for details).
      node_encoder_model_fn: A callable to be passed to the NodeBlock
        responsible for the first encoding of the nodes. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output should match the shape of the module built
        by `edge_model_fn`, but for the first and last dimension.
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details).
      reducer: Reduction to be used when aggregating the edges in the nodes.
        This should be a callable whose signature matches
        tf.unsorted_segment_sum.
      name: The module name.
    """
    super(CommNet, self).__init__(name=name)

    with self._enter_variable_scope():
      # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
      self._edge_block = blocks.EdgeBlock(
          edge_model_fn=edge_model_fn,
          use_edges=False,
          use_receiver_nodes=False,
          use_sender_nodes=True,
          use_globals=False)
      # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
      self._node_encoder_block = blocks.NodeBlock(
          node_model_fn=node_encoder_model_fn,
          use_received_edges=False,
          use_sent_edges=False,
          use_nodes=True,
          use_globals=False,
          received_edges_reducer=reducer,
          name="node_encoder_block")
      # Computes $\Theta(..)$ in Eq.(2) of 1706.06122
      self._node_block = blocks.NodeBlock(
          node_model_fn=node_model_fn,
          use_received_edges=True,
          use_sent_edges=False,
          use_nodes=True,
          use_globals=False,
          received_edges_reducer=reducer)
示例#4
0
  def __init__(self,
               edge_model_fn,
               node_model_fn,
               global_model_fn,
               reducer=tf.unsorted_segment_sum,
               edge_block_opt=None,
               node_block_opt=None,
               global_block_opt=None,
               name="graph_network"):
    """Initializes the GraphNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      node_model_fn: A callable that will be passed to NodeBlock to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see NodeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by NodeBlock and GlobalBlock to to aggregate
        nodes and edges. Defaults to tf.unsorted_segment_sum. This will be
        overridden by the reducers specified in `node_block_opt` and
        `global_block_opt`, if any.
      edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
      node_block_opt: Additional options to be passed to the NodeBlock. Can
        contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`,
        `use_globals` (all set to True by default), and
        `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
      global_block_opt: Additional options to be passed to the GlobalBlock. Can
        contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
        True by default), and `edges_reducer`, `nodes_reducer` (defaults to
        `reducer`).
      name: The module name.
    """
    super(GraphNetwork, self).__init__(name=name)
    edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
    node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
    global_block_opt = _make_default_global_block_opt(global_block_opt, reducer)

    with self._enter_variable_scope():
      self._edge_block = blocks.EdgeBlock(
          edge_model_fn=edge_model_fn, **edge_block_opt)
      self._node_block = blocks.NodeBlock(
          node_model_fn=node_model_fn, **node_block_opt)
      self._global_block = blocks.GlobalBlock(
          global_model_fn=global_model_fn, **global_block_opt)
示例#5
0
 def __init__(self,graph):
     super(GraphNetwork,self).__init__()
     self._edge_block = blocks.EdgeBlock(graph)
     self._node_block = blocks.NodeBlock(graph)
     self._global_block = blocks.GlobalBlock(graph)