def _get_model(self, reducer=None, name=None): kwargs = { "edge_model_fn": functools.partial(snt.Linear, output_size=15), "node_encoder_model_fn": functools.partial(snt.Linear, output_size=8), "node_model_fn": functools.partial(snt.Linear, output_size=5), } if reducer is not None: kwargs["reducer"] = reducer if name: kwargs["name"] = name return modules.CommNet(**kwargs)
def __init__(self, name="GCrpNetworkTiny"): super(GCrpNetworkTiny, self).__init__(name=name) with self._enter_variable_scope(): self._obsEncoder = modules.obsEncoder(encoder_fn=make_conv_model) self._network = modules.CommNet(edge_model_fn=make_edge_model, node_model_fn=make_node_model) self._hnetwork = modules.HCommNet(edge_model_fn=make_edge_model, node_model_fn=make_Hnode_model) self._Lnetwork = modules.LCommNet(edge_model_fn=make_edge_model, node_model_fn=make_Lnode_model) self._qnet = modules.qEncoder(mlp_fn=get_q_model)
def test_higher_rank_outputs(self): """Tests that a CommNet can be build with higher rank inputs/outputs.""" input_graph = self._get_shaped_input_graph() graph_network = modules.CommNet(*self._get_shaped_model_fns()) self._assert_build_and_run(graph_network, input_graph)