def test_get_upward_paths(graph_nodes, test_instance, subgraph_root=None): """Test the correctness of imagenet_spec.get_upward_paths_from.""" # Randomly sample a number of start nodes for get_upward_paths. For each, test # the behavior of get_upward_paths when either specifying an end node or not. graph_nodes_list = list(graph_nodes) num_tested = 0 while num_tested < 10: start_node = np.random.choice(graph_nodes_list) if not start_node.parents: continue # Test the behavior of get_upward_paths_from without an end_node specified. paths = imagenet_spec.get_upward_paths_from(start_node) for p in paths: last_node = p[-1] if subgraph_root is not None: test_instance.assertEqual(last_node, subgraph_root) else: # Make sure the last node does not have parents (is a root). test_instance.assertLen(last_node.parents, 0) # Now test the case where an end node is given which is a direct parent of # the start node. end_node = np.random.choice(list(start_node.parents)) paths = imagenet_spec.get_upward_paths_from(start_node, end_node) # There should be at least one path in paths that contains only # (start_node and end_node). found_direct_path = False for p in paths: if len(p) == 2 and p[0] == start_node and p[1] == end_node: found_direct_path = True test_instance.assertTrue(found_direct_path) num_tested += 1
def test_lowest_common_ancestor_(lca, height, leaf_a, leaf_b, test_instance, root=None): """Check the correctness of the lowest common ancestor and its height.""" # First, check that it is a common ancestor of the longest paths. paths_a = imagenet_spec.get_upward_paths_from(leaf_a) longest_path_a = paths_a[np.argmax([len(p) for p in paths_a])] test_instance.assertIn(lca, longest_path_a) paths_b = imagenet_spec.get_upward_paths_from(leaf_b) longest_path_b = paths_b[np.argmax([len(p) for p in paths_b])] test_instance.assertIn(lca, longest_path_b) # Check that the LCA is not higher than the root. if root is not None: test_instance.assertFalse(imagenet_spec.is_descendent(root, lca)) # Assert that there is no lower common ancestor than the given lca. for height_a, node in enumerate(longest_path_a): if node in longest_path_b: height_b = longest_path_b.index(node) node_height = max(height_a, height_b) if node == lca: test_instance.assertEqual(node_height, height) else: # It then must have greater height than the lca's height. test_instance.assertGreaterEqual(node_height, height)