def test_incompatible_higher_rank_inputs_no_raise(self):
   """A DeepSets does not make assumptions on the shape if its input edges."""
   input_graph = self._get_shaped_input_graph()
   _, node_model_fn, global_model_fn = self._get_shaped_model_fns()
   input_graph = input_graph.replace(
       edges=tf.transpose(input_graph.edges, [0, 2, 1, 3]))
   network = modules.DeepSets(node_model_fn, global_model_fn)
   self._assert_build_and_run(network, input_graph)
 def test_incompatible_higher_rank_inputs_raises(self):
   """A exception should be raised if the inputs have incompatible shapes."""
   input_graph = self._get_shaped_input_graph()
   _, node_model_fn, global_model_fn = self._get_shaped_model_fns()
   input_graph = input_graph.replace(
       nodes=tf.transpose(input_graph.nodes, [0, 2, 1, 3]))
   graph_network = modules.DeepSets(node_model_fn, global_model_fn)
   with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
     graph_network(input_graph)
 def test_incompatible_higher_rank_partial_outputs_no_raise(self):
   """There is no constraint on the size of the partial outputs."""
   input_graph = self._get_shaped_input_graph()
   node_model_fn = functools.partial(
       snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
   global_model_fn = functools.partial(
       snt.Conv2D, output_channels=10, kernel_shape=[3, 3])
   network = modules.DeepSets(node_model_fn, global_model_fn)
   self._assert_build_and_run(network, input_graph)
 def _get_model(self, reducer=None, name=None):
   kwargs = {
       "node_model_fn": functools.partial(snt.Linear, output_size=5),
       "global_model_fn": functools.partial(snt.Linear, output_size=15)
   }
   if reducer:
     kwargs["reducer"] = reducer
   if name:
     kwargs["name"] = name
   return modules.DeepSets(**kwargs)