예제 #1
0
    def __init__(self,
                 queue_url,
                 aws_region='us-east-1',
                 race_duration=180,
                 number_of_trials=3,
                 number_of_resets=10000,
                 penalty_seconds=2.0,
                 off_track_penalty=2.0,
                 collision_penalty=5.0,
                 is_continuous=False,
                 race_type="TIME_TRIAL"):
        # constructor arguments
        self._model_updater = ModelUpdater.get_instance()
        self._deepracer_path = rospkg.RosPack().get_path(
            DeepRacerPackages.DEEPRACER_SIMULATION_ENVIRONMENT)
        body_shell_path = os.path.join(self._deepracer_path, "meshes", "f1")
        self._valid_body_shells = \
            set(".".join(f.split(".")[:-1]) for f in os.listdir(body_shell_path) if os.path.isfile(
                os.path.join(body_shell_path, f)))
        self._valid_body_shells.add(const.BodyShellType.DEFAULT.value)
        self._valid_car_colors = set(e.value for e in const.CarColorType
                                     if "f1" not in e.value)
        self._num_sectors = int(rospy.get_param("NUM_SECTORS", "3"))
        self._queue_url = queue_url
        self._region = aws_region
        self._number_of_trials = number_of_trials
        self._number_of_resets = number_of_resets
        self._penalty_seconds = penalty_seconds
        self._off_track_penalty = off_track_penalty
        self._collision_penalty = collision_penalty
        self._is_continuous = is_continuous
        self._race_type = race_type
        self._is_save_simtrace_enabled = False
        self._is_save_mp4_enabled = False
        self._is_event_end = False
        self._done_condition = any
        self._race_duration = race_duration
        self._enable_domain_randomization = False

        # sqs client
        # The boto client errors out after polling for 1 hour.
        self._sqs_client = SQSClient(queue_url=self._queue_url,
                                     region_name=self._region,
                                     max_num_of_msg=MAX_NUM_OF_SQS_MESSAGE,
                                     wait_time_sec=SQS_WAIT_TIME_SEC,
                                     session=refreshed_session(self._region))
        self._s3_client = S3Client(region_name=self._region)
        # tracking current state information
        self._track_data = TrackData.get_instance()
        self._start_lane = self._track_data.center_line
        # keep track of the racer specific info, e.g. s3 locations, alias, car color etc.
        self._current_racer = None
        # keep track of the current race car we are using. It is always "racecar".
        car_model_state = ModelState()
        car_model_state.model_name = "racecar"
        self._current_car_model_state = car_model_state
        self._last_body_shell_type = None
        self._last_sensors = None
        self._racecar_model = AgentModel()
        # keep track of the current control agent we are using
        self._current_agent = None
        # keep track of the current control graph manager
        self._current_graph_manager = None
        # Keep track of previous model's name
        self._prev_model_name = None
        self._hide_position_idx = 0
        self._hide_positions = get_hide_positions(race_car_num=1)
        self._run_phase_subject = RunPhaseSubject()
        self._simtrace_video_s3_writers = []

        self._local_model_directory = './checkpoint'

        # virtual event only have single agent, so set agent_name to "agent"
        self._agent_name = "agent"

        # camera manager
        self._camera_manager = CameraManager.get_instance()

        # setting up virtual event top and follow camera in CameraManager
        # virtual event configure camera does not need to wait for car to spawm because
        # follow car camera is not tracking any car initially
        self._main_cameras, self._sub_camera = configure_camera(
            namespaces=[VIRTUAL_EVENT], is_wait_for_model=False)
        self._spawn_cameras()

        # pop out all cameras after configuration to prevent camera from moving
        self._camera_manager.pop(namespace=VIRTUAL_EVENT)

        dummy_metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value: "dummy-bucket",
            MetricsS3Keys.METRICS_KEY.value: "dummy-key",
            MetricsS3Keys.REGION.value: self._region
        }

        self._eval_metrics = EvalMetrics(
            agent_name=self._agent_name,
            s3_dict_metrics=dummy_metrics_s3_config,
            is_continuous=self._is_continuous,
            pause_time_before_start=PAUSE_TIME_BEFORE_START)

        # upload a default best sector time for all sectors with time inf for each sector
        # if there is not best sector time existed in s3

        # use the s3 bucket and prefix for yaml file stored as environment variable because
        # here is SimApp use only. For virtual event there is no s3 bucket and prefix past
        # through yaml file. All are past through sqs. For simplicity, reuse the yaml s3 bucket
        # and prefix environment variable.
        virtual_event_best_sector_time = VirtualEventBestSectorTime(
            bucket=os.environ.get("YAML_S3_BUCKET", ''),
            s3_key=get_s3_key(os.environ.get("YAML_S3_PREFIX", ''),
                              SECTOR_TIME_S3_POSTFIX),
            region_name=os.environ.get("APP_REGION", "us-east-1"),
            local_path=SECTOR_TIME_LOCAL_PATH)
        response = virtual_event_best_sector_time.list()
        # this is used to handle situation such as robomaker job crash, so the next robomaker job
        # can catch the best sector time left over from crashed job
        if "Contents" not in response:
            virtual_event_best_sector_time.persist(
                body=json.dumps({
                    SECTOR_X_FORMAT.format(idx + 1): float("inf")
                    for idx in range(self._num_sectors)
                }),
                s3_kms_extra_args=utils.get_s3_kms_extra_args())

        # ROS service to indicate all the robomaker markov packages are ready for consumption
        signal_robomaker_markov_package_ready()

        PhaseObserver('/agent/training_phase', self._run_phase_subject)

        # setup mp4 services
        self._setup_mp4_services()
예제 #2
0
def main():
    """ Main function for tournament worker """
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--preset',
                        help="(string) Name of a preset to run \
                             (class name from the 'presets' directory.)",
                        type=str,
                        required=False)
    parser.add_argument('--s3_bucket',
                        help='list(string) S3 bucket',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_BUCKET",
                                                ["gsaur-test"]))
    parser.add_argument('--s3_prefix',
                        help='list(string) S3 prefix',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_PREFIX",
                                                ["sagemaker"]))
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=rospy.get_param("AWS_REGION", "us-east-1"))
    parser.add_argument('--number_of_trials',
                        help='(integer) Number of trials',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)))
    parser.add_argument(
        '-c',
        '--local_model_directory',
        help='(string) Path to a folder containing a checkpoint \
                             to restore the model from.',
        type=str,
        default='./checkpoint')
    parser.add_argument('--number_of_resets',
                        help='(integer) Number of resets',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
    parser.add_argument('--penalty_seconds',
                        help='(float) penalty second',
                        type=float,
                        default=float(rospy.get_param("PENALTY_SECONDS", 2.0)))
    parser.add_argument('--job_type',
                        help='(string) job type',
                        type=str,
                        default=rospy.get_param("JOB_TYPE", "EVALUATION"))
    parser.add_argument('--is_continuous',
                        help='(boolean) is continous after lap completion',
                        type=bool,
                        default=utils.str2bool(
                            rospy.get_param("IS_CONTINUOUS", False)))
    parser.add_argument('--race_type',
                        help='(string) Race type',
                        type=str,
                        default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
    parser.add_argument('--off_track_penalty',
                        help='(float) off track penalty second',
                        type=float,
                        default=float(rospy.get_param("OFF_TRACK_PENALTY",
                                                      2.0)))
    parser.add_argument('--collision_penalty',
                        help='(float) collision penalty second',
                        type=float,
                        default=float(rospy.get_param("COLLISION_PENALTY",
                                                      5.0)))

    args = parser.parse_args()
    arg_s3_bucket = args.s3_bucket
    arg_s3_prefix = args.s3_prefix
    logger.info("S3 bucket: %s \n S3 prefix: %s", arg_s3_bucket, arg_s3_prefix)

    # tournament_worker: names to be displayed in MP4.
    # This is racer alias in tournament worker case.
    display_names = rospy.get_param('DISPLAY_NAME', "")

    metrics_s3_buckets = rospy.get_param('METRICS_S3_BUCKET')
    metrics_s3_object_keys = rospy.get_param('METRICS_S3_OBJECT_KEY')

    arg_s3_bucket, arg_s3_prefix = utils.force_list(
        arg_s3_bucket), utils.force_list(arg_s3_prefix)
    metrics_s3_buckets = utils.force_list(metrics_s3_buckets)
    metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys)

    validate_list = [
        arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets,
        metrics_s3_object_keys
    ]

    simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
    mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None)
    if simtrace_s3_bucket:
        simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
        simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket)
        simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix)
        validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix])
    if mp4_s3_bucket:
        mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')
        mp4_s3_bucket = utils.force_list(mp4_s3_bucket)
        mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix)
        validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix])

    if not all([lambda x: len(x) == len(validate_list[0]), validate_list]):
        utils.log_and_exit(
            "Eval worker error: Incorrect arguments passed: {}".format(
                validate_list), utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
            utils.SIMAPP_EVENT_ERROR_CODE_500)
    if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT:
        raise GenericRolloutException(
            "number of resets is less than {}".format(MIN_RESET_COUNT))

    # Instantiate Cameras
    if len(arg_s3_bucket) == 1:
        configure_camera(namespaces=['racecar'])
    else:
        configure_camera(namespaces=[
            'racecar_{}'.format(str(agent_index))
            for agent_index in range(len(arg_s3_bucket))
        ])

    agent_list = list()
    s3_bucket_dict = dict()
    s3_prefix_dict = dict()
    s3_writers = list()

    # tournament_worker: list of required S3 locations
    simtrace_s3_bucket_dict = dict()
    simtrace_s3_prefix_dict = dict()
    metrics_s3_bucket_dict = dict()
    metrics_s3_obect_key_dict = dict()
    mp4_s3_bucket_dict = dict()
    mp4_s3_object_prefix_dict = dict()

    for agent_index, s3_bucket_val in enumerate(arg_s3_bucket):
        agent_name = 'agent' if len(arg_s3_bucket) == 1 else 'agent_{}'.format(
            str(agent_index))
        racecar_name = 'racecar' if len(
            arg_s3_bucket) == 1 else 'racecar_{}'.format(str(agent_index))
        s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index]
        s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index]

        # tournament_worker: remap key with agent_name instead of agent_index for list of S3 locations.
        simtrace_s3_bucket_dict[agent_name] = simtrace_s3_bucket[agent_index]
        simtrace_s3_prefix_dict[agent_name] = simtrace_s3_object_prefix[
            agent_index]
        metrics_s3_bucket_dict[agent_name] = metrics_s3_buckets[agent_index]
        metrics_s3_obect_key_dict[agent_name] = metrics_s3_object_keys[
            agent_index]
        mp4_s3_bucket_dict[agent_name] = mp4_s3_bucket[agent_index]
        mp4_s3_object_prefix_dict[agent_name] = mp4_s3_object_prefix[
            agent_index]

        s3_client = SageS3Client(bucket=arg_s3_bucket[agent_index],
                                 s3_prefix=arg_s3_prefix[agent_index],
                                 aws_region=args.aws_region)

        # Load the model metadata
        if not os.path.exists(os.path.join(CUSTOM_FILES_PATH, agent_name)):
            os.makedirs(os.path.join(CUSTOM_FILES_PATH, agent_name))
        model_metadata_local_path = os.path.join(
            os.path.join(CUSTOM_FILES_PATH, agent_name), 'model_metadata.json')
        utils.load_model_metadata(
            s3_client,
            os.path.normpath("%s/model/model_metadata.json" %
                             arg_s3_prefix[agent_index]),
            model_metadata_local_path)
        # Handle backward compatibility
        _, _, version = parse_model_metadata(model_metadata_local_path)
        if float(version) < float(utils.SIMAPP_VERSION) and \
        not utils.has_current_ckpnt_name(arg_s3_bucket[agent_index], arg_s3_prefix[agent_index], args.aws_region):
            utils.make_compatible(arg_s3_bucket[agent_index],
                                  arg_s3_prefix[agent_index], args.aws_region,
                                  SyncFiles.TRAINER_READY.value)

        # Select the optimal model
        utils.do_model_selection(s3_bucket=arg_s3_bucket[agent_index],
                                 s3_prefix=arg_s3_prefix[agent_index],
                                 region=args.aws_region)

        # Download hyperparameters from SageMaker
        if not os.path.exists(agent_name):
            os.makedirs(agent_name)
        hyperparameters_file_success = False
        hyperparams_s3_key = os.path.normpath(arg_s3_prefix[agent_index] +
                                              "/ip/hyperparameters.json")
        hyperparameters_file_success = s3_client.download_file(
            s3_key=hyperparams_s3_key,
            local_path=os.path.join(agent_name, "hyperparameters.json"))
        sm_hyperparams_dict = {}
        if hyperparameters_file_success:
            logger.info("Received Sagemaker hyperparameters successfully!")
            with open(os.path.join(agent_name,
                                   "hyperparameters.json")) as file:
                sm_hyperparams_dict = json.load(file)
        else:
            logger.info("SageMaker hyperparameters not found.")

        agent_config = {
            'model_metadata': model_metadata_local_path,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [
                    link_name.replace('racecar', racecar_name)
                    for link_name in LINK_NAMES
                ],
                ConfigParams.VELOCITY_LIST.value: [
                    velocity_topic.replace('racecar', racecar_name)
                    for velocity_topic in VELOCITY_TOPICS
                ],
                ConfigParams.STEERING_LIST.value: [
                    steering_topic.replace('racecar', racecar_name)
                    for steering_topic in STEERING_TOPICS
                ],
                ConfigParams.CHANGE_START.value:
                utils.str2bool(rospy.get_param('CHANGE_START_POSITION',
                                               False)),
                ConfigParams.ALT_DIR.value:
                utils.str2bool(
                    rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
                ConfigParams.ACTION_SPACE_PATH.value:
                'custom_files/' + agent_name + '/model_metadata.json',
                ConfigParams.REWARD.value:
                reward_function,
                ConfigParams.AGENT_NAME.value:
                racecar_name,
                ConfigParams.VERSION.value:
                version,
                ConfigParams.NUMBER_OF_RESETS.value:
                args.number_of_resets,
                ConfigParams.PENALTY_SECONDS.value:
                args.penalty_seconds,
                ConfigParams.NUMBER_OF_TRIALS.value:
                args.number_of_trials,
                ConfigParams.IS_CONTINUOUS.value:
                args.is_continuous,
                ConfigParams.RACE_TYPE.value:
                args.race_type,
                ConfigParams.COLLISION_PENALTY.value:
                args.collision_penalty,
                ConfigParams.OFF_TRACK_PENALTY.value:
                args.off_track_penalty
            }
        }

        metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value:
            metrics_s3_buckets[agent_index],
            MetricsS3Keys.METRICS_KEY.value:
            metrics_s3_object_keys[agent_index],
            # Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed
            # or default argument set
            MetricsS3Keys.REGION.value:
            args.aws_region,
            # Replaced rospy.get_param('MODEL_S3_BUCKET') to be equal to the argument being passed
            # or default argument set
            MetricsS3Keys.STEP_BUCKET.value:
            arg_s3_bucket[agent_index],
            # Replaced rospy.get_param('MODEL_S3_PREFIX') to be equal to the argument being passed
            # or default argument set
            MetricsS3Keys.STEP_KEY.value:
            os.path.join(arg_s3_prefix[agent_index],
                         EVALUATION_SIMTRACE_DATA_S3_OBJECT_KEY)
        }
        aws_region = rospy.get_param('AWS_REGION', args.aws_region)
        s3_writer_job_info = []
        if simtrace_s3_bucket:
            s3_writer_job_info.append(
                IterationData(
                    'simtrace', simtrace_s3_bucket[agent_index],
                    simtrace_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        SIM_TRACE_EVALUATION_LOCAL_FILE.value)))
        if mp4_s3_bucket:
            s3_writer_job_info.extend([
                IterationData(
                    'pip', mp4_s3_bucket[agent_index],
                    mp4_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        CAMERA_PIP_MP4_VALIDATION_LOCAL_PATH.value)),
                IterationData(
                    '45degree', mp4_s3_bucket[agent_index],
                    mp4_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        CAMERA_45DEGREE_MP4_VALIDATION_LOCAL_PATH.value)),
                IterationData(
                    'topview', mp4_s3_bucket[agent_index],
                    mp4_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        CAMERA_TOPVIEW_MP4_VALIDATION_LOCAL_PATH.value))
            ])

        s3_writers.append(S3Writer(job_info=s3_writer_job_info))
        run_phase_subject = RunPhaseSubject()
        agent_list.append(
            create_rollout_agent(agent_config,
                                 EvalMetrics(agent_name, metrics_s3_config),
                                 run_phase_subject))
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())
    # ROS service to indicate all the robomaker markov packages are ready for consumption
    signal_robomaker_markov_package_ready()

    PhaseObserver('/agent/training_phase', run_phase_subject)

    graph_manager, _ = get_graph_manager(hp_dict=sm_hyperparams_dict,
                                         agent_list=agent_list,
                                         run_phase_subject=run_phase_subject)

    ds_params_instance = S3BotoDataStoreParameters(
        aws_region=args.aws_region,
        bucket_names=s3_bucket_dict,
        base_checkpoint_dir=args.local_model_directory,
        s3_folders=s3_prefix_dict)

    graph_manager.data_store = S3BotoDataStore(params=ds_params_instance,
                                               graph_manager=graph_manager,
                                               ignore_lock=True)
    graph_manager.env_params.seed = 0

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.local_model_directory

    tournament_worker(graph_manager=graph_manager,
                      number_of_trials=args.number_of_trials,
                      task_parameters=task_parameters,
                      s3_writers=s3_writers,
                      is_continuous=args.is_continuous)

    # tournament_worker: write race report to local file.
    write_race_report(graph_manager,
                      model_s3_bucket_map=s3_bucket_dict,
                      model_s3_prefix_map=s3_prefix_dict,
                      metrics_s3_bucket_map=metrics_s3_bucket_dict,
                      metrics_s3_key_map=metrics_s3_obect_key_dict,
                      simtrace_s3_bucket_map=simtrace_s3_bucket_dict,
                      simtrace_s3_prefix_map=simtrace_s3_prefix_dict,
                      mp4_s3_bucket_map=mp4_s3_bucket_dict,
                      mp4_s3_prefix_map=mp4_s3_object_prefix_dict,
                      display_names=display_names)

    # tournament_worker: terminate tournament_race_node.
    terminate_tournament_race()
예제 #3
0
class VirtualEventManager(object):
    """
        This is the manager that manages the live virtual event.
    """
    def __init__(self,
                 queue_url,
                 aws_region='us-east-1',
                 race_duration=180,
                 number_of_trials=3,
                 number_of_resets=10000,
                 penalty_seconds=2.0,
                 off_track_penalty=2.0,
                 collision_penalty=5.0,
                 is_continuous=False,
                 race_type="TIME_TRIAL"):
        # constructor arguments
        self._model_updater = ModelUpdater.get_instance()
        self._deepracer_path = rospkg.RosPack().get_path(
            DeepRacerPackages.DEEPRACER_SIMULATION_ENVIRONMENT)
        body_shell_path = os.path.join(self._deepracer_path, "meshes", "f1")
        self._valid_body_shells = \
            set(".".join(f.split(".")[:-1]) for f in os.listdir(body_shell_path) if os.path.isfile(
                os.path.join(body_shell_path, f)))
        self._valid_body_shells.add(const.BodyShellType.DEFAULT.value)
        self._valid_car_colors = set(e.value for e in const.CarColorType
                                     if "f1" not in e.value)
        self._num_sectors = int(rospy.get_param("NUM_SECTORS", "3"))
        self._queue_url = queue_url
        self._region = aws_region
        self._number_of_trials = number_of_trials
        self._number_of_resets = number_of_resets
        self._penalty_seconds = penalty_seconds
        self._off_track_penalty = off_track_penalty
        self._collision_penalty = collision_penalty
        self._is_continuous = is_continuous
        self._race_type = race_type
        self._is_save_simtrace_enabled = False
        self._is_save_mp4_enabled = False
        self._is_event_end = False
        self._done_condition = any
        self._race_duration = race_duration
        self._enable_domain_randomization = False

        # sqs client
        # The boto client errors out after polling for 1 hour.
        self._sqs_client = SQSClient(queue_url=self._queue_url,
                                     region_name=self._region,
                                     max_num_of_msg=MAX_NUM_OF_SQS_MESSAGE,
                                     wait_time_sec=SQS_WAIT_TIME_SEC,
                                     session=refreshed_session(self._region))
        self._s3_client = S3Client(region_name=self._region)
        # tracking current state information
        self._track_data = TrackData.get_instance()
        self._start_lane = self._track_data.center_line
        # keep track of the racer specific info, e.g. s3 locations, alias, car color etc.
        self._current_racer = None
        # keep track of the current race car we are using. It is always "racecar".
        car_model_state = ModelState()
        car_model_state.model_name = "racecar"
        self._current_car_model_state = car_model_state
        self._last_body_shell_type = None
        self._last_sensors = None
        self._racecar_model = AgentModel()
        # keep track of the current control agent we are using
        self._current_agent = None
        # keep track of the current control graph manager
        self._current_graph_manager = None
        # Keep track of previous model's name
        self._prev_model_name = None
        self._hide_position_idx = 0
        self._hide_positions = get_hide_positions(race_car_num=1)
        self._run_phase_subject = RunPhaseSubject()
        self._simtrace_video_s3_writers = []

        self._local_model_directory = './checkpoint'

        # virtual event only have single agent, so set agent_name to "agent"
        self._agent_name = "agent"

        # camera manager
        self._camera_manager = CameraManager.get_instance()

        # setting up virtual event top and follow camera in CameraManager
        # virtual event configure camera does not need to wait for car to spawm because
        # follow car camera is not tracking any car initially
        self._main_cameras, self._sub_camera = configure_camera(
            namespaces=[VIRTUAL_EVENT], is_wait_for_model=False)
        self._spawn_cameras()

        # pop out all cameras after configuration to prevent camera from moving
        self._camera_manager.pop(namespace=VIRTUAL_EVENT)

        dummy_metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value: "dummy-bucket",
            MetricsS3Keys.METRICS_KEY.value: "dummy-key",
            MetricsS3Keys.REGION.value: self._region
        }

        self._eval_metrics = EvalMetrics(
            agent_name=self._agent_name,
            s3_dict_metrics=dummy_metrics_s3_config,
            is_continuous=self._is_continuous,
            pause_time_before_start=PAUSE_TIME_BEFORE_START)

        # upload a default best sector time for all sectors with time inf for each sector
        # if there is not best sector time existed in s3

        # use the s3 bucket and prefix for yaml file stored as environment variable because
        # here is SimApp use only. For virtual event there is no s3 bucket and prefix past
        # through yaml file. All are past through sqs. For simplicity, reuse the yaml s3 bucket
        # and prefix environment variable.
        virtual_event_best_sector_time = VirtualEventBestSectorTime(
            bucket=os.environ.get("YAML_S3_BUCKET", ''),
            s3_key=get_s3_key(os.environ.get("YAML_S3_PREFIX", ''),
                              SECTOR_TIME_S3_POSTFIX),
            region_name=os.environ.get("APP_REGION", "us-east-1"),
            local_path=SECTOR_TIME_LOCAL_PATH)
        response = virtual_event_best_sector_time.list()
        # this is used to handle situation such as robomaker job crash, so the next robomaker job
        # can catch the best sector time left over from crashed job
        if "Contents" not in response:
            virtual_event_best_sector_time.persist(
                body=json.dumps({
                    SECTOR_X_FORMAT.format(idx + 1): float("inf")
                    for idx in range(self._num_sectors)
                }),
                s3_kms_extra_args=utils.get_s3_kms_extra_args())

        # ROS service to indicate all the robomaker markov packages are ready for consumption
        signal_robomaker_markov_package_ready()

        PhaseObserver('/agent/training_phase', self._run_phase_subject)

        # setup mp4 services
        self._setup_mp4_services()

    @property
    def current_racer(self):
        """ Get the current racer object.

        Returns:
            RacerInformation: Information about current racer that was passed-in from the queue.
        """
        return self._current_racer

    @property
    def is_event_end(self):
        """Return True if the service has signaled event end

        Returns:
            boolean: Is it the time to kill everything and die
        """
        return self._is_event_end

    def _spawn_cameras(self):
        '''helper method for initializing cameras
        '''
        # virtual event configure camera does not need to wait for car to spawm because
        # follow car camera is not tracking any car initially
        camera_manager = CameraManager.get_instance()

        # pop all camera under virtual event namespace
        camera_manager.pop(namespace=VIRTUAL_EVENT)

        # Spawn the follow car camera
        LOG.info(
            "[virtual event manager] Spawning virtual event follow car camera model"
        )
        initial_pose = self._track_data.get_racecar_start_pose(
            racecar_idx=0,
            racer_num=1,
            start_position=get_start_positions(1)[0])
        self._main_cameras[VIRTUAL_EVENT].spawn_model(
            initial_pose,
            os.path.join(self._deepracer_path, "models", "camera",
                         "model.sdf"))

        LOG.info("[virtual event manager] Spawning sub camera model")
        # Spawn the top camera model
        self._sub_camera.spawn_model(
            None,
            os.path.join(self._deepracer_path, "models", "top_camera",
                         "model.sdf"))

    def poll_next_racer(self):
        """
            Poll from sqs the next racer information.
        """
        received_racer = False
        while not received_racer:
            # Polling MAX_NUM_OF_SQS_MESSAGE=1 message from sqs
            # with wait time specified in SQS_WAIT_TIME_SEC
            response = self._sqs_client.get_messages()
            # valid response is non-empty list
            if response:
                message_body = response[0]
                try:
                    # validate the current racer information.
                    validate_json_input(message_body, RACER_INFO_JSON_SCHEMA)
                    # Parse JSON into an racer information object
                    # with attributes corresponding to dict keys
                    self._current_racer = json.loads(
                        message_body,
                        object_hook=lambda d: namedtuple(
                            RACER_INFO_OBJECT, d.keys())(*d.values()))
                    # only set received_racer to True after making sure the message is valid.
                    received_racer = True
                    LOG.info(
                        "[virtual event manager] Received next racer's information %s",
                        self._current_racer)
                except GenericNonFatalException as ex:
                    ex.log_except_and_continue()

    def setup_race(self):
        """
            Setting up the race for the current racer.

        Returns:
            bool: True if setup race is successful.
                  False is a non fatal exception occurred.
        """

        LOG.info("[virtual event manager] Setting up race for racer")
        try:
            self._model_updater.unpause_physics()
            LOG.info(
                "[virtual event manager] Unpause physics in current world to setup race."
            )
            # step 1: hide the racecar to a position that camera cannot see
            self._hide_racecar_model(
                model_name=self._current_car_model_state.model_name)

            # step 2: set camera to starting position after previous car is deleted
            initial_pose = self._track_data.get_racecar_start_pose(
                racecar_idx=0,
                racer_num=1,
                start_position=get_start_positions(1)[0])
            self._main_cameras[VIRTUAL_EVENT].reset_pose(car_pose=initial_pose)
            LOG.info("[virtual event manager] Reset camera to starting line.")

            # step 3: download model metadata from s3
            sensors, version, model_metadata = self._download_model_metadata()

            # step 4: check whether body shell and sensors have been updated
            # to decide whether need to delete and re-spawn. Then, update
            # shell or color accordingly
            if hasattr(self._current_racer, "carConfig") and \
                    hasattr(self._current_racer.carConfig, "bodyShellType"):
                body_shell_type = self._current_racer.carConfig.bodyShellType \
                    if self._current_racer.carConfig.bodyShellType in self._valid_body_shells \
                    else const.BodyShellType.DEFAULT.value
            else:
                body_shell_type = const.BodyShellType.DEFAULT.value

            # check whether need to delete and respawn racecar
            # re-spawn if sensor or body shell type changed
            if self._last_body_shell_type != body_shell_type or \
                    self._last_sensors != sensors:
                # delete last racecar
                self._racecar_model.delete()
                # respawn a new racecar
                hide_pose = Pose()
                hide_pose.position.x = self._hide_positions[
                    self._hide_position_idx][0]
                hide_pose.position.y = self._hide_positions[
                    self._hide_position_idx][1]
                self._racecar_model.spawn(
                    name=self._current_car_model_state.model_name,
                    pose=hide_pose,
                    include_second_camera="true"
                    if Input.STEREO.value in sensors else "false",
                    include_lidar_sensor=str(
                        any(["lidar" in sensor.lower()
                             for sensor in sensors])).lower(),
                    body_shell_type=body_shell_type,
                    lidar_360_degree_sample=str(LIDAR_360_DEGREE_SAMPLE),
                    lidar_360_degree_horizontal_resolution=str(
                        LIDAR_360_DEGREE_HORIZONTAL_RESOLUTION),
                    lidar_360_degree_min_angle=str(LIDAR_360_DEGREE_MIN_ANGLE),
                    lidar_360_degree_max_angle=str(LIDAR_360_DEGREE_MAX_ANGLE),
                    lidar_360_degree_min_range=str(LIDAR_360_DEGREE_MIN_RANGE),
                    lidar_360_degree_max_range=str(LIDAR_360_DEGREE_MAX_RANGE),
                    lidar_360_degree_range_resolution=str(
                        LIDAR_360_DEGREE_RANGE_RESOLUTION),
                    lidar_360_degree_noise_mean=str(
                        LIDAR_360_DEGREE_NOISE_MEAN),
                    lidar_360_degree_noise_stddev=str(
                        LIDAR_360_DEGREE_NOISE_STDDEV))
                self._last_body_shell_type = body_shell_type
                self._last_sensors = sensors

            # step 5: download checkpoint, setup simtrace, mp4, clear metrics, and setup graph manager
            # download checkpoint from s3
            checkpoint = self._download_checkpoint(version)
            # setup the simtrace and mp4 writers if the s3 locations are available
            self._setup_simtrace_mp4_writers()
            # reset the metrics s3 location for the current racer
            self._reset_metrics_loc()
            # setup agents
            agent_list = self._get_agent_list(model_metadata, version)
            # after _setup_graph_manager finishes, physics is paused
            # physics will be unpaused again when race start
            self._setup_graph_manager(checkpoint, agent_list)
            LOG.info(
                "[virtual event manager] Graph manager successfully created the graph: setup race successful."
            )

            # step 6: update body shell or color
            # treat amazon van digital reward specially by also hiding the collision wheel
            visuals = self._model_updater.get_model_visuals(
                self._current_car_model_state.model_name)
            if const.F1 in body_shell_type:
                self._model_updater.hide_visuals(
                    visuals=visuals,
                    ignore_keywords=["f1_body_link"] if "with_wheel"
                    in body_shell_type.lower() else ["wheel", "f1_body_link"])
            else:
                if hasattr(self._current_racer, "carConfig") and \
                        hasattr(self._current_racer.carConfig, "carColor"):
                    car_color = self._current_racer.carConfig.carColor if self._current_racer.carConfig.carColor in self._valid_car_colors \
                        else DEFAULT_COLOR
                else:
                    car_color = DEFAULT_COLOR
                self._model_updater.update_color(visuals, car_color)
            return True
        except GenericNonFatalException as ex:
            ex.log_except_and_continue()
            self.upload_race_status(status_code=ex.error_code,
                                    error_name=ex.error_name,
                                    error_details=ex.error_msg)
            self._clean_up_race()
            return False
        except Exception as ex:
            log_and_exit(
                "[virtual event manager] Something really wrong happened when setting up the race. {}"
                .format(ex), SIMAPP_VIRTUAL_EVENT_RACE_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500)

    def _download_checkpoint(self, version):
        """Setup the Checkpoint object and selete the best checkpoint.

        Args:
            version (float): SimApp version

        Returns:
            Checkpoint: Checkpoint class instance
        """
        # download checkpoint from s3
        checkpoint = Checkpoint(
            bucket=self._current_racer.inputModel.s3BucketName,
            s3_prefix=self._current_racer.inputModel.s3KeyPrefix,
            region_name=self._region,
            agent_name=self._agent_name,
            checkpoint_dir=self._local_model_directory)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible(
        ):
            checkpoint.rl_coach_checkpoint.make_compatible(
                checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint(
        )
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        model_kms = self._current_racer.inputModel.s3KmsKeyArn if hasattr(
            self._current_racer.inputModel, 's3KmsKeyArn') else None
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_extra_args(model_kms))
        return checkpoint

    def _download_model_metadata(self):
        """ Attempt to download model metadata from s3.

        Raises:
            GenericNonFatalException: An non fatal exception which we will
                                      catch and proceed with work loop.

        Returns:
            sensors, version, model_metadata: The needed information from model metadata.
        """
        model_metadata_s3_key = get_s3_key(
            self._current_racer.inputModel.s3KeyPrefix,
            MODEL_METADATA_S3_POSTFIX)
        try:
            model_metadata = ModelMetadata(
                bucket=self._current_racer.inputModel.s3BucketName,
                s3_key=model_metadata_s3_key,
                region_name=self._region,
                local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(
                    self._agent_name))
            model_metadata_info = model_metadata.get_model_metadata_info()
            sensors = model_metadata_info[ModelMetadataKeys.SENSOR.value]
            simapp_version = model_metadata_info[
                ModelMetadataKeys.VERSION.value]
        except botocore.exceptions.ClientError as err:
            error_msg = "[s3] Client Error: Failed to download model_metadata file: \
                        s3_bucket: {}, s3_key: {}, {}.".format(
                self._current_racer.inputModel.s3BucketName,
                model_metadata_s3_key, err)
            raise GenericNonFatalException(
                error_msg=error_msg,
                error_code=SIMAPP_EVENT_ERROR_CODE_400,
                error_name=SIMAPP_EVENT_USER_ERROR)
        except Exception as err:
            error_msg = "[s3] System Error: Failed to download model_metadata file: \
                        s3_bucket: {}, s3_key: {}, {}.".format(
                self._current_racer.inputModel.s3BucketName,
                model_metadata_s3_key, err)
            raise GenericNonFatalException(
                error_msg=error_msg,
                error_code=SIMAPP_EVENT_ERROR_CODE_500,
                error_name=SIMAPP_EVENT_SYSTEM_ERROR)
        return sensors, simapp_version, model_metadata

    def start_race(self):
        """
            Start the race (evaluation) for the current racer.
        """
        LOG.info("[virtual event manager] Starting race for racer %s",
                 self._current_racer.racerAlias)
        # send request
        if self._is_save_mp4_enabled:
            # racecar_color is not used for virtual event image editing, so simply pass default "Black"
            self._subscribe_to_save_mp4(
                VirtualEventVideoEditSrvRequest(
                    display_name=self._current_racer.racerAlias,
                    racecar_color=DEFAULT_COLOR))

        configure_environment_randomizer()
        self._model_updater.unpause_physics()
        LOG.info("[virtual event manager] Unpaused physics in current world.")

        if self._is_continuous:
            self._evaluate()
        else:
            self._evaluate()
            for _ in range(self._number_of_trials - 1):
                self._current_graph_manager.evaluate(EnvironmentSteps(1))

    def _evaluate(self):
        """Evaluate: step order is important and do not change

        First, reset internal state such as reset metrics, wait for
        sensor to be ready and reset the agent to the starting position.
        Second, it adds the virtual event camera to camera manager to follow specific agent.
        Third, call evaluate without reset again.

        This order is critical important and should not be changed for below reason.

        In the first step, a manual reset is done. In this manual reset step. It will
            1. reset agent metrics such as add pause second to total race second
            and clean up previous left over parameters.
            2. wait for sensor topics to become ready. Especially, when use LIDAR,
            LIDAR topic can take longer to come up.
            3. reset agent to the start line
        After the first step, camera is then attached to racecar through camera manager add.
        We should not add camera into camera manager before graph manager manual reset internal
        state. The reason is because racecar is initially spawn at a hide location. If we add
        camera before reset_internal_state, camera will record video at hide position.

        Last, we will start evaluate without reset again because we have already manually reset.
        It is critical here to NOT reset again. If we reset for a second time, we will double count
        the prepare time before the race
        """
        self._current_graph_manager.reset_internal_state(
            force_environment_reset=True)
        # Update CameraManager by adding cameras into the current namespace. By doing so
        # a single follow car camera will follow the current active racecar.
        self._camera_manager.add(self._main_cameras[VIRTUAL_EVENT],
                                 self._current_car_model_state.model_name)
        self._camera_manager.add(self._sub_camera,
                                 self._current_car_model_state.model_name)

        self._current_graph_manager.evaluate(EnvironmentSteps(1),
                                             reset_before_eval=False)

    def finish_race(self):
        """
            Finish the race for the current racer.
        """
        # pause physics of the world
        self._model_updater.pause_physics()
        time.sleep(1)
        # unsubscribe mp4
        if self._is_save_mp4_enabled:
            self._unsubscribe_from_save_mp4(EmptyRequest())
        self._track_data.remove_object(
            name=self._current_car_model_state.model_name)
        LOG.info("[virtual event manager] Finish Race - remove object %s.",
                 self._current_car_model_state.model_name)
        # pop out current racecar from camera namespace to prevent camera from moving
        self._camera_manager.pop(
            namespace=self._current_car_model_state.model_name)
        # upload simtrace and mp4 into s3 bucket
        self._save_simtrace_mp4()
        self.upload_race_status(status_code=200)
        # keep track of the previous model name
        self._prev_model_name = self._current_car_model_state.model_name
        # clean up local trace of current race
        self._clean_up_race()

    def _clean_up_race(self):
        """Helper function to clean up everything related to the ex-racer
           and get ready for the next racer.
        """
        self._simtrace_video_s3_writers = []
        # clean up local checkpoints etc.
        self._clean_local_directory()
        # reset the tensorflow graph to avoid errors with the global session
        tf.reset_default_graph()
        # reset the current racer
        self._current_racer = None
        self._current_agent = None
        self._current_graph_manager = None
        self._hide_position_idx = 0

    def _save_simtrace_mp4(self):
        """Get the appropriate kms key and save the simtrace and mp4 files.
        """
        # TODO: It might be theorically possible to have different kms keys for simtrace and mp4
        # but we are using the same key now since that's what happens in real life
        # consider refactor the simtrace_video_s3_writers later.
        if hasattr(self._current_racer.outputMp4, 's3KmsKeyArn'):
            simtrace_mp4_kms = self._current_racer.outputMp4.s3KmsKeyArn
        elif hasattr(self._current_racer.outputSimTrace, 's3KmsKeyArn'):
            simtrace_mp4_kms = self._current_racer.outputSimTrace.s3KmsKeyArn
        else:
            simtrace_mp4_kms = None
        for s3_writer in self._simtrace_video_s3_writers:
            s3_writer.persist(utils.get_s3_extra_args(simtrace_mp4_kms))

    def _reset_metrics_loc(self):
        """Reset the metrics location as new racer is loaded.
        """
        metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value:
            self._current_racer.outputMetrics.s3BucketName,
            MetricsS3Keys.METRICS_KEY.value:
            self._current_racer.outputMetrics.s3KeyPrefix,
            MetricsS3Keys.REGION.value: self._region
        }
        self._eval_metrics.reset_metrics(
            s3_dict_metrics=metrics_s3_config,
            is_save_simtrace_enabled=self._is_save_simtrace_enabled)

    def _setup_simtrace_mp4_writers(self):
        """Setup the simtrace and mp4 writers if the locations are passed in.
        """
        self._is_save_simtrace_enabled = False
        self._is_save_mp4_enabled = False
        if hasattr(self._current_racer, 'outputSimTrace'):
            self._simtrace_video_s3_writers.append(
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.SIMTRACE_EVAL.value,
                    bucket=self._current_racer.outputSimTrace.s3BucketName,
                    s3_prefix=self._current_racer.outputSimTrace.s3KeyPrefix,
                    region_name=self._region,
                    local_path=SIMTRACE_EVAL_LOCAL_PATH_FORMAT.format(
                        self._agent_name)))
            self._is_save_simtrace_enabled = True
        if hasattr(self._current_racer, 'outputMp4'):
            self._simtrace_video_s3_writers.extend([
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.PIP.value,
                    bucket=self._current_racer.outputMp4.s3BucketName,
                    s3_prefix=self._current_racer.outputMp4.s3KeyPrefix,
                    region_name=self._region,
                    local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format(
                        self._agent_name)),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.DEGREE45.value,
                    bucket=self._current_racer.outputMp4.s3BucketName,
                    s3_prefix=self._current_racer.outputMp4.s3KeyPrefix,
                    region_name=self._region,
                    local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format(
                        self._agent_name)),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.TOPVIEW.value,
                    bucket=self._current_racer.outputMp4.s3BucketName,
                    s3_prefix=self._current_racer.outputMp4.s3KeyPrefix,
                    region_name=self._region,
                    local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format(
                        self._agent_name))
            ])
            self._is_save_mp4_enabled = True

    def _setup_graph_manager(self, checkpoint, agent_list):
        """Sets up graph manager based on the checkpoint file and agents list.

        Args:
            checkpoint (Checkpoint): The model checkpoints we just downloaded.
            agent_list (list[Agent]): List of agents we want to setup graph manager for.
        """
        sm_hyperparams_dict = {}
        self._current_graph_manager, _ = get_graph_manager(
            hp_dict=sm_hyperparams_dict,
            agent_list=agent_list,
            run_phase_subject=self._run_phase_subject,
            enable_domain_randomization=self._enable_domain_randomization,
            done_condition=self._done_condition,
            pause_physics=self._model_updater.pause_physics_service,
            unpause_physics=self._model_updater.unpause_physics_service)
        checkpoint_dict = dict()
        checkpoint_dict[self._agent_name] = checkpoint
        ds_params_instance = S3BotoDataStoreParameters(
            checkpoint_dict=checkpoint_dict)

        self._current_graph_manager.data_store = S3BotoDataStore(
            params=ds_params_instance,
            graph_manager=self._current_graph_manager,
            ignore_lock=True,
            log_and_cont=True)
        self._current_graph_manager.env_params.seed = 0

        self._current_graph_manager.data_store.wait_for_checkpoints()
        self._current_graph_manager.data_store.modify_checkpoint_variables()

        task_parameters = TaskParameters()
        task_parameters.checkpoint_restore_path = self._local_model_directory

        self._current_graph_manager.create_graph(
            task_parameters=task_parameters,
            stop_physics=self._model_updater.pause_physics_service,
            start_physics=self._model_updater.unpause_physics_service,
            empty_service_call=EmptyRequest)

    def _get_agent_list(self, model_metadata, version):
        """Setup agent and get the agents list.

        Args:
            model_metadata (ModelMetadata): Current racer's model metadata
            version (str): The current racer's simapp version in the model metadata

        Returns:
            agent_list (list): The list of agents for the current racer
        """
        # setup agent
        agent_config = {
            'model_metadata': model_metadata,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [
                    link_name.replace('racecar',
                                      self._current_car_model_state.model_name)
                    for link_name in LINK_NAMES
                ],
                ConfigParams.VELOCITY_LIST.value: [
                    velocity_topic.replace(
                        'racecar', self._current_car_model_state.model_name)
                    for velocity_topic in VELOCITY_TOPICS
                ],
                ConfigParams.STEERING_LIST.value: [
                    steering_topic.replace(
                        'racecar', self._current_car_model_state.model_name)
                    for steering_topic in STEERING_TOPICS
                ],
                ConfigParams.CHANGE_START.value:
                utils.str2bool(rospy.get_param('CHANGE_START_POSITION',
                                               False)),
                ConfigParams.ALT_DIR.value:
                utils.str2bool(
                    rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
                ConfigParams.MODEL_METADATA.value:
                model_metadata,
                ConfigParams.REWARD.value:
                reward_function,
                ConfigParams.AGENT_NAME.value:
                self._current_car_model_state.model_name,
                ConfigParams.VERSION.value:
                version,
                ConfigParams.NUMBER_OF_RESETS.value:
                self._number_of_resets,
                ConfigParams.PENALTY_SECONDS.value:
                self._penalty_seconds,
                ConfigParams.NUMBER_OF_TRIALS.value:
                self._number_of_trials,
                ConfigParams.IS_CONTINUOUS.value:
                self._is_continuous,
                ConfigParams.RACE_TYPE.value:
                self._race_type,
                ConfigParams.COLLISION_PENALTY.value:
                self._collision_penalty,
                ConfigParams.OFF_TRACK_PENALTY.value:
                self._off_track_penalty,
                ConfigParams.START_POSITION.value:
                get_start_positions(1)
                [0],  # hard-coded to the first start position
                ConfigParams.DONE_CONDITION.value:
                self._done_condition,
                ConfigParams.IS_VIRTUAL_EVENT.value:
                True,
                ConfigParams.RACE_DURATION.value:
                self._race_duration
            }
        }

        agent_list = list()
        agent_list.append(
            create_rollout_agent(agent_config, self._eval_metrics,
                                 self._run_phase_subject))
        agent_list.append(create_obstacles_agent())
        agent_list.append(
            create_bot_cars_agent(
                pause_time_before_start=PAUSE_TIME_BEFORE_START))
        return agent_list

    def _setup_mp4_services(self):
        """
        Setting up the mp4 ros services if mp4s need to be saved.
        """
        mp4_sub = "/{}/save_mp4/subscribe_to_save_mp4".format(VIRTUAL_EVENT)
        mp4_unsub = "/{}/save_mp4/unsubscribe_from_save_mp4".format(
            VIRTUAL_EVENT)
        rospy.wait_for_service(mp4_sub)
        rospy.wait_for_service(mp4_unsub)
        self._subscribe_to_save_mp4 = ServiceProxyWrapper(
            mp4_sub, VirtualEventVideoEditSrv)
        self._unsubscribe_from_save_mp4 = ServiceProxyWrapper(mp4_unsub, Empty)

    def _clean_local_directory(self):
        """Clean up the local directory after race ends.
        """
        LOG.info(
            "[virtual event manager] cleaning up the local directory after race ends."
        )
        for root, _, files in os.walk(self._local_model_directory):
            for f in files:
                os.remove(os.path.join(root, f))

    def _hide_racecar_model(self, model_name):
        """hide racecar model into hide location

        Args:
            model_name (str): model name
        """
        # set the car at the pit parking position
        yaw = 0.0 if self._track_data.is_ccw else math.pi
        self._model_updater.set_model_position(
            model_name,
            self._hide_positions[self._hide_position_idx],
            yaw,
            is_blocking=True)
        LOG.info("[virtual event manager] hide {}".format(model_name))

    def upload_race_status(self,
                           status_code,
                           error_name=None,
                           error_details=None):
        """Upload race status into s3.

        Args:
            status_code (str): Status code for race.
            error_name (str, optional): The name of the error if is 4xx or 5xx.
                                        Defaults to "".
            error_details (str, optional): The detail message of the error
                                           if is 4xx or 5xx.
                                           Defaults to "".
        """
        # persist s3 status file
        if error_name is not None and error_details is not None:
            status = {
                RaceStatusKeys.STATUS_CODE.value: status_code,
                RaceStatusKeys.ERROR_NAME.value: error_name,
                RaceStatusKeys.ERROR_DETAILS.value: error_details
            }
        else:
            status = {RaceStatusKeys.STATUS_CODE.value: status_code}
        status_json = json.dumps(status)
        s3_key = os.path.normpath(
            os.path.join(self._current_racer.outputStatus.s3KeyPrefix,
                         S3_RACE_STATUS_FILE_NAME))
        race_status_kms = self._current_racer.outputStatus.s3KmsKeyArn if \
            hasattr(self._current_racer.outputStatus, 's3KmsKeyArn') else None
        self._s3_client.upload_fileobj(
            bucket=self._current_racer.outputStatus.s3BucketName,
            s3_key=s3_key,
            fileobj=io.BytesIO(status_json.encode()),
            s3_kms_extra_args=utils.get_s3_extra_args(race_status_kms))
        LOG.info(
            "[virtual event manager] Successfully uploaded race status file to \
                 s3 bucket {} with s3 key {}.".format(
                self._current_racer.outputStatus.s3BucketName, s3_key))
예제 #4
0
def main():
    """ Main function for evaluation worker """
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--preset',
                        help="(string) Name of a preset to run \
                             (class name from the 'presets' directory.)",
                        type=str,
                        required=False)
    parser.add_argument('--s3_bucket',
                        help='list(string) S3 bucket',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_BUCKET",
                                                ["gsaur-test"]))
    parser.add_argument('--s3_prefix',
                        help='list(string) S3 prefix',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_PREFIX",
                                                ["sagemaker"]))
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=rospy.get_param("AWS_REGION", "us-east-1"))
    parser.add_argument('--number_of_trials',
                        help='(integer) Number of trials',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)))
    parser.add_argument(
        '-c',
        '--local_model_directory',
        help='(string) Path to a folder containing a checkpoint \
                             to restore the model from.',
        type=str,
        default='./checkpoint')
    parser.add_argument('--number_of_resets',
                        help='(integer) Number of resets',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
    parser.add_argument('--penalty_seconds',
                        help='(float) penalty second',
                        type=float,
                        default=float(rospy.get_param("PENALTY_SECONDS", 2.0)))
    parser.add_argument('--job_type',
                        help='(string) job type',
                        type=str,
                        default=rospy.get_param("JOB_TYPE", "EVALUATION"))
    parser.add_argument('--is_continuous',
                        help='(boolean) is continous after lap completion',
                        type=bool,
                        default=utils.str2bool(
                            rospy.get_param("IS_CONTINUOUS", False)))
    parser.add_argument('--race_type',
                        help='(string) Race type',
                        type=str,
                        default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
    parser.add_argument('--off_track_penalty',
                        help='(float) off track penalty second',
                        type=float,
                        default=float(rospy.get_param("OFF_TRACK_PENALTY",
                                                      2.0)))
    parser.add_argument('--collision_penalty',
                        help='(float) collision penalty second',
                        type=float,
                        default=float(rospy.get_param("COLLISION_PENALTY",
                                                      5.0)))

    args = parser.parse_args()
    arg_s3_bucket = args.s3_bucket
    arg_s3_prefix = args.s3_prefix
    logger.info("S3 bucket: %s \n S3 prefix: %s", arg_s3_bucket, arg_s3_prefix)

    metrics_s3_buckets = rospy.get_param('METRICS_S3_BUCKET')
    metrics_s3_object_keys = rospy.get_param('METRICS_S3_OBJECT_KEY')

    arg_s3_bucket, arg_s3_prefix = utils.force_list(
        arg_s3_bucket), utils.force_list(arg_s3_prefix)
    metrics_s3_buckets = utils.force_list(metrics_s3_buckets)
    metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys)

    validate_list = [
        arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets,
        metrics_s3_object_keys
    ]

    simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
    mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None)
    if simtrace_s3_bucket:
        simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
        simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket)
        simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix)
        validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix])
    if mp4_s3_bucket:
        mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')
        mp4_s3_bucket = utils.force_list(mp4_s3_bucket)
        mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix)
        validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix])

    if not all([lambda x: len(x) == len(validate_list[0]), validate_list]):
        log_and_exit(
            "Eval worker error: Incorrect arguments passed: {}".format(
                validate_list), SIMAPP_SIMULATION_WORKER_EXCEPTION,
            SIMAPP_EVENT_ERROR_CODE_500)
    if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT:
        raise GenericRolloutException(
            "number of resets is less than {}".format(MIN_RESET_COUNT))

    # Instantiate Cameras
    if len(arg_s3_bucket) == 1:
        configure_camera(namespaces=['racecar'])
    else:
        configure_camera(namespaces=[
            'racecar_{}'.format(str(agent_index))
            for agent_index in range(len(arg_s3_bucket))
        ])

    agent_list = list()
    s3_bucket_dict = dict()
    s3_prefix_dict = dict()
    checkpoint_dict = dict()
    simtrace_video_s3_writers = []
    start_positions = get_start_positions(len(arg_s3_bucket))
    done_condition = utils.str_to_done_condition(
        rospy.get_param("DONE_CONDITION", any))
    park_positions = utils.pos_2d_str_to_list(
        rospy.get_param("PARK_POSITIONS", []))
    # if not pass in park positions for all done condition case, use default
    if not park_positions:
        park_positions = [DEFAULT_PARK_POSITION for _ in arg_s3_bucket]
    for agent_index, _ in enumerate(arg_s3_bucket):
        agent_name = 'agent' if len(arg_s3_bucket) == 1 else 'agent_{}'.format(
            str(agent_index))
        racecar_name = 'racecar' if len(
            arg_s3_bucket) == 1 else 'racecar_{}'.format(str(agent_index))
        s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index]
        s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index]

        # download model metadata
        model_metadata = ModelMetadata(
            bucket=arg_s3_bucket[agent_index],
            s3_key=get_s3_key(arg_s3_prefix[agent_index],
                              MODEL_METADATA_S3_POSTFIX),
            region_name=args.aws_region,
            local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(agent_name))
        model_metadata_info = model_metadata.get_model_metadata_info()
        version = model_metadata_info[ModelMetadataKeys.VERSION.value]

        # checkpoint s3 instance
        checkpoint = Checkpoint(bucket=arg_s3_bucket[agent_index],
                                s3_prefix=arg_s3_prefix[agent_index],
                                region_name=args.aws_region,
                                agent_name=agent_name,
                                checkpoint_dir=args.local_model_directory)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible(
        ):
            checkpoint.rl_coach_checkpoint.make_compatible(
                checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint(
        )
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_kms_extra_args())

        checkpoint_dict[agent_name] = checkpoint

        agent_config = {
            'model_metadata': model_metadata,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [
                    link_name.replace('racecar', racecar_name)
                    for link_name in LINK_NAMES
                ],
                ConfigParams.VELOCITY_LIST.value: [
                    velocity_topic.replace('racecar', racecar_name)
                    for velocity_topic in VELOCITY_TOPICS
                ],
                ConfigParams.STEERING_LIST.value: [
                    steering_topic.replace('racecar', racecar_name)
                    for steering_topic in STEERING_TOPICS
                ],
                ConfigParams.CHANGE_START.value:
                utils.str2bool(rospy.get_param('CHANGE_START_POSITION',
                                               False)),
                ConfigParams.ALT_DIR.value:
                utils.str2bool(
                    rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
                ConfigParams.MODEL_METADATA.value:
                model_metadata,
                ConfigParams.REWARD.value:
                reward_function,
                ConfigParams.AGENT_NAME.value:
                racecar_name,
                ConfigParams.VERSION.value:
                version,
                ConfigParams.NUMBER_OF_RESETS.value:
                args.number_of_resets,
                ConfigParams.PENALTY_SECONDS.value:
                args.penalty_seconds,
                ConfigParams.NUMBER_OF_TRIALS.value:
                args.number_of_trials,
                ConfigParams.IS_CONTINUOUS.value:
                args.is_continuous,
                ConfigParams.RACE_TYPE.value:
                args.race_type,
                ConfigParams.COLLISION_PENALTY.value:
                args.collision_penalty,
                ConfigParams.OFF_TRACK_PENALTY.value:
                args.off_track_penalty,
                ConfigParams.START_POSITION.value:
                start_positions[agent_index],
                ConfigParams.DONE_CONDITION.value:
                done_condition
            }
        }

        metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value:
            metrics_s3_buckets[agent_index],
            MetricsS3Keys.METRICS_KEY.value:
            metrics_s3_object_keys[agent_index],
            # Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed
            # or default argument set
            MetricsS3Keys.REGION.value:
            args.aws_region
        }
        aws_region = rospy.get_param('AWS_REGION', args.aws_region)

        if simtrace_s3_bucket:
            simtrace_video_s3_writers.append(
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.SIMTRACE_EVAL.value,
                    bucket=simtrace_s3_bucket[agent_index],
                    s3_prefix=simtrace_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=SIMTRACE_EVAL_LOCAL_PATH_FORMAT.format(
                        agent_name)))
        if mp4_s3_bucket:
            simtrace_video_s3_writers.extend([
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.PIP.value,
                    bucket=mp4_s3_bucket[agent_index],
                    s3_prefix=mp4_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format(
                        agent_name)),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.DEGREE45.value,
                    bucket=mp4_s3_bucket[agent_index],
                    s3_prefix=mp4_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format(
                        agent_name)),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.TOPVIEW.value,
                    bucket=mp4_s3_bucket[agent_index],
                    s3_prefix=mp4_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format(
                        agent_name))
            ])

        run_phase_subject = RunPhaseSubject()
        agent_list.append(
            create_rollout_agent(
                agent_config,
                EvalMetrics(agent_name, metrics_s3_config, args.is_continuous),
                run_phase_subject))
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())

    # ROS service to indicate all the robomaker markov packages are ready for consumption
    signal_robomaker_markov_package_ready()

    PhaseObserver('/agent/training_phase', run_phase_subject)
    enable_domain_randomization = utils.str2bool(
        rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False))

    sm_hyperparams_dict = {}

    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service('/gazebo/pause_physics_dr')
    rospy.wait_for_service('/gazebo/unpause_physics_dr')
    pause_physics = ServiceProxyWrapper('/gazebo/pause_physics_dr', Empty)
    unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics_dr', Empty)

    graph_manager, _ = get_graph_manager(
        hp_dict=sm_hyperparams_dict,
        agent_list=agent_list,
        run_phase_subject=run_phase_subject,
        enable_domain_randomization=enable_domain_randomization,
        done_condition=done_condition,
        pause_physics=pause_physics,
        unpause_physics=unpause_physics)

    ds_params_instance = S3BotoDataStoreParameters(
        checkpoint_dict=checkpoint_dict)

    graph_manager.data_store = S3BotoDataStore(params=ds_params_instance,
                                               graph_manager=graph_manager,
                                               ignore_lock=True)
    graph_manager.env_params.seed = 0

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.local_model_directory

    evaluation_worker(graph_manager=graph_manager,
                      number_of_trials=args.number_of_trials,
                      task_parameters=task_parameters,
                      simtrace_video_s3_writers=simtrace_video_s3_writers,
                      is_continuous=args.is_continuous,
                      park_positions=park_positions,
                      race_type=args.race_type,
                      pause_physics=pause_physics,
                      unpause_physics=unpause_physics)
예제 #5
0
class VirtualEventManager(object):
    """
        This is the manager that manages the live virtual event.
    """
    def __init__(self, queue_url, aws_region='us-east-1',
                 race_duration=180, number_of_trials=3, number_of_resets=10000,
                 penalty_seconds=2.0, off_track_penalty=2.0, collision_penalty=5.0,
                 is_continuous=False, race_type="TIME_TRIAL",
                 body_shell_type="deepracer"):
        # constructor arguments
        self._body_shell_type = body_shell_type
        self._num_sectors = int(rospy.get_param("NUM_SECTORS", "3"))
        self._queue_url = queue_url
        self._region = aws_region
        self._number_of_trials = number_of_trials
        self._number_of_resets = number_of_resets
        self._penalty_seconds = penalty_seconds
        self._off_track_penalty = off_track_penalty
        self._collision_penalty = collision_penalty
        self._is_continuous = is_continuous
        self._race_type = race_type
        self._is_save_simtrace_enabled = False
        self._is_save_mp4_enabled = False
        self._is_event_end = False
        self._done_condition = any
        self._race_duration = race_duration
        self._enable_domain_randomization = False

        # sqs client
        # The boto client errors out after polling for 1 hour.
        self._sqs_client = SQSClient(queue_url=self._queue_url,
                                     region_name=self._region,
                                     max_num_of_msg=MAX_NUM_OF_SQS_MESSAGE,
                                     wait_time_sec=SQS_WAIT_TIME_SEC,
                                     session=refreshed_session(self._region))
        self._s3_client = S3Client(region_name=self._region)
        self._model_updater = ModelUpdater.get_instance()

        # tracking current state information
        self._track_data = TrackData.get_instance()
        self._start_lane = self._track_data.center_line
        # keep track of the racer specific info, e.g. s3 locations, alias, car color etc.
        self._current_racer = None
        # keep track of the current race car we are using, e.g. racecar_0
        self._current_car_model_state = None
        # keep track of the current control agent we are using
        self._current_agent = None
        # keep track of the current control graph manager
        self._current_graph_manager = None
        # Keep track of previous model's name
        self._prev_model_name = None
        self._park_position_idx = 0
        self._park_positions = get_hide_positions(len(SENSOR_MODEL_MAP))
        self._run_phase_subject = RunPhaseSubject()
        self._simtrace_video_s3_writers = []

        self._local_model_directory = './checkpoint'

        # virtual event only have single agent, so set agent_name to "agent"
        self._agent_name = "agent"

        # camera manager
        self._camera_manager = CameraManager.get_instance()

        # setting up virtual event top and follow camera in CameraManager
        # virtual event configure camera does not need to wait for car to spawm because
        # follow car camera is not tracking any car initially
        self._main_cameras, self._sub_camera = configure_camera(namespaces=[VIRTUAL_EVENT],
                                                                is_wait_for_model=False)

        # pop out all cameras after configuration to prevent camera from moving
        self._camera_manager.pop(namespace=VIRTUAL_EVENT)

        dummy_metrics_s3_config = {MetricsS3Keys.METRICS_BUCKET.value: "dummy-bucket",
                                   MetricsS3Keys.METRICS_KEY.value: "dummy-key",
                                   MetricsS3Keys.REGION.value: self._region}

        self._eval_metrics = EvalMetrics(agent_name=self._agent_name,
                                         s3_dict_metrics=dummy_metrics_s3_config,
                                         is_continuous=self._is_continuous,
                                         pause_time_before_start=PAUSE_TIME_BEFORE_START)

        # upload a default best sector time for all sectors with time inf for each sector
        # if there is not best sector time existed in s3

        # use the s3 bucket and prefix for yaml file stored as environment variable because
        # here is SimApp use only. For virtual event there is no s3 bucket and prefix past
        # through yaml file. All are past through sqs. For simplicity, reuse the yaml s3 bucket
        # and prefix environment variable.
        virtual_event_best_sector_time = VirtualEventBestSectorTime(
            bucket=os.environ.get("YAML_S3_BUCKET", ''),
            s3_key=get_s3_key(os.environ.get("YAML_S3_PREFIX", ''), SECTOR_TIME_S3_POSTFIX),
            region_name=os.environ.get("APP_REGION", "us-east-1"),
            local_path=SECTOR_TIME_LOCAL_PATH)
        response = virtual_event_best_sector_time.list()
        # this is used to handle situation such as robomaker job crash, so the next robomaker job
        # can catch the best sector time left over from crashed job
        if "Contents" not in response:
            virtual_event_best_sector_time.persist(body=json.dumps(
                {SECTOR_X_FORMAT.format(idx + 1): float("inf")
                 for idx in range(self._num_sectors)}),
                s3_kms_extra_args=utils.get_s3_kms_extra_args())

        # ROS service to indicate all the robomaker markov packages are ready for consumption
        signal_robomaker_markov_package_ready()

        PhaseObserver('/agent/training_phase', self._run_phase_subject)

        # setup mp4 services
        self._setup_mp4_services()

    @property
    def current_racer(self):
        """ Get the current racer object.

        Returns:
            RacerInformation: Information about current racer that was passed-in from the queue.
        """
        return self._current_racer

    @property
    def is_event_end(self):
        """Return True if the service has signaled event end

        Returns:
            boolean: Is it the time to kill everything and die
        """
        return self._is_event_end

    def poll_next_racer(self):
        """
            Poll from sqs the next racer information.
        """
        received_racer = False
        error_counter = 0
        while not received_racer:
            # Polling MAX_NUM_OF_SQS_MESSAGE=1 message from sqs
            # with wait time specified in SQS_WAIT_TIME_SEC
            response = self._sqs_client.get_messages()
            # If polling message is successful, it will return a list of payloads
            # If polling message failed, it will return an integer
            # 1=ClientError Or FailedToDeleteMessage
            # 2=SystemError
            if isinstance(response, int):
                error_counter += response
                if error_counter >= MAX_NUM_OF_SQS_ERROR:
                    # something went really wrong with the sqs queue...
                    log_and_exit("[virtual event manager] Too many exceptions (num={}) in \
                                 receiving message from sqs queue: {}".format(error_counter,
                                                                              self._queue_url),
                                 SIMAPP_SQS_RECEIVE_MESSAGE_EXCEPTION,
                                 SIMAPP_EVENT_ERROR_CODE_500)
            elif isinstance(response, list) and len(response) == 1:
                message_body = response[0]
                try:
                    # validate the current racer information.
                    validate_json_input(message_body, RACER_INFO_JSON_SCHEMA)
                    # Parse JSON into an racer information object
                    # with attributes corresponding to dict keys
                    self._current_racer = json.loads(message_body,
                                                     object_hook=lambda d:
                                                     namedtuple(RACER_INFO_OBJECT, d.keys())(*d.values()))
                    # only set received_racer to True after making sure the message is valid.
                    received_racer = True
                    LOG.info("[virtual event manager] Received next racer's information %s", self._current_racer)
                except GenericNonFatalException as ex:
                    ex.log_except_and_continue()

    def setup_race(self):
        """
            Setting up the race for the current racer.

        Returns:
            bool: True if setup race is successful.
                  False is a non fatal exception occurred.
        """

        LOG.info("[virtual event manager] Setting up race for racer")
        try:
            # unpause the physics in current world
            self._model_updater.unpause_physics()
            LOG.info("[virtual event manager] Unpaused physics in current world.")
            # set camera to starting position
            initial_pose = self._track_data.get_racecar_start_pose(
                racecar_idx=0,
                racer_num=1,
                start_position=get_start_positions(1)[0])
            self._main_cameras[VIRTUAL_EVENT].reset_pose(
                car_pose=initial_pose)
            LOG.info("[virtual event manager] Reset camera to starting line.")
            if self._prev_model_name is not None:
                # NOTE: it's by design that we immediately part the previous car to pit
                # location right after unpause physics. This prevents any unwanted
                # leftover behavior to happen
                self._park_at_pit_location(self._prev_model_name)
                LOG.info("[virtual event manager] Parked previous model %s to pit location.",
                         self._prev_model_name)
            self._model_updater.pause_physics()
            LOG.info("[virtual event manager] Paused physics in current world.")
            # download model metadata from s3
            sensors, version, model_metadata = self._download_model_metadata()
            # based on model metadata, select racecar
            self._current_car_model_state = self._get_car_model_state(sensors)
            # download checkpoint from s3
            checkpoint = self._download_checkpoint(version)
            # setup the simtrace and mp4 writers if the s3 locations are available
            self._setup_simtrace_mp4_writers()
            # reset the metrics s3 location for the current racer
            self._reset_metrics_loc()
            # setup agents
            agent_list = self._get_agent_list(model_metadata, version)
            self._setup_graph_manager(checkpoint, agent_list)
            LOG.info("[virtual event manager] Graph manager successfully created the graph: setup race successful.")
            return True
        except GenericNonFatalException as ex:
            ex.log_except_and_continue()
            self.upload_race_status(status_code=ex.error_code,
                                    error_name=ex.error_name,
                                    error_details=ex.error_msg)
            self._clean_up_race()
            return False
        except Exception as ex:
            log_and_exit("[virtual event manager] Something really wrong happened when setting up the race. {}".format(ex),
                         SIMAPP_VIRTUAL_EVENT_RACE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)

    def _download_checkpoint(self, version):
        """Setup the Checkpoint object and selete the best checkpoint.

        Args:
            version (float): SimApp version

        Returns:
            Checkpoint: Checkpoint class instance
        """
        # download checkpoint from s3
        checkpoint = Checkpoint(bucket=self._current_racer.inputModel.s3BucketName,
                                s3_prefix=self._current_racer.inputModel.s3KeyPrefix,
                                region_name=self._region,
                                agent_name=self._agent_name,
                                checkpoint_dir=self._local_model_directory)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible():
            checkpoint.rl_coach_checkpoint.make_compatible(checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint()
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        model_kms = self._current_racer.inputModel.s3KmsKeyArn if hasattr(self._current_racer.inputModel, 's3KmsKeyArn') else None
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_extra_args(model_kms))
        return checkpoint

    def _download_model_metadata(self):
        """ Attempt to download model metadata from s3.

        Raises:
            GenericNonFatalException: An non fatal exception which we will
                                      catch and proceed with work loop.

        Returns:
            sensors, version, model_metadata: The needed information from model metadata.
        """
        model_metadata_s3_key = get_s3_key(self._current_racer.inputModel.s3KeyPrefix,
                                           MODEL_METADATA_S3_POSTFIX)
        try:
            model_metadata = ModelMetadata(bucket=self._current_racer.inputModel.s3BucketName,
                                           s3_key=model_metadata_s3_key,
                                           region_name=self._region,
                                           local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(self._agent_name))
            model_metadata_info = model_metadata.get_model_metadata_info()
            sensors = model_metadata_info[ModelMetadataKeys.SENSOR.value]
            simapp_version = model_metadata_info[ModelMetadataKeys.VERSION.value]
        except botocore.exceptions.ClientError as err:
            error_msg = "[s3] Client Error: Failed to download model_metadata file: \
                        s3_bucket: {}, s3_key: {}, {}.".format(self._current_racer.inputModel.s3BucketName,
                                                               model_metadata_s3_key,
                                                               err)
            raise GenericNonFatalException(error_msg=error_msg,
                                           error_code=SIMAPP_EVENT_ERROR_CODE_400,
                                           error_name=SIMAPP_EVENT_USER_ERROR)
        except Exception as err:
            error_msg = "[s3] System Error: Failed to download model_metadata file: \
                        s3_bucket: {}, s3_key: {}, {}.".format(self._current_racer.inputModel.s3BucketName,
                                                               model_metadata_s3_key,
                                                               err)
            raise GenericNonFatalException(error_msg=error_msg,
                                           error_code=SIMAPP_EVENT_ERROR_CODE_500,
                                           error_name=SIMAPP_EVENT_SYSTEM_ERROR)
        return sensors, simapp_version, model_metadata

    def start_race(self):
        """
            Start the race (evaluation) for the current racer.
        """
        LOG.info("[virtual event manager] Starting race for racer %s", self._current_racer.racerAlias)
        # update the car on current model if does not use f1 or tron type of shell
        if const.F1 not in self._body_shell_type.lower():
            self._model_updater.update_model_color(self._current_car_model_state.model_name,
                                                   self._current_racer.carConfig.carColor)
        # send request
        if self._is_save_mp4_enabled:
            self._subscribe_to_save_mp4(VirtualEventVideoEditSrvRequest(
                display_name=self._current_racer.racerAlias,
                racecar_color=self._current_racer.carConfig.carColor))

        # Update CameraManager by adding cameras into the current namespace. By doing so
        # a single follow car camera will follow the current active racecar.
        self._camera_manager.add(self._main_cameras[VIRTUAL_EVENT],
                                 self._current_car_model_state.model_name)
        self._camera_manager.add(self._sub_camera, self._current_car_model_state.model_name)

        configure_environment_randomizer()
        # strip index for park position
        self._park_position_idx = get_racecar_idx(self._current_car_model_state.model_name)
        # set the park position in track and do evaluation
        # Before each evaluation episode (single lap for non-continuous race and complete race for
        # continuous race), a new copy of park_positions needs to be loaded into track_data because
        # a park position will be pop from park_positions when a racer car need to be parked.
        # unpause the physics in current world
        self._model_updater.unpause_physics()
        LOG.info("[virtual event manager] Unpaused physics in current world.")
        if self._prev_model_name is not None and \
           self._prev_model_name != self._current_car_model_state.model_name:
            # disable the links on the prev car
            # we are doing it here because we don't want the car to float around
            # after the link is disabled
            prev_car_model_state = ModelState()
            prev_car_model_state.model_name = self._prev_model_name
        LOG.info("[virtual event manager] Unpaused model for current car.")
        if self._is_continuous:
            self._track_data.park_positions = [self._park_positions[self._park_position_idx]]
            self._current_graph_manager.evaluate(EnvironmentSteps(1))
        else:
            for _ in range(self._number_of_trials):
                self._track_data.park_positions = [self._park_positions[self._park_position_idx]]
                self._current_graph_manager.evaluate(EnvironmentSteps(1))

    def finish_race(self):
        """
            Finish the race for the current racer.
        """
        # pause physics of the world
        self._model_updater.pause_physics()
        time.sleep(1)
        # unsubscribe mp4
        if self._is_save_mp4_enabled:
            self._unsubscribe_from_save_mp4(EmptyRequest())
        self._track_data.remove_object(name=self._current_car_model_state.model_name)
        LOG.info("[virtual event manager] Finish Race - remove object %s.",
                 self._current_car_model_state.model_name)
        # pop out current racecar from camera namespace to prevent camera from moving
        self._camera_manager.pop(namespace=self._current_car_model_state.model_name)
        # upload simtrace and mp4 into s3 bucket
        self._save_simtrace_mp4()
        self.upload_race_status(status_code=200)
        # keep track of the previous model name
        self._prev_model_name = self._current_car_model_state.model_name
        # clean up local trace of current race
        self._clean_up_race()

    def _clean_up_race(self):
        """Helper function to clean up everything related to the ex-racer
           and get ready for the next racer.
        """
        self._simtrace_video_s3_writers = []
        # clean up local checkpoints etc.
        self._clean_local_directory()
        # reset the tensorflow graph to avoid errors with the global session
        tf.reset_default_graph()
        # reset the current racer
        self._current_racer = None
        self._current_agent = None
        self._current_car_model_state = None
        self._current_graph_manager = None
        self._park_position_idx = 0

    def _save_simtrace_mp4(self):
        """Get the appropriate kms key and save the simtrace and mp4 files.
        """
        # TODO: It might be theorically possible to have different kms keys for simtrace and mp4
        # but we are using the same key now since that's what happens in real life
        # consider refactor the simtrace_video_s3_writers later.
        if hasattr(self._current_racer.outputMp4, 's3KmsKeyArn'):
            simtrace_mp4_kms = self._current_racer.outputMp4.s3KmsKeyArn
        elif hasattr(self._current_racer.outputSimTrace, 's3KmsKeyArn'):
            simtrace_mp4_kms = self._current_racer.outputSimTrace.s3KmsKeyArn
        else:
            simtrace_mp4_kms = None
        for s3_writer in self._simtrace_video_s3_writers:
            s3_writer.persist(utils.get_s3_extra_args(simtrace_mp4_kms))

    def _reset_metrics_loc(self):
        """Reset the metrics location as new racer is loaded.
        """
        metrics_s3_config = {MetricsS3Keys.METRICS_BUCKET.value: self._current_racer.outputMetrics.s3BucketName,
                             MetricsS3Keys.METRICS_KEY.value: self._current_racer.outputMetrics.s3KeyPrefix,
                             MetricsS3Keys.REGION.value: self._region}
        self._eval_metrics.reset_metrics(s3_dict_metrics=metrics_s3_config,
                                         is_save_simtrace_enabled=self._is_save_simtrace_enabled)

    def _setup_simtrace_mp4_writers(self):
        """Setup the simtrace and mp4 writers if the locations are passed in.
        """
        self._is_save_simtrace_enabled = False
        self._is_save_mp4_enabled = False
        if hasattr(self._current_racer, 'outputSimTrace'):
            self._simtrace_video_s3_writers.append(
                SimtraceVideo(upload_type=SimtraceVideoNames.SIMTRACE_EVAL.value,
                              bucket=self._current_racer.outputSimTrace.s3BucketName,
                              s3_prefix=self._current_racer.outputSimTrace.s3KeyPrefix,
                              region_name=self._region,
                              local_path=SIMTRACE_EVAL_LOCAL_PATH_FORMAT.format(self._agent_name)))
            self._is_save_simtrace_enabled = True
        if hasattr(self._current_racer, 'outputMp4'):
            self._simtrace_video_s3_writers.extend([
                SimtraceVideo(upload_type=SimtraceVideoNames.PIP.value,
                              bucket=self._current_racer.outputMp4.s3BucketName,
                              s3_prefix=self._current_racer.outputMp4.s3KeyPrefix,
                              region_name=self._region,
                              local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format(self._agent_name)),
                SimtraceVideo(upload_type=SimtraceVideoNames.DEGREE45.value,
                              bucket=self._current_racer.outputMp4.s3BucketName,
                              s3_prefix=self._current_racer.outputMp4.s3KeyPrefix,
                              region_name=self._region,
                              local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format(self._agent_name)),
                SimtraceVideo(upload_type=SimtraceVideoNames.TOPVIEW.value,
                              bucket=self._current_racer.outputMp4.s3BucketName,
                              s3_prefix=self._current_racer.outputMp4.s3KeyPrefix,
                              region_name=self._region,
                              local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format(self._agent_name))])
            self._is_save_mp4_enabled = True

    def _setup_graph_manager(self, checkpoint, agent_list):
        """Sets up graph manager based on the checkpoint file and agents list.

        Args:
            checkpoint (Checkpoint): The model checkpoints we just downloaded.
            agent_list (list[Agent]): List of agents we want to setup graph manager for.
        """
        sm_hyperparams_dict = {}
        self._current_graph_manager, _ = get_graph_manager(hp_dict=sm_hyperparams_dict, agent_list=agent_list,
                                                           run_phase_subject=self._run_phase_subject,
                                                           enable_domain_randomization=self._enable_domain_randomization,
                                                           done_condition=self._done_condition,
                                                           pause_physics=self._model_updater.pause_physics_service,
                                                           unpause_physics=self._model_updater.unpause_physics_service)
        checkpoint_dict = dict()
        checkpoint_dict[self._agent_name] = checkpoint
        ds_params_instance = S3BotoDataStoreParameters(checkpoint_dict=checkpoint_dict)

        self._current_graph_manager.data_store = S3BotoDataStore(params=ds_params_instance,
                                                                 graph_manager=self._current_graph_manager,
                                                                 ignore_lock=True,
                                                                 log_and_cont=True)
        self._current_graph_manager.env_params.seed = 0

        self._current_graph_manager.data_store.wait_for_checkpoints()
        self._current_graph_manager.data_store.modify_checkpoint_variables()

        task_parameters = TaskParameters()
        task_parameters.checkpoint_restore_path = self._local_model_directory

        self._current_graph_manager.create_graph(task_parameters=task_parameters,
                                                 stop_physics=self._model_updater.pause_physics_service,
                                                 start_physics=self._model_updater.unpause_physics_service,
                                                 empty_service_call=EmptyRequest)

    def _get_agent_list(self, model_metadata, version):
        """Setup agent and get the agents list.

        Args:
            model_metadata (ModelMetadata): Current racer's model metadata
            version (str): The current racer's simapp version in the model metadata

        Returns:
            agent_list (list): The list of agents for the current racer
        """
        # setup agent
        agent_config = {
            'model_metadata': model_metadata,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [
                    link_name.replace('racecar', self._current_car_model_state.model_name) for link_name in LINK_NAMES],
                ConfigParams.VELOCITY_LIST.value: [
                    velocity_topic.replace('racecar', self._current_car_model_state.model_name) for velocity_topic in VELOCITY_TOPICS],
                ConfigParams.STEERING_LIST.value: [
                    steering_topic.replace('racecar', self._current_car_model_state.model_name) for steering_topic in STEERING_TOPICS],
                ConfigParams.CHANGE_START.value: utils.str2bool(rospy.get_param('CHANGE_START_POSITION', False)),
                ConfigParams.ALT_DIR.value: utils.str2bool(rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
                ConfigParams.MODEL_METADATA.value: model_metadata,
                ConfigParams.REWARD.value: reward_function,
                ConfigParams.AGENT_NAME.value: self._current_car_model_state.model_name,
                ConfigParams.VERSION.value: version,
                ConfigParams.NUMBER_OF_RESETS.value: self._number_of_resets,
                ConfigParams.PENALTY_SECONDS.value: self._penalty_seconds,
                ConfigParams.NUMBER_OF_TRIALS.value: self._number_of_trials,
                ConfigParams.IS_CONTINUOUS.value: self._is_continuous,
                ConfigParams.RACE_TYPE.value: self._race_type,
                ConfigParams.COLLISION_PENALTY.value: self._collision_penalty,
                ConfigParams.OFF_TRACK_PENALTY.value: self._off_track_penalty,
                ConfigParams.START_POSITION.value: get_start_positions(1)[0],  # hard-coded to the first start position
                ConfigParams.DONE_CONDITION.value: self._done_condition,
                ConfigParams.IS_VIRTUAL_EVENT.value: True,
                ConfigParams.RACE_DURATION.value: self._race_duration}}

        agent_list = list()
        agent_list.append(create_rollout_agent(agent_config,
                                               self._eval_metrics,
                                               self._run_phase_subject))
        agent_list.append(create_obstacles_agent())
        agent_list.append(create_bot_cars_agent())
        return agent_list

    def _setup_mp4_services(self):
        """
        Setting up the mp4 ros services if mp4s need to be saved.
        """
        mp4_sub = "/{}/save_mp4/subscribe_to_save_mp4".format(VIRTUAL_EVENT)
        mp4_unsub = "/{}/save_mp4/unsubscribe_from_save_mp4".format(VIRTUAL_EVENT)
        rospy.wait_for_service(mp4_sub)
        rospy.wait_for_service(mp4_unsub)
        self._subscribe_to_save_mp4 = ServiceProxyWrapper(mp4_sub, VirtualEventVideoEditSrv)
        self._unsubscribe_from_save_mp4 = ServiceProxyWrapper(mp4_unsub, Empty)

    def _get_car_model_state(self, sensors: list) -> ModelState:
        """Get the current car model state according to sensors configuration.

        Args:
            sensors (list): sensors in the model metadata

        Returns:
            ModelState: a model state object with the current racecar name
        """
        is_stereo = False
        is_lidar = False
        if Input.STEREO.value in sensors:
            is_stereo = True
            LOG.info("[virtual event manager] stereo camera present")
        if Input.LIDAR.value in sensors or Input.SECTOR_LIDAR.value in sensors:
            is_lidar = True
            LOG.info("[virtual event manager] lidar present")
        car_model_state = ModelState()
        if is_stereo:
            if is_lidar:
                car_model_state.model_name = SENSOR_MODEL_MAP['stereo_camera_lidar']
            else:
                car_model_state.model_name = SENSOR_MODEL_MAP['stereo_camera']
        else:
            if is_lidar:
                car_model_state.model_name = SENSOR_MODEL_MAP['single_camera_lidar']
            else:
                car_model_state.model_name = SENSOR_MODEL_MAP['single_camera']
        return car_model_state

    def _clean_local_directory(self):
        """Clean up the local directory after race ends.
        """
        LOG.info("[virtual event manager] cleaning up the local directory after race ends.")
        for root, _, files in os.walk(self._local_model_directory):
            for f in files:
                os.remove(os.path.join(root, f))

    def _park_at_pit_location(self, model_name):
        """Reset car to inital position.
        """
        # set the car at the pit parking position
        yaw = 0.0 if self._track_data.is_ccw else math.pi
        self._model_updater.set_model_position(model_name,
                                               self._park_positions[self._park_position_idx],
                                               yaw,
                                               is_blocking=True)
        LOG.info("[virtual event manager] parked car to pit position.")

    def upload_race_status(self, status_code, error_name=None, error_details=None):
        """Upload race status into s3.

        Args:
            status_code (str): Status code for race.
            error_name (str, optional): The name of the error if is 4xx or 5xx.
                                        Defaults to "".
            error_details (str, optional): The detail message of the error
                                           if is 4xx or 5xx.
                                           Defaults to "".
        """
        # persist s3 status file
        if error_name is not None and error_details is not None:
            status = {RaceStatusKeys.STATUS_CODE.value: status_code,
                      RaceStatusKeys.ERROR_NAME.value: error_name,
                      RaceStatusKeys.ERROR_DETAILS.value: error_details}
        else:
            status = {RaceStatusKeys.STATUS_CODE.value: status_code}
        status_json = json.dumps(status)
        s3_key = os.path.normpath(os.path.join(self._current_racer.outputStatus.s3KeyPrefix,
                                               S3_RACE_STATUS_FILE_NAME))
        race_status_kms = self._current_racer.outputStatus.s3KmsKeyArn if \
            hasattr(self._current_racer.outputStatus, 's3KmsKeyArn') else None
        self._s3_client.upload_fileobj(bucket=self._current_racer.outputStatus.s3BucketName,
                                       s3_key=s3_key,
                                       fileobj=io.BytesIO(status_json.encode()),
                                       s3_kms_extra_args=utils.get_s3_extra_args(race_status_kms))
        LOG.info("[virtual event manager] Successfully uploaded race status file to \
                 s3 bucket {} with s3 key {}.".format(self._current_racer.outputStatus.s3BucketName,
                                                      s3_key))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--preset',
                        help="(string) Name of a preset to run \
                             (class name from the 'presets' directory.)",
                        type=str,
                        required=False)
    parser.add_argument('--s3_bucket',
                        help='(string) S3 bucket',
                        type=str,
                        default=rospy.get_param("MODEL_S3_BUCKET",
                                                "gsaur-test"))
    parser.add_argument('--s3_prefix',
                        help='(string) S3 prefix',
                        type=str,
                        default=rospy.get_param("MODEL_S3_PREFIX",
                                                "sagemaker"))
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=rospy.get_param("AWS_REGION", "us-east-1"))
    parser.add_argument('--number_of_trials',
                        help='(integer) Number of trials',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)))
    parser.add_argument(
        '-c',
        '--local_model_directory',
        help='(string) Path to a folder containing a checkpoint \
                             to restore the model from.',
        type=str,
        default='./checkpoint')

    args = parser.parse_args()
    logger.info("S3 bucket: %s \n S3 prefix: %s", args.s3_bucket,
                args.s3_prefix)

    s3_client = SageS3Client(bucket=args.s3_bucket,
                             s3_prefix=args.s3_prefix,
                             aws_region=args.aws_region)

    # Load the model metadata
    model_metadata_local_path = os.path.join(CUSTOM_FILES_PATH,
                                             'model_metadata.json')
    utils.load_model_metadata(
        s3_client,
        os.path.normpath("%s/model/model_metadata.json" % args.s3_prefix),
        model_metadata_local_path)
    # Handle backward compatibility
    _, _, version = parse_model_metadata(model_metadata_local_path)
    if float(version) < float(utils.SIMAPP_VERSION) and \
    not utils.has_current_ckpnt_name(args.s3_bucket, args.s3_prefix, args.aws_region):
        utils.make_compatible(args.s3_bucket, args.s3_prefix, args.aws_region,
                              SyncFiles.TRAINER_READY.value)
    # Download hyperparameters from SageMaker
    hyperparameters_file_success = False
    hyperparams_s3_key = os.path.normpath(args.s3_prefix +
                                          "/ip/hyperparameters.json")
    hyperparameters_file_success = s3_client.download_file(
        s3_key=hyperparams_s3_key, local_path="hyperparameters.json")
    sm_hyperparams_dict = {}
    if hyperparameters_file_success:
        logger.info("Received Sagemaker hyperparameters successfully!")
        with open("hyperparameters.json") as file:
            sm_hyperparams_dict = json.load(file)
    else:
        logger.info("SageMaker hyperparameters not found.")

    #! TODO each agent should have own config
    _, _, version = utils_parse_model_metadata.parse_model_metadata(
        model_metadata_local_path)
    agent_config = {
        'model_metadata': model_metadata_local_path,
        'car_ctrl_cnfig': {
            ConfigParams.LINK_NAME_LIST.value:
            LINK_NAMES,
            ConfigParams.VELOCITY_LIST.value:
            VELOCITY_TOPICS,
            ConfigParams.STEERING_LIST.value:
            STEERING_TOPICS,
            ConfigParams.CHANGE_START.value:
            utils.str2bool(rospy.get_param('CHANGE_START_POSITION', False)),
            ConfigParams.ALT_DIR.value:
            utils.str2bool(
                rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
            ConfigParams.ACTION_SPACE_PATH.value:
            'custom_files/model_metadata.json',
            ConfigParams.REWARD.value:
            reward_function,
            ConfigParams.AGENT_NAME.value:
            'racecar',
            ConfigParams.VERSION.value:
            version
        }
    }

    #! TODO each agent should have own s3 bucket
    metrics_s3_config = {
        MetricsS3Keys.METRICS_BUCKET.value:
        rospy.get_param('METRICS_S3_BUCKET'),
        MetricsS3Keys.METRICS_KEY.value:
        rospy.get_param('METRICS_S3_OBJECT_KEY'),
        MetricsS3Keys.REGION.value:
        rospy.get_param('AWS_REGION'),
        MetricsS3Keys.STEP_BUCKET.value:
        rospy.get_param('MODEL_S3_BUCKET'),
        MetricsS3Keys.STEP_KEY.value:
        os.path.join(rospy.get_param('MODEL_S3_PREFIX'),
                     EVALUATION_SIMTRACE_DATA_S3_OBJECT_KEY)
    }

    agent_list = list()
    agent_list.append(
        create_rollout_agent(agent_config, EvalMetrics(metrics_s3_config)))
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())

    graph_manager, _ = get_graph_manager(sm_hyperparams_dict, agent_list)

    ds_params_instance = S3BotoDataStoreParameters(
        aws_region=args.aws_region,
        bucket_name=args.s3_bucket,
        checkpoint_dir=args.local_model_directory,
        s3_folder=args.s3_prefix)

    data_store = S3BotoDataStore(ds_params_instance)
    data_store.graph_manager = graph_manager
    graph_manager.data_store = data_store
    graph_manager.env_params.seed = 0

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.local_model_directory

    evaluation_worker(
        graph_manager=graph_manager,
        data_store=data_store,
        number_of_trials=args.number_of_trials,
        task_parameters=task_parameters,
    )
예제 #7
0
def main():
    """ Main function for evaluation worker """
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--preset',
                        help="(string) Name of a preset to run \
                             (class name from the 'presets' directory.)",
                        type=str,
                        required=False)
    parser.add_argument('--s3_bucket',
                        help='list(string) S3 bucket',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_BUCKET",
                                                ["gsaur-test"]))
    parser.add_argument('--s3_prefix',
                        help='list(string) S3 prefix',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_PREFIX",
                                                ["sagemaker"]))
    parser.add_argument('--s3_endpoint_url',
                        help='(string) S3 endpoint URL',
                        type=str,
                        default=rospy.get_param("S3_ENDPOINT_URL", None))
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=rospy.get_param("AWS_REGION", "us-east-1"))
    parser.add_argument('--number_of_trials',
                        help='(integer) Number of trials',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)))
    parser.add_argument(
        '-c',
        '--local_model_directory',
        help='(string) Path to a folder containing a checkpoint \
                             to restore the model from.',
        type=str,
        default='./checkpoint')
    parser.add_argument('--number_of_resets',
                        help='(integer) Number of resets',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
    parser.add_argument('--penalty_seconds',
                        help='(float) penalty second',
                        type=float,
                        default=float(rospy.get_param("PENALTY_SECONDS", 2.0)))
    parser.add_argument('--job_type',
                        help='(string) job type',
                        type=str,
                        default=rospy.get_param("JOB_TYPE", "EVALUATION"))
    parser.add_argument('--is_continuous',
                        help='(boolean) is continous after lap completion',
                        type=bool,
                        default=utils.str2bool(
                            rospy.get_param("IS_CONTINUOUS", False)))
    parser.add_argument('--race_type',
                        help='(string) Race type',
                        type=str,
                        default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
    parser.add_argument('--off_track_penalty',
                        help='(float) off track penalty second',
                        type=float,
                        default=float(rospy.get_param("OFF_TRACK_PENALTY",
                                                      2.0)))
    parser.add_argument('--collision_penalty',
                        help='(float) collision penalty second',
                        type=float,
                        default=float(rospy.get_param("COLLISION_PENALTY",
                                                      5.0)))
    parser.add_argument('--round_robin_advance_dist',
                        help='(float) round robin distance 0-1',
                        type=float,
                        default=float(
                            rospy.get_param("ROUND_ROBIN_ADVANCE_DIST", 0.05)))
    parser.add_argument('--start_position_offset',
                        help='(float) offset start 0-1',
                        type=float,
                        default=float(
                            rospy.get_param("START_POSITION_OFFSET", 0.0)))

    args = parser.parse_args()
    arg_s3_bucket = args.s3_bucket
    arg_s3_prefix = args.s3_prefix
    logger.info("S3 bucket: %s \n S3 prefix: %s \n S3 endpoint URL: %s",
                args.s3_bucket, args.s3_prefix, args.s3_endpoint_url)

    metrics_s3_buckets = rospy.get_param('METRICS_S3_BUCKET')
    metrics_s3_object_keys = rospy.get_param('METRICS_S3_OBJECT_KEY')

    arg_s3_bucket, arg_s3_prefix = utils.force_list(
        arg_s3_bucket), utils.force_list(arg_s3_prefix)
    metrics_s3_buckets = utils.force_list(metrics_s3_buckets)
    metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys)

    validate_list = [
        arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets,
        metrics_s3_object_keys
    ]

    simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
    mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None)
    if simtrace_s3_bucket:
        simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
        simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket)
        simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix)
        validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix])
    if mp4_s3_bucket:
        mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')
        mp4_s3_bucket = utils.force_list(mp4_s3_bucket)
        mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix)
        validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix])

    if not all([lambda x: len(x) == len(validate_list[0]), validate_list]):
        log_and_exit(
            "Eval worker error: Incorrect arguments passed: {}".format(
                validate_list), SIMAPP_SIMULATION_WORKER_EXCEPTION,
            SIMAPP_EVENT_ERROR_CODE_500)
    if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT:
        raise GenericRolloutException(
            "number of resets is less than {}".format(MIN_RESET_COUNT))

    # Instantiate Cameras
    if len(arg_s3_bucket) == 1:
        configure_camera(namespaces=['racecar'])
    else:
        configure_camera(namespaces=[
            'racecar_{}'.format(str(agent_index))
            for agent_index in range(len(arg_s3_bucket))
        ])

    agent_list = list()
    s3_bucket_dict = dict()
    s3_prefix_dict = dict()
    s3_writers = list()
    start_positions = get_start_positions(len(arg_s3_bucket))
    done_condition = utils.str_to_done_condition(
        rospy.get_param("DONE_CONDITION", any))
    park_positions = utils.pos_2d_str_to_list(
        rospy.get_param("PARK_POSITIONS", []))
    # if not pass in park positions for all done condition case, use default
    if not park_positions:
        park_positions = [DEFAULT_PARK_POSITION for _ in arg_s3_bucket]
    for agent_index, _ in enumerate(arg_s3_bucket):
        agent_name = 'agent' if len(arg_s3_bucket) == 1 else 'agent_{}'.format(
            str(agent_index))
        racecar_name = 'racecar' if len(
            arg_s3_bucket) == 1 else 'racecar_{}'.format(str(agent_index))
        s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index]
        s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index]

        # download model metadata
        model_metadata = ModelMetadata(
            bucket=arg_s3_bucket[agent_index],
            s3_key=get_s3_key(arg_s3_prefix[agent_index],
                              MODEL_METADATA_S3_POSTFIX),
            region_name=args.aws_region,
            s3_endpoint_url=args.s3_endpoint_url,
            local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(agent_name))
        _, _, version = model_metadata.get_model_metadata_info()

        # Select the optimal model
        utils.do_model_selection(s3_bucket=arg_s3_bucket[agent_index],
                                 s3_prefix=arg_s3_prefix[agent_index],
                                 region=args.aws_region,
                                 s3_endpoint_url=args.s3_endpoint_url)

        agent_config = {
            'model_metadata': model_metadata,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [
                    link_name.replace('racecar', racecar_name)
                    for link_name in LINK_NAMES
                ],
                ConfigParams.VELOCITY_LIST.value: [
                    velocity_topic.replace('racecar', racecar_name)
                    for velocity_topic in VELOCITY_TOPICS
                ],
                ConfigParams.STEERING_LIST.value: [
                    steering_topic.replace('racecar', racecar_name)
                    for steering_topic in STEERING_TOPICS
                ],
                ConfigParams.CHANGE_START.value:
                utils.str2bool(rospy.get_param('CHANGE_START_POSITION',
                                               False)),
                ConfigParams.ALT_DIR.value:
                utils.str2bool(
                    rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
                ConfigParams.ACTION_SPACE_PATH.value:
                model_metadata.local_path,
                ConfigParams.REWARD.value:
                reward_function,
                ConfigParams.AGENT_NAME.value:
                racecar_name,
                ConfigParams.VERSION.value:
                version,
                ConfigParams.NUMBER_OF_RESETS.value:
                args.number_of_resets,
                ConfigParams.PENALTY_SECONDS.value:
                args.penalty_seconds,
                ConfigParams.NUMBER_OF_TRIALS.value:
                args.number_of_trials,
                ConfigParams.IS_CONTINUOUS.value:
                args.is_continuous,
                ConfigParams.RACE_TYPE.value:
                args.race_type,
                ConfigParams.COLLISION_PENALTY.value:
                args.collision_penalty,
                ConfigParams.OFF_TRACK_PENALTY.value:
                args.off_track_penalty,
                ConfigParams.START_POSITION.value:
                start_positions[agent_index],
                ConfigParams.DONE_CONDITION.value:
                done_condition,
                ConfigParams.ROUND_ROBIN_ADVANCE_DIST.value:
                args.round_robin_advance_dist,
                ConfigParams.START_POSITION_OFFSET.value:
                args.start_position_offset
            }
        }

        metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value:
            metrics_s3_buckets[agent_index],
            MetricsS3Keys.METRICS_KEY.value:
            metrics_s3_object_keys[agent_index],
            MetricsS3Keys.ENDPOINT_URL.value:
            rospy.get_param('S3_ENDPOINT_URL', None),
            # Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed
            # or default argument set
            MetricsS3Keys.REGION.value:
            args.aws_region
        }
        aws_region = rospy.get_param('AWS_REGION', args.aws_region)
        s3_writer_job_info = []
        if simtrace_s3_bucket:
            s3_writer_job_info.append(
                IterationData(
                    'simtrace', simtrace_s3_bucket[agent_index],
                    simtrace_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        SIM_TRACE_EVALUATION_LOCAL_FILE.value)))
        if mp4_s3_bucket:
            s3_writer_job_info.extend([
                IterationData(
                    'pip', mp4_s3_bucket[agent_index],
                    mp4_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        CAMERA_PIP_MP4_VALIDATION_LOCAL_PATH.value)),
                IterationData(
                    '45degree', mp4_s3_bucket[agent_index],
                    mp4_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        CAMERA_45DEGREE_MP4_VALIDATION_LOCAL_PATH.value)),
                IterationData(
                    'topview', mp4_s3_bucket[agent_index],
                    mp4_s3_object_prefix[agent_index], aws_region,
                    os.path.join(
                        ITERATION_DATA_LOCAL_FILE_PATH, agent_name,
                        IterationDataLocalFileNames.
                        CAMERA_TOPVIEW_MP4_VALIDATION_LOCAL_PATH.value))
            ])

        s3_writers.append(
            S3Writer(job_info=s3_writer_job_info,
                     s3_endpoint_url=args.s3_endpoint_url))
        run_phase_subject = RunPhaseSubject()
        agent_list.append(
            create_rollout_agent(
                agent_config,
                EvalMetrics(agent_name, metrics_s3_config, args.is_continuous),
                run_phase_subject))
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())

    # ROS service to indicate all the robomaker markov packages are ready for consumption
    signal_robomaker_markov_package_ready()

    PhaseObserver('/agent/training_phase', run_phase_subject)
    enable_domain_randomization = utils.str2bool(
        rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False))

    sm_hyperparams_dict = {}
    graph_manager, _ = get_graph_manager(
        hp_dict=sm_hyperparams_dict,
        agent_list=agent_list,
        run_phase_subject=run_phase_subject,
        enable_domain_randomization=enable_domain_randomization,
        done_condition=done_condition)

    ds_params_instance = S3BotoDataStoreParameters(
        aws_region=args.aws_region,
        bucket_names=s3_bucket_dict,
        base_checkpoint_dir=args.local_model_directory,
        s3_folders=s3_prefix_dict,
        s3_endpoint_url=args.s3_endpoint_url)

    graph_manager.data_store = S3BotoDataStore(params=ds_params_instance,
                                               graph_manager=graph_manager,
                                               ignore_lock=True)
    graph_manager.env_params.seed = 0

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.local_model_directory

    evaluation_worker(graph_manager=graph_manager,
                      number_of_trials=args.number_of_trials,
                      task_parameters=task_parameters,
                      s3_writers=s3_writers,
                      is_continuous=args.is_continuous,
                      park_positions=park_positions)