def test_incompatible_higher_rank_inputs_no_raise(self): """A DeepSets does not make assumptions on the shape if its input edges.""" input_graph = self._get_shaped_input_graph() _, node_model_fn, global_model_fn = self._get_shaped_model_fns() input_graph = input_graph.replace( edges=tf.transpose(input_graph.edges, [0, 2, 1, 3])) network = modules.DeepSets(node_model_fn, global_model_fn) self._assert_build_and_run(network, input_graph)
def test_incompatible_higher_rank_inputs_raises(self): """A exception should be raised if the inputs have incompatible shapes.""" input_graph = self._get_shaped_input_graph() _, node_model_fn, global_model_fn = self._get_shaped_model_fns() input_graph = input_graph.replace( nodes=tf.transpose(input_graph.nodes, [0, 2, 1, 3])) graph_network = modules.DeepSets(node_model_fn, global_model_fn) with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"): graph_network(input_graph)
def test_incompatible_higher_rank_partial_outputs_no_raise(self): """There is no constraint on the size of the partial outputs.""" input_graph = self._get_shaped_input_graph() node_model_fn = functools.partial( snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2]) global_model_fn = functools.partial( snt.Conv2D, output_channels=10, kernel_shape=[3, 3]) network = modules.DeepSets(node_model_fn, global_model_fn) self._assert_build_and_run(network, input_graph)
def _get_model(self, reducer=None, name=None): kwargs = { "node_model_fn": functools.partial(snt.Linear, output_size=5), "global_model_fn": functools.partial(snt.Linear, output_size=15) } if reducer: kwargs["reducer"] = reducer if name: kwargs["name"] = name return modules.DeepSets(**kwargs)