def test_propogates_dependence_up_through_lambda(self): dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32) lam = building_blocks.Lambda('x', tf.int32, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( lam, dummy_intrinsic_predicate) self.assertIn(lam, dependent_nodes)
def test_propogates_dependence_up_through_selection(self): dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', [tf.int32]) selection = building_blocks.Selection(dummy_intrinsic, index=0) dependent_nodes = tree_analysis.extract_nodes_consuming( selection, dummy_intrinsic_predicate) self.assertIn(selection, dependent_nodes)
def test_propogates_dependence_up_through_block_result(self): dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32) integer_reference = building_blocks.Reference('int', tf.int32) block = building_blocks.Block([('x', integer_reference)], dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( block, dummy_intrinsic_predicate) self.assertIn(block, dependent_nodes)
def test_propogates_dependence_up_through_tuple(self): dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32) integer_reference = building_blocks.Reference('int', tf.int32) tup = building_blocks.Tuple([integer_reference, dummy_intrinsic]) dependent_nodes = tree_analysis.extract_nodes_consuming( tup, dummy_intrinsic_predicate) self.assertIn(tup, dependent_nodes)
def test_propogates_dependence_up_through_lambda(self): type_signature = computation_types.TensorType(tf.int32) whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic', type_signature) lam = building_blocks.Lambda('x', tf.int32, whimsy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( lam, whimsy_intrinsic_predicate) self.assertIn(lam, dependent_nodes)
def test_propogates_dependence_up_through_selection(self): type_signature = computation_types.StructType([tf.int32]) whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic', type_signature) selection = building_blocks.Selection(whimsy_intrinsic, index=0) dependent_nodes = tree_analysis.extract_nodes_consuming( selection, whimsy_intrinsic_predicate) self.assertIn(selection, dependent_nodes)
def test_propogates_dependence_up_through_call(self): dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32) ref_to_x = building_blocks.Reference('x', tf.int32) identity_lambda = building_blocks.Lambda('x', tf.int32, ref_to_x) called_lambda = building_blocks.Call(identity_lambda, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( called_lambda, dummy_intrinsic_predicate) self.assertIn(called_lambda, dependent_nodes)
def test_propogates_dependence_up_through_tuple(self): type_signature = computation_types.TensorType(tf.int32) whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic', type_signature) integer_reference = building_blocks.Reference('int', tf.int32) tup = building_blocks.Struct([integer_reference, whimsy_intrinsic]) dependent_nodes = tree_analysis.extract_nodes_consuming( tup, whimsy_intrinsic_predicate) self.assertIn(tup, dependent_nodes)
def test_propogates_dependence_up_through_block_locals(self): type_signature = computation_types.TensorType(tf.int32) dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', type_signature) integer_reference = building_blocks.Reference('int', tf.int32) block = building_blocks.Block([('x', dummy_intrinsic)], integer_reference) dependent_nodes = tree_analysis.extract_nodes_consuming( block, dummy_intrinsic_predicate) self.assertIn(block, dependent_nodes)
def test_propogates_dependence_up_through_call(self): type_signature = computation_types.TensorType(tf.int32) whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic', type_signature) ref_to_x = building_blocks.Reference('x', tf.int32) identity_lambda = building_blocks.Lambda('x', tf.int32, ref_to_x) called_lambda = building_blocks.Call(identity_lambda, whimsy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( called_lambda, whimsy_intrinsic_predicate) self.assertIn(called_lambda, dependent_nodes)
def test_propogates_dependence_into_binding_to_reference(self): fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) ref_to_x = building_blocks.Reference('x', fed_type) federated_zero = building_blocks.Intrinsic(intrinsic_defs.GENERIC_ZERO.uri, fed_type) def federated_zero_predicate(x): return x.is_intrinsic() and x.uri == intrinsic_defs.GENERIC_ZERO.uri block = building_blocks.Block([('x', federated_zero)], ref_to_x) dependent_nodes = tree_analysis.extract_nodes_consuming( block, federated_zero_predicate) self.assertIn(ref_to_x, dependent_nodes)
def test_adds_no_nodes_to_set_with_constant_false_predicate(self): nested_tree = building_block_test_utils.create_nested_syntax_tree() all_nodes = tree_analysis.extract_nodes_consuming( nested_tree, lambda x: False) self.assertEmpty(all_nodes)
def test_adds_all_nodes_to_set_with_constant_true_predicate(self): nested_tree = building_block_test_utils.create_nested_syntax_tree() all_nodes = tree_analysis.extract_nodes_consuming( nested_tree, lambda x: True) node_count = tree_analysis.count(nested_tree) self.assertLen(all_nodes, node_count)
def test_raises_on_none_predicate(self): data_type = computation_types.StructType([]) data = building_blocks.Data('whimsy', data_type) with self.assertRaises(TypeError): tree_analysis.extract_nodes_consuming(data, None)
def test_raises_on_none_comp(self): with self.assertRaises(TypeError): tree_analysis.extract_nodes_consuming(None, lambda x: True)
def test_raises_on_none_predicate(self): data = building_blocks.Data('dummy', []) with self.assertRaises(TypeError): tree_analysis.extract_nodes_consuming(data, None)