def test_propogates_dependence_up_through_lambda(self):
     dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic',
                                                 tf.int32)
     lam = building_blocks.Lambda('x', tf.int32, dummy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         lam, dummy_intrinsic_predicate)
     self.assertIn(lam, dependent_nodes)
 def test_propogates_dependence_up_through_selection(self):
     dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic',
                                                 [tf.int32])
     selection = building_blocks.Selection(dummy_intrinsic, index=0)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         selection, dummy_intrinsic_predicate)
     self.assertIn(selection, dependent_nodes)
Beispiel #3
0
 def test_propogates_dependence_up_through_block_result(self):
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32)
   integer_reference = building_blocks.Reference('int', tf.int32)
   block = building_blocks.Block([('x', integer_reference)], dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
Beispiel #4
0
 def test_propogates_dependence_up_through_tuple(self):
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32)
   integer_reference = building_blocks.Reference('int', tf.int32)
   tup = building_blocks.Tuple([integer_reference, dummy_intrinsic])
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       tup, dummy_intrinsic_predicate)
   self.assertIn(tup, dependent_nodes)
Beispiel #5
0
 def test_propogates_dependence_up_through_lambda(self):
     type_signature = computation_types.TensorType(tf.int32)
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     lam = building_blocks.Lambda('x', tf.int32, whimsy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         lam, whimsy_intrinsic_predicate)
     self.assertIn(lam, dependent_nodes)
Beispiel #6
0
 def test_propogates_dependence_up_through_selection(self):
     type_signature = computation_types.StructType([tf.int32])
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     selection = building_blocks.Selection(whimsy_intrinsic, index=0)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         selection, whimsy_intrinsic_predicate)
     self.assertIn(selection, dependent_nodes)
Beispiel #7
0
 def test_propogates_dependence_up_through_call(self):
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32)
   ref_to_x = building_blocks.Reference('x', tf.int32)
   identity_lambda = building_blocks.Lambda('x', tf.int32, ref_to_x)
   called_lambda = building_blocks.Call(identity_lambda, dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       called_lambda, dummy_intrinsic_predicate)
   self.assertIn(called_lambda, dependent_nodes)
Beispiel #8
0
 def test_propogates_dependence_up_through_tuple(self):
     type_signature = computation_types.TensorType(tf.int32)
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     integer_reference = building_blocks.Reference('int', tf.int32)
     tup = building_blocks.Struct([integer_reference, whimsy_intrinsic])
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         tup, whimsy_intrinsic_predicate)
     self.assertIn(tup, dependent_nodes)
 def test_propogates_dependence_up_through_block_locals(self):
   type_signature = computation_types.TensorType(tf.int32)
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic',
                                               type_signature)
   integer_reference = building_blocks.Reference('int', tf.int32)
   block = building_blocks.Block([('x', dummy_intrinsic)], integer_reference)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
Beispiel #10
0
 def test_propogates_dependence_up_through_call(self):
     type_signature = computation_types.TensorType(tf.int32)
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     ref_to_x = building_blocks.Reference('x', tf.int32)
     identity_lambda = building_blocks.Lambda('x', tf.int32, ref_to_x)
     called_lambda = building_blocks.Call(identity_lambda, whimsy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         called_lambda, whimsy_intrinsic_predicate)
     self.assertIn(called_lambda, dependent_nodes)
  def test_propogates_dependence_into_binding_to_reference(self):
    fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS)
    ref_to_x = building_blocks.Reference('x', fed_type)
    federated_zero = building_blocks.Intrinsic(intrinsic_defs.GENERIC_ZERO.uri,
                                               fed_type)

    def federated_zero_predicate(x):
      return x.is_intrinsic() and x.uri == intrinsic_defs.GENERIC_ZERO.uri

    block = building_blocks.Block([('x', federated_zero)], ref_to_x)
    dependent_nodes = tree_analysis.extract_nodes_consuming(
        block, federated_zero_predicate)
    self.assertIn(ref_to_x, dependent_nodes)
Beispiel #12
0
 def test_adds_no_nodes_to_set_with_constant_false_predicate(self):
     nested_tree = building_block_test_utils.create_nested_syntax_tree()
     all_nodes = tree_analysis.extract_nodes_consuming(
         nested_tree, lambda x: False)
     self.assertEmpty(all_nodes)
Beispiel #13
0
 def test_adds_all_nodes_to_set_with_constant_true_predicate(self):
     nested_tree = building_block_test_utils.create_nested_syntax_tree()
     all_nodes = tree_analysis.extract_nodes_consuming(
         nested_tree, lambda x: True)
     node_count = tree_analysis.count(nested_tree)
     self.assertLen(all_nodes, node_count)
Beispiel #14
0
 def test_raises_on_none_predicate(self):
     data_type = computation_types.StructType([])
     data = building_blocks.Data('whimsy', data_type)
     with self.assertRaises(TypeError):
         tree_analysis.extract_nodes_consuming(data, None)
Beispiel #15
0
 def test_raises_on_none_comp(self):
     with self.assertRaises(TypeError):
         tree_analysis.extract_nodes_consuming(None, lambda x: True)
 def test_raises_on_none_predicate(self):
     data = building_blocks.Data('dummy', [])
     with self.assertRaises(TypeError):
         tree_analysis.extract_nodes_consuming(data, None)