예제 #1
0
def test_SetDiff_input_ranges():
    """Test when inputs are iterators."""
    assert labtypes.SetDiff(range(3), range(4)) == {3}
예제 #2
0
파일: graph_util.py 프로젝트: SpringRi/phd
def AddShortestPath(rand: np.random.RandomState, graph: nx.Graph,
                    min_length: int = 1,
                    weight_name: str = "distance") -> nx.DiGraph:
  """Samples a shortest path from A to B and adds attributes to indicate it.

  Args:
    rand: A random seed for the graph generator. Default= None.
    graph: A `nx.Graph`.
    min_length: (optional) An `int` minimum number of edges in the shortest
      path. Default= 1.

  Returns:
    The `nx.DiGraph` with the shortest path added.

  Raises:
    ValueError: All shortest paths are below the minimum length
  """
  # Map from node pairs to the length of their shortest path.
  pair_to_length_dict = {}
  lengths = list(nx.all_pairs_shortest_path_length(graph))
  for x, yy in lengths:
    for y, l in yy.items():
      if l >= min_length:
        pair_to_length_dict[x, y] = l
  if not pair_to_length_dict:
    raise ValueError("All shortest paths are below the minimum length")
  # The node pairs which exceed the minimum length.
  node_pairs = list(pair_to_length_dict)

  # Computes probabilities per pair, to enforce uniform sampling of each
  # shortest path lengths.
  # The counts of pairs per length.
  counts = collections.Counter(pair_to_length_dict.values())
  prob_per_length = 1.0 / len(counts)
  probabilities = [
    prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs
  ]

  # Choose the start and end points.
  i = rand.choice(len(node_pairs), p=probabilities)
  start, end = node_pairs[i]
  path = nx.shortest_path(
      graph, source=start, target=end, weight=weight_name)

  # Creates a directed graph, to store the directed path from start to end.
  directed_graph = graph.to_directed()

  # Add the "start", "end", and "solution" attributes to the nodes.
  directed_graph.add_node(start, start=True)
  directed_graph.add_node(end, end=True)
  directed_graph.add_nodes_from(
      list(labtypes.SetDiff(directed_graph.nodes(), [start])), start=False)
  directed_graph.add_nodes_from(
      list(labtypes.SetDiff(directed_graph.nodes(), [end])), end=False)
  directed_graph.add_nodes_from(
      list(labtypes.SetDiff(directed_graph.nodes(), path)), solution=False)
  directed_graph.add_nodes_from(path, solution=True)

  # Now do the same for the edges.
  path_edges = list(labtypes.PairwiseIterator(path))
  directed_graph.add_edges_from(
      list(labtypes.SetDiff(directed_graph.edges(), path_edges)),
      solution=False)
  directed_graph.add_edges_from(path_edges, solution=True)

  return directed_graph
예제 #3
0
def test_SetDiff_unmatching_types():
    """Test when inputs are of different types."""
    assert labtypes.SetDiff([1, 2, 3], ['a', 'b']) == {1, 2, 3, 'a', 'b'}
예제 #4
0
def test_SetDiff_overlapping_inputs():
    """Test when inputs overlap."""
    assert labtypes.SetDiff([1, 2], [1, 2, 3]) == {3}
예제 #5
0
def test_SetDiff_matching_inputs():
    """Test when both inputs are the same."""
    assert labtypes.SetDiff([1, 2, 3], [1, 2, 3]) == set()
예제 #6
0
def test_SetDiff_one_input_is_empty():
    """Test when one input is empty."""
    assert labtypes.SetDiff([1, 2, 3], []) == {1, 2, 3}
    assert labtypes.SetDiff([], [1, 2, 3]) == {1, 2, 3}
예제 #7
0
def test_SetDiff_empty_inputs():
    """Test when inputs are empty."""
    assert labtypes.SetDiff([], []) == set()