def test_incompatible_higher_rank_inputs_no_raise(self, field_to_reshape):
   """A RelationNetwork does not make assumptions on its inputs shapes."""
   input_graph = self._get_shaped_input_graph()
   edge_model_fn, _, global_model_fn = self._get_shaped_model_fns()
   input_graph = input_graph.map(
       lambda v: tf.transpose(v, [0, 2, 1, 3]), [field_to_reshape])
   network = modules.RelationNetwork(edge_model_fn, global_model_fn)
   self._assert_build_and_run(network, input_graph)
Exemple #2
0
 def __init__(self,
              latent_size=16,
              num_layers=2,
              name="MLPRelationNetwork"):
     super(MLPRelationNetwork, self).__init__(name=name)
     with self._enter_variable_scope():
         self._network = modules.RelationNetwork(
             edge_model_fn=make_mlp_model, global_model_fn=make_mlp_model)
 def _get_model(self, reducer=tf.unsorted_segment_sum, name=None):
   kwargs = {
       "edge_model_fn": functools.partial(snt.Linear, output_size=5),
       "global_model_fn": functools.partial(snt.Linear, output_size=15)
   }
   if reducer:
     kwargs["reducer"] = reducer
   if name:
     kwargs["name"] = name
   return modules.RelationNetwork(**kwargs)
Exemple #4
0
    def __init__(self,
                 edge_output_size=None,
                 node_output_size=None,
                 global_output_size=None,
                 network="GraphIndependent",
                 name="EncodeProcessDecode"):
        super(EncodeProcessDecode, self).__init__(name=name)

        if network == "GraphIndependent":
            self._encoder = MLPGraphIndependent()
            self._core = MLPGraphNetwork()
            self._decoder = MLPGraphIndependent()
        elif network == "RelationNetwork":
            self._encoder = MLPRelationNetwork()
            self._core = MLPGraphNetwork()
            self._decoder = MLPRelationNetwork()

        # Transforms the outputs into the appropriate shapes.
        if edge_output_size is None:
            edge_fn = None
        else:
            edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output")
        if node_output_size is None:
            node_fn = None
        else:
            node_fn = lambda: snt.Linear(node_output_size, name="node_output")
        if global_output_size is None:
            global_fn = None
        else:
            global_fn = lambda: snt.Linear(global_output_size,
                                           name="global_output")
        with self._enter_variable_scope():
            if network == "GraphIndependent":
                self._output_transform = modules.GraphIndependent(
                    edge_fn, node_fn, global_fn)
            elif network == "RelationNetwork":
                self._output_transform = modules.RelationNetwork(
                    edge_fn, global_fn)