Exemple #1
0
def main(
    script: str,
    scenarios: Sequence[str],
    headless: bool,
    seed: int,
    vehicles_to_replace: int,
    episodes: int,
):
    assert vehicles_to_replace > 0
    assert episodes > 0
    logger = logging.getLogger(script)
    logger.setLevel(logging.INFO)

    logger.debug("initializing SMARTS")

    smarts = SMARTS(
        agent_interfaces={},
        traffic_sim=None,
        envision=None if headless else Envision(),
    )
    random_seed(seed)
    traffic_history_provider = smarts.get_provider_by_type(
        TrafficHistoryProvider)
    assert traffic_history_provider

    scenario_list = Scenario.get_scenario_list(scenarios)
    scenarios_iterator = Scenario.variations_for_all_scenario_roots(
        scenario_list, [])
    for scenario in scenarios_iterator:
        logger.debug("working on scenario {}".format(scenario.name))

        veh_missions = scenario.discover_missions_of_traffic_histories()
        if not veh_missions:
            logger.warning("no vehicle missions found for scenario {}.".format(
                scenario.name))
            continue
        veh_start_times = {
            v_id: mission.start_time
            for v_id, mission in veh_missions.items()
        }

        k = vehicles_to_replace
        if k > len(veh_missions):
            logger.warning(
                "vehicles_to_replace={} is greater than the number of vehicle missions ({})."
                .format(vehicles_to_replace, len(veh_missions)))
            k = len(veh_missions)

        # XXX replace with AgentSpec appropriate for IL model
        agent_spec = AgentSpec(
            interface=AgentInterface.from_type(AgentType.Imitation),
            agent_builder=ReplayCheckerAgent,
            agent_params=smarts.fixed_timestep_sec,
        )

        for episode in range(episodes):
            logger.info(f"starting episode {episode}...")
            agentid_to_vehid = {}
            agent_interfaces = {}

            # Build the Agents for the to-be-hijacked vehicles
            # and gather their missions
            agents = {}
            dones = {}
            ego_missions = {}
            sample = {}

            if scenario.traffic_history.dataset_source == "Waymo":
                # For Waymo, we only hijack the vehicle that was autonomous in the dataset
                waymo_ego_id = scenario.traffic_history.ego_vehicle_id
                if waymo_ego_id is not None:
                    assert (
                        k == 1
                    ), f"do not specify -k > 1 when just hijacking Waymo ego vehicle (it was {k})"
                    veh_id = str(waymo_ego_id)
                    sample = {veh_id}
                else:
                    logger.warning(
                        f"Waymo ego vehicle id not mentioned in the dataset. Hijacking a random vehicle."
                    )

            if not sample:
                # For other datasets, hijack a sample of the recorded vehicles
                # Pick k vehicle missions to hijack with agent
                # and figure out which one starts the earliest
                sample = scenario.traffic_history.random_overlapping_sample(
                    veh_start_times, k)

            if len(sample) < k:
                logger.warning(
                    f"Unable to choose {k} overlapping missions.  allowing non-overlapping."
                )
                leftover = set(veh_start_times.keys()) - sample
                sample.update(set(random.sample(leftover, k - len(sample))))

            agent_spec.interface.max_episode_steps = max([
                scenario.traffic_history.vehicle_final_exit_time(veh_id) / 0.1
                for veh_id in sample
            ])
            history_start_time = None
            logger.info(f"chose vehicles: {sample}")
            for veh_id in sample:
                agent_id = f"ego-agent-IL-{veh_id}"
                agentid_to_vehid[agent_id] = veh_id
                agent_interfaces[agent_id] = agent_spec.interface
                if (not history_start_time
                        or veh_start_times[veh_id] < history_start_time):
                    history_start_time = veh_start_times[veh_id]

            for agent_id in agent_interfaces.keys():
                agent = agent_spec.build_agent()
                veh_id = agentid_to_vehid[agent_id]
                agent.load_data_for_vehicle(veh_id, scenario,
                                            history_start_time)
                agents[agent_id] = agent
                dones[agent_id] = False
                mission = veh_missions[veh_id]
                ego_missions[agent_id] = replace(
                    mission,
                    start_time=mission.start_time - history_start_time)

            # Tell the traffic history provider to start traffic
            # at the point when the earliest agent enters...
            traffic_history_provider.start_time = history_start_time
            # and all the other agents to offset their missions by this much too
            scenario.set_ego_missions(ego_missions)
            logger.info(f"offsetting sim_time by: {history_start_time}")

            # Take control of vehicles with corresponding agent_ids
            smarts.switch_ego_agents(agent_interfaces)

            # Finally start the simulation loop...
            logger.info(f"starting simulation loop...")
            observations = smarts.reset(scenario)
            while not all(done for done in dones.values()):
                actions = {
                    agent_id: agents[agent_id].act(agent_obs)
                    for agent_id, agent_obs in observations.items()
                }
                logger.debug("stepping @ sim_time={} for agents={}...".format(
                    smarts.elapsed_sim_time, list(observations.keys())))
                observations, rewards, dones, infos = smarts.step(actions)

                for agent_id in agents.keys():
                    if dones.get(agent_id, False):
                        if not observations[agent_id].events.reached_goal:
                            logger.warning(
                                "agent_id={} exited @ sim_time={}".format(
                                    agent_id, smarts.elapsed_sim_time))
                            logger.warning("   ... with {}".format(
                                observations[agent_id].events))
                        else:
                            logger.info(
                                "agent_id={} reached goal @ sim_time={}".
                                format(agent_id, smarts.elapsed_sim_time))
                            logger.debug("   ... with {}".format(
                                observations[agent_id].events))
                        del observations[agent_id]

    smarts.destroy()
Exemple #2
0
def main(
    script: str,
    scenarios: Sequence[str],
    headless: bool,
    envision_record_data_replay_path: str,
    seed: int,
    vehicles_to_replace_randomly: int,
    min_timestep_count: int,
    positional_radius: int,
    episodes: int,
):
    assert episodes > 0
    logger = logging.getLogger(script)
    logger.setLevel(logging.INFO)
    logger.debug("initializing SMARTS")

    envision_client = None
    if not headless or envision_record_data_replay_path:
        envision_client = Envision(output_dir=envision_record_data_replay_path)

    smarts = SMARTS(
        agent_interfaces={},
        traffic_sim=None,
        envision=envision_client,
    )
    random_seed(seed)

    scenarios_iterator = Scenario.scenario_variations(scenarios, [])
    scenario = next(scenarios_iterator)

    for episode in range(episodes):
        logger.info(f"starting episode {episode}...")

        def should_trigger(ctx: Dict[str, Any]) -> bool:
            return ctx["elapsed_sim_time"] > 2

        def on_trigger(ctx: Dict[str, Any]):
            # Define agent specs to be assigned
            agent_spec = AgentSpec(
                interface=AgentInterface(waypoints=True, action=ActionSpaceType.Lane),
                agent_builder=BasicAgent,
            )

            # Select a random sample from candidates
            k = ctx.get("vehicles_to_replace_randomly", 0)
            if k <= 0:
                logger.warning(
                    "default (0) or negative value specified for replacement. Replacing all valid vehicle candidates."
                )
                sample = ctx["vehicle_candidates"]
            else:
                logger.info(
                    f"Choosing {k} vehicles randomly from {len(ctx['vehicle_candidates'])} valid vehicle candidates."
                )
                sample = random.sample(ctx["vehicle_candidates"], k)
            assert len(sample) != 0

            for veh_id in sample:
                # Map selected vehicles to agent ids & specs
                agent_id = f"agent-{veh_id}"
                ctx["agents"][agent_id] = agent_spec.build_agent()

                # Create missions based on current state and traffic history
                positional, traverse = scenario.create_dynamic_traffic_history_mission(
                    veh_id, ctx["elapsed_sim_time"], ctx["positional_radius"]
                )

                # Take control of vehicles immediately
                try:
                    # Try to assign a PositionalGoal at the last recorded timestep
                    smarts.add_agent_and_switch_control(
                        veh_id, agent_id, agent_spec.interface, positional
                    )
                except PlanningError:
                    logger.warning(
                        f"Unable to create PositionalGoal for vehicle {veh_id}, falling back to TraverseGoal"
                    )
                    smarts.add_agent_and_switch_control(
                        veh_id, agent_id, agent_spec.interface, traverse
                    )

        # Create a table of vehicle trajectory lengths, filtering out non-moving vehicles
        vehicle_candidates = []
        for v_id in (str(id) for id in scenario.traffic_history.all_vehicle_ids()):
            traj = list(scenario.traffic_history.vehicle_trajectory(v_id))
            # Find moving vehicles with more than the minimum number of timesteps
            if [row for row in traj if row.speed != 0] and len(
                traj
            ) >= min_timestep_count:
                vehicle_candidates.append(v_id)

        assert len(vehicle_candidates) > 0

        k = vehicles_to_replace_randomly
        if k > len(vehicle_candidates):
            logger.warning(
                f"vehicles_to_replace_randomly={k} is greater than the number of vehicle candidates ({len(vehicle_candidates)})."
            )
            k = len(vehicle_candidates)

        # Initialize trigger and define initial context
        context = {
            "agents": {},
            "elapsed_sim_time": 0.0,
            "vehicle_candidates": vehicle_candidates,
            "vehicles_to_replace_randomly": k,
            "positional_radius": positional_radius,
        }
        trigger = Trigger(should_trigger, on_trigger)

        dones = {}
        observations = smarts.reset(scenario)
        while not dones or not all(dones.values()):
            # Update context
            context["elapsed_sim_time"] = smarts.elapsed_sim_time

            # Step trigger to further update context
            trigger.update(context)

            # Get agents from current context
            agents = context["agents"]

            # Step simulation
            actions = {
                agent_id: agents[agent_id].act(agent_obs)
                for agent_id, agent_obs in observations.items()
            }
            logger.debug(
                f"stepping @ sim_time={smarts.elapsed_sim_time} for agents={list(observations.keys())}..."
            )
            observations, rewards, dones, infos = smarts.step(actions)

            for agent_id in agents.keys():
                if dones.get(agent_id, False):
                    if not observations[agent_id].events.reached_goal:
                        logger.warning(
                            f"agent_id={agent_id} exited @ sim_time={smarts.elapsed_sim_time}"
                        )
                        logger.warning(f"   ... with {observations[agent_id].events}")
                    else:
                        logger.info(
                            f"agent_id={agent_id} reached goal @ sim_time={smarts.elapsed_sim_time}"
                        )
                        logger.debug(f"   ... with {observations[agent_id].events}")
                    del observations[agent_id]

    smarts.destroy()