Ejemplo n.º 1
0
 def _get_model(self, reducer=None, name=None):
   kwargs = {
       "edge_model_fn": functools.partial(snt.Linear, output_size=15),
       "node_encoder_model_fn": functools.partial(snt.Linear, output_size=8),
       "node_model_fn": functools.partial(snt.Linear, output_size=5),
   }
   if reducer is not None:
     kwargs["reducer"] = reducer
   if name:
     kwargs["name"] = name
   return modules.CommNet(**kwargs)
Ejemplo n.º 2
0
    def __init__(self, name="GCrpNetworkTiny"):
        super(GCrpNetworkTiny, self).__init__(name=name)
        with self._enter_variable_scope():
            self._obsEncoder = modules.obsEncoder(encoder_fn=make_conv_model)
            self._network = modules.CommNet(edge_model_fn=make_edge_model,
                                            node_model_fn=make_node_model)

            self._hnetwork = modules.HCommNet(edge_model_fn=make_edge_model,
                                              node_model_fn=make_Hnode_model)

            self._Lnetwork = modules.LCommNet(edge_model_fn=make_edge_model,
                                              node_model_fn=make_Lnode_model)

            self._qnet = modules.qEncoder(mlp_fn=get_q_model)
Ejemplo n.º 3
0
 def test_higher_rank_outputs(self):
   """Tests that a CommNet can be build with higher rank inputs/outputs."""
   input_graph = self._get_shaped_input_graph()
   graph_network = modules.CommNet(*self._get_shaped_model_fns())
   self._assert_build_and_run(graph_network, input_graph)