示例#1
0
  def test_gnn_parameters(self):
    params = ParameterServer()
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["NumMpLayers"] = 4
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["MpLayerNumUnits"] = 64
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["message_calculation_class"] = "gnn_edge_mlp"
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["global_exchange_mode"] = "mean"
    
    gnn_library = GNNWrapper.SupportedLibrary.spektral
    params["ML"]["BehaviorGraphSACAgent"]["GNN"]["Library"] = gnn_library

    
    bp = ContinuousHighwayBlueprint(params, number_of_senarios=2500, random_seed=0)
    observer = GraphObserver(params=params)
    env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
    sac_agent = BehaviorGraphSACAgent(environment=env, observer=observer, params=params)

    actor_gnn = sac_agent._agent._actor_network._gnn
    critic_gnn = sac_agent._agent._critic_network_1._gnn

    for gnn in [actor_gnn, critic_gnn]:
      self.assertEqual(gnn._params["NumMpLayers"], 4)
      self.assertEqual(gnn._params["MpLayerNumUnits"], 64)
      self.assertEqual(gnn._params["message_calculation_class"], "gnn_edge_mlp")
      self.assertEqual(gnn._params["global_exchange_mode"], "mean")
      self.assertEqual(gnn._params["Library"], gnn_library)
示例#2
0
def configurable_setup(params, num_scenarios, graph_sac=True):
  """Configurable GNN setup depending on a given filename

  Args:
    params: ParameterServer instance

  Returns: 
    observer: GraphObserver instance
    actor: ActorNetwork of BehaviorGraphSACAgent
  """
  observer = GraphObserver(params=params)
  bp = ContinuousHighwayBlueprint(params,
                                  number_of_senarios=num_scenarios,
                                  random_seed=0)
  env = SingleAgentRuntime(blueprint=bp, observer=observer,
                            render=False)
  if graph_sac:
    # Get GNN SAC actor net
    sac_agent = BehaviorGraphSACAgent(environment=env, observer=observer,
                                      params=params)
  else:
    sac_agent = BehaviorSACAgent(environment=env, params=params)

  actor = sac_agent._agent._actor_network
  return observer, actor
示例#3
0
def run_configuration(argv):
    # Uncomment one of the following default parameter filename definitions,
    # depending on which GNN library you'd like to use.

    # File with standard parameters for tf2_gnn use:
    # param_filename = "examples/example_params/tfa_sac_gnn_tf2_gnn_default.json"

    # File with standard parameters for spektral use:
    param_filename = "examples/example_params/tfa_sac_gnn_spektral_default.json"
    params = ParameterServer(filename=param_filename)

    # NOTE: Modify these paths to specify your preferred path for checkpoints and summaries
    # params["ML"]["BehaviorTFAAgents"]["CheckpointPath"] = "YOUR_PATH"
    # params["ML"]["TFARunner"]["SummaryPath"] = "YOUR_PATH"

    #viewer = MPViewer(
    #  params=params,
    #  x_range=[-35, 35],
    #  y_range=[-35, 35],
    #  follow_agent_id=True)

    #viewer = VideoRenderer(
    #  renderer=viewer,
    #  world_step_time=0.2,
    #  fig_path="/your_path_here/training/video/")

    # create environment
    bp = ContinuousHighwayBlueprint(params,
                                    number_of_senarios=2500,
                                    random_seed=0)

    observer = GraphObserver(params=params)

    env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)

    sac_agent = BehaviorGraphSACAgent(environment=env,
                                      observer=observer,
                                      params=params)
    env.ml_behavior = sac_agent
    runner = SACRunner(params=params, environment=env, agent=sac_agent)

    if FLAGS.mode == "train":
        runner.SetupSummaryWriter()
        runner.Train()
    elif FLAGS.mode == "visualize":
        runner.Visualize(5)
    elif FLAGS.mode == "evaluate":
        runner.Evaluate()
示例#4
0
 def test_sac_graph_agent(self):
     params = ParameterServer()
     bp = ContinuousMergingBlueprint(params,
                                     number_of_senarios=2500,
                                     random_seed=0)
     observer = GraphObserver(params=params)
     env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
     sac_agent = BehaviorGraphSACAgent(environment=env,
                                       observer=observer,
                                       params=params)
     env.ml_behavior = sac_agent
     env.reset()
     eval_id = env._scenario._eval_agent_ids[0]
     self.assertEqual(env._world.agents[eval_id].behavior_model, sac_agent)
     for _ in range(0, 5):
         env._world.Step(0.2)
示例#5
0
    def _configurable_setup(self, params_filename):
        """Configurable GNN setup depending on a given filename

    Args:
      params_filename: str, corresponds to path of params file

    Returns:
      params: ParameterServer instance
      observer: GraphObserver instance
      actor: ActorNetwork of BehaviorGraphSACAgent
    """
        params = ParameterServer(filename=params_filename)
        observer = GraphObserver(params=params)
        bp = ContinuousHighwayBlueprint(params,
                                        number_of_senarios=2,
                                        random_seed=0)
        env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
        # Get GNN SAC actor net
        sac_agent = BehaviorGraphSACAgent(environment=env,
                                          observer=observer,
                                          params=params)
        actor = sac_agent._agent._actor_network
        return params, observer, actor
示例#6
0
def run_configuration(argv):
    params = ParameterServer()

    # NOTE: Modify these paths to specify your preferred path for checkpoints and summaries
    # params["ML"]["BehaviorTFAAgents"]["CheckpointPath"] = "/Users/hart/Development/bark-ml/checkpoints_merge_spektral_att2/"
    # params["ML"]["TFARunner"]["SummaryPath"] = "/Users/hart/Development/bark-ml/checkpoints_merge_spektral_att2/"

    #viewer = MPViewer(
    #  params=params,
    #  x_range=[-35, 35],
    #  y_range=[-35, 35],
    #  follow_agent_id=True)
    #viewer = VideoRenderer(
    #  renderer=viewer,
    #  world_step_time=0.2,
    #  fig_path="/your_path_here/training/video/")

    # create environment
    bp = ContinuousMergingBlueprint(params, num_scenarios=2500, random_seed=0)

    observer = GraphObserver(params=params)

    env = SingleAgentRuntime(blueprint=bp, observer=observer, render=False)
    sac_agent = BehaviorGraphSACAgent(environment=env,
                                      observer=observer,
                                      params=params,
                                      init_gnn='init_interaction_network')
    env.ml_behavior = sac_agent
    runner = SACRunner(params=params, environment=env, agent=sac_agent)

    if FLAGS.mode == "train":
        runner.SetupSummaryWriter()
        runner.Train()
    elif FLAGS.mode == "visualize":
        runner.Run(num_episodes=10, render=True)
    elif FLAGS.mode == "evaluate":
        runner.Run(num_episodes=250, render=False)
示例#7
0
def run_configuration(argv):
  params = ParameterServer()
  # NOTE: Modify these paths to specify your preferred path for checkpoints and summaries
  # params["ML"]["BehaviorTFAAgents"]["CheckpointPath"] = "/Users/hart/Development/bark-ml/checkpoints/"
  # params["ML"]["TFARunner"]["SummaryPath"] = "/Users/hart/Development/bark-ml/checkpoints/"
  params["Visualization"]["Agents"]["Alpha"]["Other"] = 0.2
  params["Visualization"]["Agents"]["Alpha"]["Controlled"] = 0.2
  params["Visualization"]["Agents"]["Alpha"]["Controlled"] = 0.2
  params["ML"]["VisualizeCfWorlds"] = False
  params["ML"]["VisualizeCfHeatmap"] = False
  # params["ML"]["ResultsFolder"] = "/Users/hart/Development/bark-ml/results/data/"

  # viewer = MPViewer(
  #   params=params,
  #   x_range=[-35, 35],
  #   y_range=[-35, 35],
  #   follow_agent_id=True)


  # create environment
  bp = ContinuousMergingBlueprint(params,
                                  num_scenarios=2500,
                                  random_seed=0)

  observer = GraphObserver(params=params)

  behavior_model_pool = []
  for count, a in enumerate([-5., 0., 5.]):
    local_params = params.AddChild("local_"+str(count))
    local_params["BehaviorConstantAcceleration"]["ConstAcceleration"] = a
    behavior = BehaviorConstantAcceleration(local_params)
    behavior_model_pool.append(behavior)

  env = CounterfactualRuntime(
    blueprint=bp,
    observer=observer,
    render=False,
    params=params,
    behavior_model_pool=behavior_model_pool)
  sac_agent = BehaviorGraphSACAgent(environment=env,
                                    observer=observer,
                                    params=params)
  env.ml_behavior = sac_agent
  runner = SACRunner(params=params,
                     environment=env,
                     agent=sac_agent)

  if FLAGS.mode == "train":
    runner.SetupSummaryWriter()
    runner.Train()
  elif FLAGS.mode == "visualize":
    runner._environment._max_col_rate = 0.
    runner.Run(num_episodes=1, render=True)
  elif FLAGS.mode == "evaluate":
    for cr in np.arange(0, 1, 0.1):
      runner._environment._max_col_rate = cr
      runner.Run(num_episodes=250, render=False, max_col_rate=cr)
    runner._environment._tracer.Save(
      params["ML"]["ResultsFolder"] + "evaluation_results_runtime.pckl")
    goal_reached = runner._tracer.success_rate
    runner._tracer.Save(
      params["ML"]["ResultsFolder"] + "evaluation_results_runner.pckl")