def __init__(self, edge_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="interaction_network"): """Initializes the InteractionNetwork module. Args: edge_model_fn: A callable that will be passed to `EdgeBlock` to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see `blocks.EdgeBlock` for details), and the shape of the output of this module must match the one of the input nodes, but for the first and last axis. node_model_fn: A callable that will be passed to `NodeBlock` to perform per-node computations. The callable must return a Sonnet module (or equivalent; see `blocks.NodeBlock` for details). reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to tf.unsorted_segment_sum. name: The module name. """ super(InteractionNetwork, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock( edge_model_fn=edge_model_fn, use_globals=False) self._node_block = blocks.NodeBlock( node_model_fn=node_model_fn, use_sent_edges=False, use_globals=False, received_edges_reducer=reducer)
def __init__(self, edge_model_fn, global_model_fn, reducer=tf.unsorted_segment_sum, name="relation_network"): """Initializes the RelationNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to tf.unsorted_segment_sum. name: The module name. """ super(RelationNetwork, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock( edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, use_edges=True, use_nodes=False, use_globals=False, edges_reducer=reducer)
def __init__(self, edge_model_fn, node_encoder_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="comm_net"): """Initializes the CommNet module. Args: edge_model_fn: A callable to be passed to EdgeBlock. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_encoder_model_fn: A callable to be passed to the NodeBlock responsible for the first encoding of the nodes. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output should match the shape of the module built by `edge_model_fn`, but for the first and last dimension. node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). reducer: Reduction to be used when aggregating the edges in the nodes. This should be a callable whose signature matches tf.unsorted_segment_sum. name: The module name. """ super(CommNet, self).__init__(name=name) with self._enter_variable_scope(): # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122 self._edge_block = blocks.EdgeBlock( edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122 self._node_encoder_block = blocks.NodeBlock( node_model_fn=node_encoder_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer, name="node_encoder_block") # Computes $\Theta(..)$ in Eq.(2) of 1706.06122 self._node_block = blocks.NodeBlock( node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer)
def __init__(self, edge_model_fn, node_model_fn, global_model_fn, reducer=tf.unsorted_segment_sum, edge_block_opt=None, node_block_opt=None, global_block_opt=None, name="graph_network"): """Initializes the GraphNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_model_fn: A callable that will be passed to NodeBlock to perform per-node computations. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by NodeBlock and GlobalBlock to to aggregate nodes and edges. Defaults to tf.unsorted_segment_sum. This will be overridden by the reducers specified in `node_block_opt` and `global_block_opt`, if any. edge_block_opt: Additional options to be passed to the EdgeBlock. Can contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`, `use_globals`. By default, these are all True. node_block_opt: Additional options to be passed to the NodeBlock. Can contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`, `use_globals` (all set to True by default), and `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`). global_block_opt: Additional options to be passed to the GlobalBlock. Can contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to True by default), and `edges_reducer`, `nodes_reducer` (defaults to `reducer`). name: The module name. """ super(GraphNetwork, self).__init__(name=name) edge_block_opt = _make_default_edge_block_opt(edge_block_opt) node_block_opt = _make_default_node_block_opt(node_block_opt, reducer) global_block_opt = _make_default_global_block_opt(global_block_opt, reducer) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock( edge_model_fn=edge_model_fn, **edge_block_opt) self._node_block = blocks.NodeBlock( node_model_fn=node_model_fn, **node_block_opt) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, **global_block_opt)
def __init__(self,graph): super(GraphNetwork,self).__init__() self._edge_block = blocks.EdgeBlock(graph) self._node_block = blocks.NodeBlock(graph) self._global_block = blocks.GlobalBlock(graph)