def test_propogates_dependence_up_through_lambda(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) lam = computation_building_blocks.Lambda('x', tf.int32, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( lam, dummy_intrinsic_predicate) self.assertIn(lam, dependent_nodes)
def test_propogates_dependence_up_through_selection(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', [tf.int32]) selection = computation_building_blocks.Selection(dummy_intrinsic, index=0) dependent_nodes = tree_analysis.extract_nodes_consuming( selection, dummy_intrinsic_predicate) self.assertIn(selection, dependent_nodes)
def test_propogates_dependence_up_through_tuple(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) integer_reference = computation_building_blocks.Reference('int', tf.int32) tup = computation_building_blocks.Tuple( [integer_reference, dummy_intrinsic]) dependent_nodes = tree_analysis.extract_nodes_consuming( tup, dummy_intrinsic_predicate) self.assertIn(tup, dependent_nodes)
def test_propogates_dependence_up_through_block_locals(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) integer_reference = computation_building_blocks.Reference('int', tf.int32) block = computation_building_blocks.Block([('x', dummy_intrinsic)], integer_reference) dependent_nodes = tree_analysis.extract_nodes_consuming( block, dummy_intrinsic_predicate) self.assertIn(block, dependent_nodes)
def test_propogates_dependence_up_through_call(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) ref_to_x = computation_building_blocks.Reference('x', tf.int32) identity_lambda = computation_building_blocks.Lambda( 'x', tf.int32, ref_to_x) called_lambda = computation_building_blocks.Call( identity_lambda, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( called_lambda, dummy_intrinsic_predicate) self.assertIn(called_lambda, dependent_nodes)
def test_propogates_dependence_into_binding_to_reference(self): fed_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) ref_to_x = computation_building_blocks.Reference('x', fed_type) federated_zero = computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_ZERO.uri, fed_type) def federated_zero_predicate(x): return isinstance(x, computation_building_blocks.Intrinsic ) and x.uri == intrinsic_defs.GENERIC_ZERO.uri block = computation_building_blocks.Block([('x', federated_zero)], ref_to_x) dependent_nodes = tree_analysis.extract_nodes_consuming( block, federated_zero_predicate) self.assertIn(ref_to_x, dependent_nodes)
def test_adds_no_nodes_to_set_with_constant_false_predicate(self): nested_tree = computation_test_utils.create_nested_syntax_tree() all_nodes = tree_analysis.extract_nodes_consuming( nested_tree, lambda x: False) self.assertEmpty(all_nodes)
def test_adds_all_nodes_to_set_with_constant_true_predicate(self): nested_tree = computation_test_utils.create_nested_syntax_tree() all_nodes = tree_analysis.extract_nodes_consuming( nested_tree, lambda x: True) node_count = tree_analysis.count(nested_tree) self.assertLen(all_nodes, node_count)
def test_raises_on_none_predicate(self): data = computation_building_blocks.Data('dummy', []) with self.assertRaises(TypeError): tree_analysis.extract_nodes_consuming(data, None)
def test_raises_on_none_comp(self): with self.assertRaises(TypeError): tree_analysis.extract_nodes_consuming(None, lambda x: True)