Beispiel #1
0
def test_PairwiseIterator_input_is_string():
    """Test when input is list."""
    generator = labtypes.PairwiseIterator('hello')
    assert next(generator) == ('h', 'e')
    assert next(generator) == ('e', 'l')
    assert next(generator) == ('l', 'l')
    assert next(generator) == ('l', 'o')
Beispiel #2
0
def test_PairwiseIterator_input_is_list():
    """Test when input is list."""
    generator = labtypes.PairwiseIterator([0, 1, 2, 3])
    assert next(generator) == (0, 1)
    assert next(generator) == (1, 2)
    assert next(generator) == (2, 3)
    with pytest.raises(StopIteration):
        next(generator)
Beispiel #3
0
def AddShortestPath(rand: np.random.RandomState, graph: nx.Graph,
                    min_length: int = 1,
                    weight_name: str = "distance") -> nx.DiGraph:
  """Samples a shortest path from A to B and adds attributes to indicate it.

  Args:
    rand: A random seed for the graph generator. Default= None.
    graph: A `nx.Graph`.
    min_length: (optional) An `int` minimum number of edges in the shortest
      path. Default= 1.

  Returns:
    The `nx.DiGraph` with the shortest path added.

  Raises:
    ValueError: All shortest paths are below the minimum length
  """
  # Map from node pairs to the length of their shortest path.
  pair_to_length_dict = {}
  lengths = list(nx.all_pairs_shortest_path_length(graph))
  for x, yy in lengths:
    for y, l in yy.items():
      if l >= min_length:
        pair_to_length_dict[x, y] = l
  if not pair_to_length_dict:
    raise ValueError("All shortest paths are below the minimum length")
  # The node pairs which exceed the minimum length.
  node_pairs = list(pair_to_length_dict)

  # Computes probabilities per pair, to enforce uniform sampling of each
  # shortest path lengths.
  # The counts of pairs per length.
  counts = collections.Counter(pair_to_length_dict.values())
  prob_per_length = 1.0 / len(counts)
  probabilities = [
    prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs
  ]

  # Choose the start and end points.
  i = rand.choice(len(node_pairs), p=probabilities)
  start, end = node_pairs[i]
  path = nx.shortest_path(
      graph, source=start, target=end, weight=weight_name)

  # Creates a directed graph, to store the directed path from start to end.
  directed_graph = graph.to_directed()

  # Add the "start", "end", and "solution" attributes to the nodes.
  directed_graph.add_node(start, start=True)
  directed_graph.add_node(end, end=True)
  directed_graph.add_nodes_from(
      list(labtypes.SetDiff(directed_graph.nodes(), [start])), start=False)
  directed_graph.add_nodes_from(
      list(labtypes.SetDiff(directed_graph.nodes(), [end])), end=False)
  directed_graph.add_nodes_from(
      list(labtypes.SetDiff(directed_graph.nodes(), path)), solution=False)
  directed_graph.add_nodes_from(path, solution=True)

  # Now do the same for the edges.
  path_edges = list(labtypes.PairwiseIterator(path))
  directed_graph.add_edges_from(
      list(labtypes.SetDiff(directed_graph.edges(), path_edges)),
      solution=False)
  directed_graph.add_edges_from(path_edges, solution=True)

  return directed_graph
Beispiel #4
0
def test_PairwiseIterator_empty_list():
    """Test that empty list produces no output."""
    assert list(labtypes.PairwiseIterator([])) == []
Beispiel #5
0
def test_PairwiseIterator_input_is_iterator():
    """Test when input is iterator."""
    generator = labtypes.PairwiseIterator(range(4))
    assert next(generator) == (0, 1)
    assert next(generator) == (1, 2)
    assert next(generator) == (2, 3)