def test_observed_agents_selection(self):
    agent_limit = 10
    params = ParameterServer()
    params["ML"]["GraphObserver"]["AgentLimit"] = agent_limit
    observer = GraphObserver(params=params)

    obs, obs_world = self._get_observation(
      observer=observer,
      world=self.world,
      eval_id=self.eval_id)

    obs = tf.expand_dims(obs, 0) # add a batch dimension

    nodes, _, _ = GraphObserver.graph(obs, graph_dims=observer.graph_dimensions)
    nodes = nodes[0] # remove batch dim

    ego_node = nodes[0]
    ego_node_pos = Point2d(
      ego_node[0].numpy(), # x coordinate
      ego_node[1].numpy()) # y coordinate

    # verify that the nodes are ordered by
    # ascending distance to the ego node
    max_distance_to_ego = 0
    for node in nodes:
      pos = Point2d(
        node[0].numpy(), # x coordinate
        node[1].numpy()) # y coordinate
      distance_to_ego = Distance(pos, ego_node_pos)

      self.assertGreaterEqual(distance_to_ego, max_distance_to_ego,
        msg='Nodes are not sorted by distance relative to '\
          + 'the ego node in ascending order.')

      max_distance_to_ego = distance_to_ego
  def test_request_subset_of_available_edge_features(self):
    params = ParameterServer()

    requested_features = GraphObserver.available_edge_attributes()[0:2]
    params["ML"]["GraphObserver"]["EnabledEdgeFeatures"] = requested_features
    observer = GraphObserver(params=params)

    self.assertEqual(
      observer._enabled_edge_attribute_keys,
      requested_features)
  def test_request_partially_invalid_edge_features(self):
    params = ParameterServer()

    requested_features =\
      GraphObserver.available_edge_attributes()[0:2] + ['invalid']
    params["ML"]["GraphObserver"]["EnabledEdgeFeatures"] = requested_features
    observer = GraphObserver(params=params)

    # remove invalid feature from expected list
    requested_features.pop(-1)

    self.assertEqual(
      observer._enabled_edge_attribute_keys,
      requested_features)
  def test_observe_without_self_loops(self):
    num_agents = 4
    params = ParameterServer()
    params["ML"]["GraphObserver"]["AgentLimit"] = num_agents
    params["ML"]["GraphObserver"]["SelfLoops"] = False
    observer = GraphObserver(params=params)
    obs, _ = self._get_observation(observer, self.world, self.eval_id)
    obs = tf.expand_dims(obs, 0) # add a batch dimension

    _, adjacency, _ = GraphObserver.graph(obs, graph_dims=observer.graph_dimensions)
    adjacency_list_diagonal = (tf.linalg.tensor_diag_part(adjacency[0]))

    # assert zeros on the diagonal of the adjacency matrix
    tf.assert_equal(adjacency_list_diagonal, tf.zeros(num_agents))
Esempio n. 5
0
def visualize_graph(data_point, ax, visible_distance, normalization_ref):
  # Transform to nx.Graph
  observation = data_point["graph"]
  graph = GraphObserver.graph_from_observation(observation)

  # Get node positions
  pos = dict()
  goal = dict()
  for i in graph.nodes:
      features = graph.nodes[i]
      pos[i] = [features["x"].numpy(), features["y"].numpy()]
      goal[i] = [features["goal_x"].numpy(), features["goal_y"].numpy()]

  # Draw ellipse for visibility range of ego agent
  width = 4*visible_distance/normalization_ref["dx"][1]
  height = 4*visible_distance/normalization_ref["dy"][1]
  ellipse = Ellipse(pos[0], width=width,height=height, facecolor='yellow',
                    zorder=-1)#,**kwargs)
  ax.add_patch(ellipse)
  goal_ellipse = Ellipse(goal[0], width= 0.2, height=0.2, facecolor="green",
                         zorder=-2)
  ax.add_patch(goal_ellipse)

  # Change color for ego agent
  node_colors = ["blue" for i in range(len(graph.nodes))]
  node_colors[0] = "red"
  return nx.draw(graph, pos = pos, with_labels=True, ax=ax,
                 node_color=node_colors)
Esempio n. 6
0
def configurable_setup(params, num_scenarios, graph_sac=True):
  """Configurable GNN setup depending on a given filename

  Args:
    params: ParameterServer instance

  Returns: 
    observer: GraphObserver instance
    actor: ActorNetwork of BehaviorGraphSACAgent
  """
  observer = GraphObserver(params=params)
  bp = ContinuousHighwayBlueprint(params,
                                  number_of_senarios=num_scenarios,
                                  random_seed=0)
  env = SingleAgentRuntime(blueprint=bp, observer=observer,
                            render=False)
  if graph_sac:
    # Get GNN SAC actor net
    sac_agent = BehaviorGraphSACAgent(environment=env, observer=observer,
                                      params=params)
  else:
    sac_agent = BehaviorSACAgent(environment=env, params=params)

  actor = sac_agent._agent._actor_network
  return observer, actor
Esempio n. 7
0
  def test_gnn_parameters(self):
    params = ParameterServer()
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["NumMpLayers"] = 4
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["MpLayerNumUnits"] = 64
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["message_calculation_class"] = "gnn_edge_mlp"
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["global_exchange_mode"] = "mean"
    
    gnn_library = GNNWrapper.SupportedLibrary.spektral
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["Library"] = gnn_library

    
    bp = ContinuousHighwayBlueprint(params, number_of_senarios=2500, random_seed=0)
    observer = GraphObserver(params=params)
    env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
    sac_agent = BehaviorGraphSACAgent(environment=env, observer=observer, params=params)

    actor_gnn = sac_agent._agent._actor_network._gnn
    critic_gnn = sac_agent._agent._critic_network_1._gnn

    for gnn in [actor_gnn, critic_gnn]:
      self.assertEqual(gnn._params["NumMpLayers"], 4)
      self.assertEqual(gnn._params["MpLayerNumUnits"], 64)
      self.assertEqual(gnn._params["message_calculation_class"], "gnn_edge_mlp")
      self.assertEqual(gnn._params["global_exchange_mode"], "mean")
      self.assertEqual(gnn._params["Library"], gnn_library)
  def _init_call_func(self, observations, training=False):
    """Graph nets implementation."""
    node_vals, edge_indices, node_to_graph, edge_vals = GraphObserver.graph(
      observations=observations,
      graph_dims=self._graph_dims,
      dense=True)
    batch_size = tf.shape(observations)[0]
    node_counts = tf.unique_with_counts(node_to_graph)[2]
    edge_counts = tf.math.square(node_counts)

    input_graph = GraphsTuple(
      nodes=tf.cast(node_vals, tf.float32),
      edges=tf.cast(edge_vals, tf.float32),
      globals=tf.tile([[0.]], [batch_size, 1]),
      receivers=tf.cast(edge_indices[:, 1], tf.int32),
      senders=tf.cast(edge_indices[:, 0], tf.int32),
      n_node=node_counts,
      n_edge=edge_counts)

    self._latent_trace = []
    latent = input_graph
    for gb in self._graph_blocks:
      latent = gb(latent)
      self._latent_trace.append(latent)
    node_values = tf.reshape(latent.nodes, [batch_size, -1, self._embedding_size])
    return node_values
  def test_agent_pruning(self):
    """
    Verify that the observer correctly handles the case where
    there are less agents in the world than set as the limit.
    tl;dr: check that all entries of the node features,
    adjacency matrix, and edge features not corresponding to
    actually existing agents are zeros.
    """
    num_agents = 25
    params = ParameterServer()
    params["ML"]["GraphObserver"]["AgentLimit"] = num_agents
    observer = GraphObserver(params=params)
    obs, world = self._get_observation(observer, self.world, self.eval_id)
    obs = tf.expand_dims(obs, 0) # add a batch dimension

    nodes, adjacency_matrix, edge_features = GraphObserver.graph(
      observations=obs,
      graph_dims=observer.graph_dimensions)

    self.assertEqual(nodes.shape, [1, num_agents, observer.feature_len])

    expected_num_agents = len(world.agents)

    # nodes that do not represent agents, but are contained
    # to fill up the required observation space.
    expected_n_fill_up_nodes = num_agents - expected_num_agents
    fill_up_nodes = nodes[0, expected_num_agents:]

    self.assertEqual(
      fill_up_nodes.shape,
      [expected_n_fill_up_nodes, observer.feature_len])

    # verify that entries for non-existing agents are all zeros
    self.assertEqual(tf.reduce_sum(fill_up_nodes), 0)

    # the equivalent for edges: verify that for each zero entry
    # in the adjacency matrix, the corresponding edge feature
    # vector is a zero vector of correct length.
    zero_indices = tf.where(tf.equal(adjacency_matrix, 0))
    fill_up_edge_features = tf.gather_nd(edge_features, zero_indices)
    edge_feature_len = observer.graph_dimensions[2]
    zero_edge_feature_vectors = tf.zeros(
      [zero_indices.shape[0], edge_feature_len])

    self.assertTrue(tf.reduce_all(tf.equal(
      fill_up_edge_features,
      zero_edge_feature_vectors)))
Esempio n. 10
0
    def _call_spektral(self, observations, training=False):
        embeddings, adj_matrix, edge_features = GraphObserver.graph(
            observations=observations, graph_dims=self._graph_dims)

        for conv in self._convolutions:
            embeddings = conv([embeddings, adj_matrix, edge_features])

        return embeddings
 def setUp(self):
   """Setting up the test-case."""
   params = ParameterServer()
   bp = ContinuousHighwayBlueprint(params, random_seed=0)
   self.env = SingleAgentRuntime(blueprint=bp, render=False)
   self.env.reset()
   self.world = self.env._world
   self.observer = GraphObserver(params)
   self.eval_id = self.env._scenario._eval_agent_ids[0]
  def test_parameter_server_usage(self):
    expected_num_agents = 15
    expected_visibility_radius = 100

    params = ParameterServer()
    params["ML"]["GraphObserver"]["AgentLimit"] = expected_num_agents
    params["ML"]["GraphObserver"]["VisibilityRadius"] = expected_visibility_radius
    params["ML"]["GraphObserver"]["NormalizationEnabled"] = True
    observer = GraphObserver(params=params)

    self.assertEqual(observer._num_agents, expected_num_agents)
    self.assertEqual(observer._visibility_radius, expected_visibility_radius)
    # self.assertTrue(observer._add_self_loops)
    self.assertTrue(observer._normalize_observations)
Esempio n. 13
0
 def __init__(self,
              num_scenarios=3,
              dump_dir=None,
              render=False,
              params=ParameterServer()):
     """Inits DataGenerator with the parameters (see class definition)."""
     self._dump_dir = dump_dir
     self._num_scenarios = num_scenarios
     self._params = params
     self._bp = ContinuousHighwayBlueprint(self._params,\
       number_of_senarios=self._num_scenarios, random_seed=0)
     self._observer = GraphObserver(params=self._params)
     self._env = SingleAgentRuntime(blueprint=self._bp,
                                    observer=self._observer,
                                    render=render)
Esempio n. 14
0
def run_configuration(argv):
    # Uncomment one of the following default parameter filename definitions,
    # depending on which GNN library you'd like to use.

    # File with standard parameters for tf2_gnn use:
    # param_filename = "examples/example_params/tfa_sac_gnn_tf2_gnn_default.json"

    # File with standard parameters for spektral use:
    param_filename = "examples/example_params/tfa_sac_gnn_spektral_default.json"
    params = ParameterServer(filename=param_filename)

    # NOTE: Modify these paths to specify your preferred path for checkpoints and summaries
    # params["ML"]["BehaviorTFAAgents"]["CheckpointPath"] = "YOUR_PATH"
    # params["ML"]["TFARunner"]["SummaryPath"] = "YOUR_PATH"

    #viewer = MPViewer(
    #  params=params,
    #  x_range=[-35, 35],
    #  y_range=[-35, 35],
    #  follow_agent_id=True)

    #viewer = VideoRenderer(
    #  renderer=viewer,
    #  world_step_time=0.2,
    #  fig_path="/your_path_here/training/video/")

    # create environment
    bp = ContinuousHighwayBlueprint(params,
                                    number_of_senarios=2500,
                                    random_seed=0)

    observer = GraphObserver(params=params)

    env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)

    sac_agent = BehaviorGraphSACAgent(environment=env,
                                      observer=observer,
                                      params=params)
    env.ml_behavior = sac_agent
    runner = SACRunner(params=params, environment=env, agent=sac_agent)

    if FLAGS.mode == "train":
        runner.SetupSummaryWriter()
        runner.Train()
    elif FLAGS.mode == "visualize":
        runner.Visualize(5)
    elif FLAGS.mode == "evaluate":
        runner.Evaluate()
Esempio n. 15
0
 def test_sac_graph_agent(self):
     params = ParameterServer()
     bp = ContinuousMergingBlueprint(params,
                                     number_of_senarios=2500,
                                     random_seed=0)
     observer = GraphObserver(params=params)
     env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
     sac_agent = BehaviorGraphSACAgent(environment=env,
                                       observer=observer,
                                       params=params)
     env.ml_behavior = sac_agent
     env.reset()
     eval_id = env._scenario._eval_agent_ids[0]
     self.assertEqual(env._world.agents[eval_id].behavior_model, sac_agent)
     for _ in range(0, 5):
         env._world.Step(0.2)
Esempio n. 16
0
    def _call_tf2_gnn(self, observations, training=False):
        batch_size = tf.constant(observations.shape[0])

        embeddings, adj_list, node_to_graph_map = GraphObserver.graph(
            observations=observations, graph_dims=self._graph_dims, dense=True)

        gnn_input = GNNInput(
            node_features=embeddings,
            adjacency_lists=(adj_list, ),
            node_to_graph_map=node_to_graph_map,
            num_graphs=batch_size,
        )

        # tf2_gnn outputs a flattened node embeddings vector, so we
        # reshape it to have the embeddings of each node seperately.
        flat_output = self._gnn(gnn_input, training=training)
        output = tf.reshape(flat_output, [batch_size, -1, self.num_units])

        return output
  def test_observation_conforms_to_spec(self):
    """
    Verify that the observation returned by the observer
    is valid with respect to its defined observation space.
    """
    num_agents = 4
    params = ParameterServer()
    params["ML"]["GraphObserver"]["AgentLimit"] = num_agents
    observer = GraphObserver(params=params)
    obs, _ = self._get_observation(observer, self.world, self.eval_id)

    self.assertTrue(observer.observation_space.contains(obs))

    # additionally check that the adjacency list is binary, since
    # this can't be enforced by the observation space currently
    adj_start_idx = num_agents * observer.feature_len
    adj_end_idx = adj_start_idx + num_agents ** 2
    adj_list = obs[adj_start_idx : adj_end_idx]

    for element in adj_list: self.assertIn(element, [0, 1])
Esempio n. 18
0
    def _configurable_setup(self, params_filename):
        """Configurable GNN setup depending on a given filename

    Args:
      params_filename: str, corresponds to path of params file

    Returns:
      params: ParameterServer instance
      observer: GraphObserver instance
      actor: ActorNetwork of BehaviorGraphSACAgent
    """
        params = ParameterServer(filename=params_filename)
        observer = GraphObserver(params=params)
        bp = ContinuousHighwayBlueprint(params,
                                        number_of_senarios=2,
                                        random_seed=0)
        env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
        # Get GNN SAC actor net
        sac_agent = BehaviorGraphSACAgent(environment=env,
                                          observer=observer,
                                          params=params)
        actor = sac_agent._agent._actor_network
        return params, observer, actor
Esempio n. 19
0
def run_configuration(argv):
    params = ParameterServer()

    # NOTE: Modify these paths to specify your preferred path for checkpoints and summaries
    # params["ML"]["BehaviorTFAAgents"]["CheckpointPath"] = "/Users/hart/Development/bark-ml/checkpoints_merge_spektral_att2/"
    # params["ML"]["TFARunner"]["SummaryPath"] = "/Users/hart/Development/bark-ml/checkpoints_merge_spektral_att2/"

    #viewer = MPViewer(
    #  params=params,
    #  x_range=[-35, 35],
    #  y_range=[-35, 35],
    #  follow_agent_id=True)
    #viewer = VideoRenderer(
    #  renderer=viewer,
    #  world_step_time=0.2,
    #  fig_path="/your_path_here/training/video/")

    # create environment
    bp = ContinuousMergingBlueprint(params, num_scenarios=2500, random_seed=0)

    observer = GraphObserver(params=params)

    env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
    sac_agent = BehaviorGraphSACAgent(environment=env,
                                      observer=observer,
                                      params=params,
                                      init_gnn='init_interaction_network')
    env.ml_behavior = sac_agent
    runner = SACRunner(params=params, environment=env, agent=sac_agent)

    if FLAGS.mode == "train":
        runner.SetupSummaryWriter()
        runner.Train()
    elif FLAGS.mode == "visualize":
        runner.Run(num_episodes=10, render=True)
    elif FLAGS.mode == "evaluate":
        runner.Run(num_episodes=250, render=False)
Esempio n. 20
0
  def _test_graph_dim_validation_accepts_observer_dims(self):
    observer = GraphObserver()
    gnn = GNNWrapper(graph_dims=observer.graph_dimensions)

    # verify no exception is raised and the dims are applied
    self.assertEqual(gnn._graph_dims, observer.graph_dimensions)
Esempio n. 21
0
def run_configuration(argv):
  params = ParameterServer()
  # NOTE: Modify these paths to specify your preferred path for checkpoints and summaries
  # params["ML"]["BehaviorTFAAgents"]["CheckpointPath"] = "/Users/hart/Development/bark-ml/checkpoints/"
  # params["ML"]["TFARunner"]["SummaryPath"] = "/Users/hart/Development/bark-ml/checkpoints/"
  params["Visualization"]["Agents"]["Alpha"]["Other"] = 0.2
  params["Visualization"]["Agents"]["Alpha"]["Controlled"] = 0.2
  params["Visualization"]["Agents"]["Alpha"]["Controlled"] = 0.2
  params["ML"]["VisualizeCfWorlds"] = False
  params["ML"]["VisualizeCfHeatmap"] = False
  # params["ML"]["ResultsFolder"] = "/Users/hart/Development/bark-ml/results/data/"

  # viewer = MPViewer(
  #   params=params,
  #   x_range=[-35, 35],
  #   y_range=[-35, 35],
  #   follow_agent_id=True)


  # create environment
  bp = ContinuousMergingBlueprint(params,
                                  num_scenarios=2500,
                                  random_seed=0)

  observer = GraphObserver(params=params)

  behavior_model_pool = []
  for count, a in enumerate([-5., 0., 5.]):
    local_params = params.AddChild("local_"+str(count))
    local_params["BehaviorConstantAcceleration"]["ConstAcceleration"] = a
    behavior = BehaviorConstantAcceleration(local_params)
    behavior_model_pool.append(behavior)

  env = CounterfactualRuntime(
    blueprint=bp,
    observer=observer,
    render=False,
    params=params,
    behavior_model_pool=behavior_model_pool)
  sac_agent = BehaviorGraphSACAgent(environment=env,
                                    observer=observer,
                                    params=params)
  env.ml_behavior = sac_agent
  runner = SACRunner(params=params,
                     environment=env,
                     agent=sac_agent)

  if FLAGS.mode == "train":
    runner.SetupSummaryWriter()
    runner.Train()
  elif FLAGS.mode == "visualize":
    runner._environment._max_col_rate = 0.
    runner.Run(num_episodes=1, render=True)
  elif FLAGS.mode == "evaluate":
    for cr in np.arange(0, 1, 0.1):
      runner._environment._max_col_rate = cr
      runner.Run(num_episodes=250, render=False, max_col_rate=cr)
    runner._environment._tracer.Save(
      params["ML"]["ResultsFolder"] + "evaluation_results_runtime.pckl")
    goal_reached = runner._tracer.success_rate
    runner._tracer.Save(
      params["ML"]["ResultsFolder"] + "evaluation_results_runner.pckl")
    def test_gnn(self):
        # Node features for graph 0.
        nodes_0 = [
            [10.1, 20., 30.],  # Node 0
            [11., 21., 31.],  # Node 1
            [12., 22., 32.],  # Node 2
            [13., 23., 33.],  # Node 3
            [14., 24., 34.]
        ]  # Node 4

        # Edge features for graph 0.
        edges_0 = [
            [100., 200.],  # Edge 0
            [101.2, 201.],  # Edge 1
            [102., 202.],  # Edge 2
            [103., 203.],  # Edge 3
            [104., 204.],  # Edge 4
            [105., 205.]
        ]  # Edge 5

        # The sender and receiver nodes associated with each edge for graph 0.
        senders_0 = [
            0,  # Index of the sender node for edge 0
            1,  # Index of the sender node for edge 1
            1,  # Index of the sender node for edge 2
            2,  # Index of the sender node for edge 3
            2,  # Index of the sender node for edge 4
            3
        ]  # Index of the sender node for edge 5
        receivers_0 = [
            1,  # Index of the receiver node for edge 0
            2,  # Index of the receiver node for edge 1
            3,  # Index of the receiver node for edge 2
            0,  # Index of the receiver node for edge 3
            3,  # Index of the receiver node for edge 4
            4
        ]  # Index of the receiver node for edge 5

        data_dict_0 = {
            "globals": [],
            "nodes": nodes_0,
            "edges": edges_0,
            "senders": senders_0,
            "receivers": receivers_0
        }

        input_graph = utils_tf.data_dicts_to_graphs_tuple(
            [data_dict_0, data_dict_0])

        num_nodes = len(nodes_0)
        num_features = 3
        num_edge_features = len(edges_0)
        graph_dims = (num_nodes, num_features, num_edge_features)

        # 6 edges x 2
        # 5 nodes x 3
        # adj matrix 5x5
        obs = np.zeros(shape=(1, 52))
        # NOTE: use dense
        params = ParameterServer()
        graph_observer = GraphObserver(params)
        graph_observer.feature_len = 2
        graph_observer.edge_feature_len = 3

        _, _, _ = graph_observer.graph(obs, graph_dims=graph_dims)

        print(input_graph)
  def test_observation_to_graph_conversion(self):
    params = ParameterServer()
    params["ML"]["GraphObserver"]["SelfLoops"] = False
    graph_observer = GraphObserver(params=params)

    num_nodes = 5
    num_features = 5
    num_edge_features = 4

    node_features = np.random.random_sample((num_nodes, num_features))
    edge_features = np.random.random_sample((num_nodes, num_nodes, num_edge_features))

    # note that edges are bidirectional, the
    # the matrix is symmetric
    adjacency_list = [
      [0, 1, 1, 1, 0], # 1 connects with 2, 3, 4
      [1, 0, 1, 1, 0], # 2 connects with 3, 4
      [1, 1, 0, 1, 0], # 3 connects with 4
      [1, 1, 1, 0, 0], # 4 has no links
      [0, 0, 0, 0, 0]  # empty slot -> all zeros
    ]

    observation = np.array(node_features)
    observation = np.append(observation, adjacency_list)
    observation = np.append(observation, edge_features)
    observation = observation.reshape(-1)
    observations = np.array([observation, observation])

    self.assertEqual(observations.shape, (2, 150))

    expected_nodes = tf.constant([node_features, node_features])
    expected_edge_features = tf.constant([edge_features, edge_features])

    graph_dims = (num_nodes, num_features, num_edge_features)
    nodes, edges, edge_features = graph_observer.graph(observations, graph_dims)

    self.assertTrue(tf.reduce_all(tf.equal(nodes, expected_nodes)))
    self.assertTrue(tf.reduce_all(tf.equal(edge_features, expected_edge_features)))

    observations = np.array([observation, observation, observation])

    # in dense mode, the nodes of all graphs are in a single list
    expected_nodes = tf.constant([node_features, node_features, node_features])
    expected_nodes = tf.reshape(expected_nodes, [-1, num_features])

    # the edges encoded in the adjacency list above
    expected_dense_edges = tf.constant([
      # graph 1
      [0, 1], [0, 2], [0, 3],
      [1, 0], [1, 2], [1, 3],
      [2, 0], [2, 1], [2, 3],
      [3, 0], [3, 1], [3, 2],
      # graph 2
      [5, 6], [5, 7], [5, 8],
      [6, 5], [6, 7], [6, 8],
      [7, 5], [7, 6], [7, 8],
      [8, 5], [8, 6], [8, 7],
      # graph 3
      [10, 11], [10, 12], [10, 13],
      [11, 10], [11, 12], [11, 13],
      [12, 10], [12, 11], [12, 13],
      [13, 10], [13, 11], [13, 12]
    ], dtype=tf.int32)

    expected_node_to_graph_map = tf.constant([
      0, 0, 0, 0, 0,
      1, 1, 1, 1, 1,
      2, 2, 2, 2, 2
    ])

    observations = tf.convert_to_tensor(observations)
    print(observations)
    nodes, edges, node_to_graph_map, E =\
      GraphObserver.graph(observations, graph_dims, dense=True)

    self.assertTrue(tf.reduce_all(tf.equal(nodes, expected_nodes)))
    self.assertTrue(tf.reduce_all(tf.equal(edges, expected_dense_edges)))