def test_propogates_dependence_up_through_lambda(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   lam = computation_building_blocks.Lambda('x', tf.int32, dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       lam, dummy_intrinsic_predicate)
   self.assertIn(lam, dependent_nodes)
 def test_propogates_dependence_up_through_selection(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', [tf.int32])
   selection = computation_building_blocks.Selection(dummy_intrinsic, index=0)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       selection, dummy_intrinsic_predicate)
   self.assertIn(selection, dependent_nodes)
 def test_propogates_dependence_up_through_tuple(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   integer_reference = computation_building_blocks.Reference('int', tf.int32)
   tup = computation_building_blocks.Tuple(
       [integer_reference, dummy_intrinsic])
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       tup, dummy_intrinsic_predicate)
   self.assertIn(tup, dependent_nodes)
 def test_propogates_dependence_up_through_block_locals(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   integer_reference = computation_building_blocks.Reference('int', tf.int32)
   block = computation_building_blocks.Block([('x', dummy_intrinsic)],
                                             integer_reference)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
示例#5
0
 def test_propogates_dependence_up_through_call(self):
     dummy_intrinsic = computation_building_blocks.Intrinsic(
         'dummy_intrinsic', tf.int32)
     ref_to_x = computation_building_blocks.Reference('x', tf.int32)
     identity_lambda = computation_building_blocks.Lambda(
         'x', tf.int32, ref_to_x)
     called_lambda = computation_building_blocks.Call(
         identity_lambda, dummy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         called_lambda, dummy_intrinsic_predicate)
     self.assertIn(called_lambda, dependent_nodes)
  def test_propogates_dependence_into_binding_to_reference(self):
    fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
    ref_to_x = computation_building_blocks.Reference('x', fed_type)
    federated_zero = computation_building_blocks.Intrinsic(
        intrinsic_defs.GENERIC_ZERO.uri, fed_type)

    def federated_zero_predicate(x):
      return isinstance(x, computation_building_blocks.Intrinsic
                       ) and x.uri == intrinsic_defs.GENERIC_ZERO.uri

    block = computation_building_blocks.Block([('x', federated_zero)], ref_to_x)
    dependent_nodes = tree_analysis.extract_nodes_consuming(
        block, federated_zero_predicate)
    self.assertIn(ref_to_x, dependent_nodes)
示例#7
0
 def test_adds_no_nodes_to_set_with_constant_false_predicate(self):
     nested_tree = computation_test_utils.create_nested_syntax_tree()
     all_nodes = tree_analysis.extract_nodes_consuming(
         nested_tree, lambda x: False)
     self.assertEmpty(all_nodes)
示例#8
0
 def test_adds_all_nodes_to_set_with_constant_true_predicate(self):
     nested_tree = computation_test_utils.create_nested_syntax_tree()
     all_nodes = tree_analysis.extract_nodes_consuming(
         nested_tree, lambda x: True)
     node_count = tree_analysis.count(nested_tree)
     self.assertLen(all_nodes, node_count)
示例#9
0
 def test_raises_on_none_predicate(self):
     data = computation_building_blocks.Data('dummy', [])
     with self.assertRaises(TypeError):
         tree_analysis.extract_nodes_consuming(data, None)
示例#10
0
 def test_raises_on_none_comp(self):
     with self.assertRaises(TypeError):
         tree_analysis.extract_nodes_consuming(None, lambda x: True)