Exemplo n.º 1
0
    def _download_checkpoint(self, version):
        """Setup the Checkpoint object and selete the best checkpoint.

        Args:
            version (float): SimApp version

        Returns:
            Checkpoint: Checkpoint class instance
        """
        # download checkpoint from s3
        checkpoint = Checkpoint(
            bucket=self._current_racer.inputModel.s3BucketName,
            s3_prefix=self._current_racer.inputModel.s3KeyPrefix,
            region_name=self._region,
            agent_name=self._agent_name,
            checkpoint_dir=self._local_model_directory)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible(
        ):
            checkpoint.rl_coach_checkpoint.make_compatible(
                checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint(
        )
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        model_kms = self._current_racer.inputModel.s3KmsKeyArn if hasattr(
            self._current_racer.inputModel, 's3KmsKeyArn') else None
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_extra_args(model_kms))
        return checkpoint
Exemplo n.º 2
0
def validate(s3_bucket, s3_prefix, aws_region):
    screen.set_use_colors(False)
    screen.log_title(" S3 bucket: {} \n S3 prefix: {}".format(
        s3_bucket, s3_prefix))

    # download model metadata
    model_metadata = ModelMetadata(bucket=s3_bucket,
                                   s3_key=get_s3_key(
                                       s3_prefix, MODEL_METADATA_S3_POSTFIX),
                                   region_name=aws_region,
                                   local_path=MODEL_METADATA_LOCAL_PATH)

    # Create model local path
    os.makedirs(LOCAL_MODEL_DIR)

    try:
        # Handle backward compatibility
        model_metadata_info = model_metadata.get_model_metadata_info()
        observation_list = model_metadata_info[ModelMetadataKeys.SENSOR.value]
        version = model_metadata_info[ModelMetadataKeys.VERSION.value]
    except Exception as ex:
        log_and_exit("Failed to parse model_metadata file: {}".format(ex),
                     SIMAPP_VALIDATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_400)

    # Below get_transition_data function must called before create_training_agent function
    # to avoid 500 in case unsupported Sensor is received.
    # create_training_agent will exit with 500 if unsupported sensor is received,
    # and get_transition_data function below will exit with 400 if unsupported sensor is received.
    # We want to return 400 in model validation case if unsupported sensor is received.
    # Thus, call this get_transition_data function before create_traning_agent function!
    transitions = get_transition_data(observation_list)

    checkpoint = Checkpoint(bucket=s3_bucket,
                            s3_prefix=s3_prefix,
                            region_name=args.aws_region,
                            agent_name='agent',
                            checkpoint_dir=LOCAL_MODEL_DIR)
    # make coach checkpoint compatible
    if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible(
    ):
        checkpoint.rl_coach_checkpoint.make_compatible(
            checkpoint.syncfile_ready)
    # add checkpoint into checkpoint_dict
    checkpoint_dict = {'agent': checkpoint}

    agent_config = {
        'model_metadata': model_metadata,
        ConfigParams.CAR_CTRL_CONFIG.value: {
            ConfigParams.LINK_NAME_LIST.value: [],
            ConfigParams.VELOCITY_LIST.value: {},
            ConfigParams.STEERING_LIST.value: {},
            ConfigParams.CHANGE_START.value: None,
            ConfigParams.ALT_DIR.value: None,
            ConfigParams.MODEL_METADATA.value: model_metadata,
            ConfigParams.REWARD.value: None,
            ConfigParams.AGENT_NAME.value: 'racecar'
        }
    }

    agent_list = list()
    agent_list.append(create_training_agent(agent_config))

    sm_hyperparams_dict = {}
    graph_manager, _ = get_graph_manager(hp_dict=sm_hyperparams_dict,
                                         agent_list=agent_list,
                                         run_phase_subject=None)

    ds_params_instance = S3BotoDataStoreParameters(
        checkpoint_dict=checkpoint_dict)

    graph_manager.data_store = S3BotoDataStore(ds_params_instance,
                                               graph_manager,
                                               ignore_lock=True)

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = LOCAL_MODEL_DIR
    _validate(graph_manager=graph_manager,
              task_parameters=task_parameters,
              transitions=transitions,
              s3_bucket=s3_bucket,
              s3_prefix=s3_prefix,
              aws_region=aws_region)
Exemplo n.º 3
0
def main():
    screen.set_use_colors(False)

    parser = argparse.ArgumentParser()
    parser.add_argument('-pk',
                        '--preset_s3_key',
                        help="(string) Name of a preset to download from S3",
                        type=str,
                        required=False)
    parser.add_argument(
        '-ek',
        '--environment_s3_key',
        help="(string) Name of an environment file to download from S3",
        type=str,
        required=False)
    parser.add_argument('--model_metadata_s3_key',
                        help="(string) Model Metadata File S3 Key",
                        type=str,
                        required=False)
    parser.add_argument(
        '-c',
        '--checkpoint_dir',
        help=
        '(string) Path to a folder containing a checkpoint to write the model to.',
        type=str,
        default='./checkpoint')
    parser.add_argument(
        '--pretrained_checkpoint_dir',
        help='(string) Path to a folder for downloading a pre-trained model',
        type=str,
        default=PRETRAINED_MODEL_DIR)
    parser.add_argument('--s3_bucket',
                        help='(string) S3 bucket',
                        type=str,
                        default=os.environ.get(
                            "SAGEMAKER_SHARED_S3_BUCKET_PATH", "gsaur-test"))
    parser.add_argument('--s3_prefix',
                        help='(string) S3 prefix',
                        type=str,
                        default='sagemaker')
    parser.add_argument('--framework',
                        help='(string) tensorflow or mxnet',
                        type=str,
                        default='tensorflow')
    parser.add_argument('--pretrained_s3_bucket',
                        help='(string) S3 bucket for pre-trained model',
                        type=str)
    parser.add_argument('--pretrained_s3_prefix',
                        help='(string) S3 prefix for pre-trained model',
                        type=str,
                        default='sagemaker')
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=os.environ.get("AWS_REGION", "us-east-1"))

    args, _ = parser.parse_known_args()

    s3_client = S3Client(region_name=args.aws_region, max_retry_attempts=0)

    # download model metadata
    # TODO: replace 'agent' with name of each agent
    model_metadata_download = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=args.model_metadata_s3_key,
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent'))
    model_metadata_info = model_metadata_download.get_model_metadata_info()
    network_type = model_metadata_info[ModelMetadataKeys.NEURAL_NETWORK.value]
    version = model_metadata_info[ModelMetadataKeys.VERSION.value]

    # upload model metadata
    model_metadata_upload = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=get_s3_key(args.s3_prefix, MODEL_METADATA_S3_POSTFIX),
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent'))
    model_metadata_upload.persist(
        s3_kms_extra_args=utils.get_s3_kms_extra_args())

    shutil.copy2(model_metadata_download.local_path, SM_MODEL_OUTPUT_DIR)

    success_custom_preset = False
    if args.preset_s3_key:
        preset_local_path = "./markov/presets/preset.py"
        try:
            s3_client.download_file(bucket=args.s3_bucket,
                                    s3_key=args.preset_s3_key,
                                    local_path=preset_local_path)
            success_custom_preset = True
        except botocore.exceptions.ClientError:
            pass
        if not success_custom_preset:
            logger.info(
                "Could not download the preset file. Using the default DeepRacer preset."
            )
        else:
            preset_location = "markov.presets.preset:graph_manager"
            graph_manager = short_dynamic_import(preset_location,
                                                 ignore_module_case=True)
            s3_client.upload_file(
                bucket=args.s3_bucket,
                s3_key=os.path.normpath("%s/presets/preset.py" %
                                        args.s3_prefix),
                local_path=preset_local_path,
                s3_kms_extra_args=utils.get_s3_kms_extra_args())
            if success_custom_preset:
                logger.info("Using preset: %s" % args.preset_s3_key)

    if not success_custom_preset:
        params_blob = os.environ.get('SM_TRAINING_ENV', '')
        if params_blob:
            params = json.loads(params_blob)
            sm_hyperparams_dict = params["hyperparameters"]
        else:
            sm_hyperparams_dict = {}

        #! TODO each agent should have own config
        agent_config = {
            'model_metadata': model_metadata_download,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [],
                ConfigParams.VELOCITY_LIST.value: {},
                ConfigParams.STEERING_LIST.value: {},
                ConfigParams.CHANGE_START.value: None,
                ConfigParams.ALT_DIR.value: None,
                ConfigParams.MODEL_METADATA.value: model_metadata_download,
                ConfigParams.REWARD.value: None,
                ConfigParams.AGENT_NAME.value: 'racecar'
            }
        }

        agent_list = list()
        agent_list.append(create_training_agent(agent_config))

        graph_manager, robomaker_hyperparams_json = get_graph_manager(
            hp_dict=sm_hyperparams_dict,
            agent_list=agent_list,
            run_phase_subject=None,
            run_type=str(RunType.TRAINER))

        # Upload hyperparameters to SageMaker shared s3 bucket
        hyperparameters = Hyperparameters(bucket=args.s3_bucket,
                                          s3_key=get_s3_key(
                                              args.s3_prefix,
                                              HYPERPARAMETER_S3_POSTFIX),
                                          region_name=args.aws_region)
        hyperparameters.persist(
            hyperparams_json=robomaker_hyperparams_json,
            s3_kms_extra_args=utils.get_s3_kms_extra_args())

        # Attach sample collector to graph_manager only if sample count > 0
        max_sample_count = int(sm_hyperparams_dict.get("max_sample_count", 0))
        if max_sample_count > 0:
            sample_collector = SampleCollector(
                bucket=args.s3_bucket,
                s3_prefix=args.s3_prefix,
                region_name=args.aws_region,
                max_sample_count=max_sample_count,
                sampling_frequency=int(
                    sm_hyperparams_dict.get("sampling_frequency", 1)))
            graph_manager.sample_collector = sample_collector

    # persist IP config from sagemaker to s3
    ip_config = IpConfig(bucket=args.s3_bucket,
                         s3_prefix=args.s3_prefix,
                         region_name=args.aws_region)
    ip_config.persist(s3_kms_extra_args=utils.get_s3_kms_extra_args())

    training_algorithm = model_metadata_download.training_algorithm
    output_head_format = FROZEN_HEAD_OUTPUT_GRAPH_FORMAT_MAPPING[
        training_algorithm]

    use_pretrained_model = args.pretrained_s3_bucket and args.pretrained_s3_prefix
    # Handle backward compatibility
    if use_pretrained_model:
        # checkpoint s3 instance for pretrained model
        # TODO: replace 'agent' for multiagent training
        checkpoint = Checkpoint(bucket=args.pretrained_s3_bucket,
                                s3_prefix=args.pretrained_s3_prefix,
                                region_name=args.aws_region,
                                agent_name='agent',
                                checkpoint_dir=args.pretrained_checkpoint_dir,
                                output_head_format=output_head_format)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible(
        ):
            checkpoint.rl_coach_checkpoint.make_compatible(
                checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint(
        )
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_kms_extra_args())
        # add checkpoint into checkpoint_dict
        checkpoint_dict = {'agent': checkpoint}
        # load pretrained model
        ds_params_instance_pretrained = S3BotoDataStoreParameters(
            checkpoint_dict=checkpoint_dict)
        data_store_pretrained = S3BotoDataStore(ds_params_instance_pretrained,
                                                graph_manager, True)
        data_store_pretrained.load_from_store()

    memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters(
        redis_address="localhost",
        redis_port=6379,
        run_type=str(RunType.TRAINER),
        channel=args.s3_prefix,
        network_type=network_type)

    graph_manager.memory_backend_params = memory_backend_params

    # checkpoint s3 instance for training model
    checkpoint = Checkpoint(bucket=args.s3_bucket,
                            s3_prefix=args.s3_prefix,
                            region_name=args.aws_region,
                            agent_name='agent',
                            checkpoint_dir=args.checkpoint_dir,
                            output_head_format=output_head_format)
    checkpoint_dict = {'agent': checkpoint}
    ds_params_instance = S3BotoDataStoreParameters(
        checkpoint_dict=checkpoint_dict)

    graph_manager.data_store_params = ds_params_instance

    graph_manager.data_store = S3BotoDataStore(ds_params_instance,
                                               graph_manager)

    task_parameters = TaskParameters()
    task_parameters.experiment_path = SM_MODEL_OUTPUT_DIR
    task_parameters.checkpoint_save_secs = 20
    if use_pretrained_model:
        task_parameters.checkpoint_restore_path = args.pretrained_checkpoint_dir
    task_parameters.checkpoint_save_dir = args.checkpoint_dir

    training_worker(
        graph_manager=graph_manager,
        task_parameters=task_parameters,
        user_batch_size=json.loads(robomaker_hyperparams_json)["batch_size"],
        user_episode_per_rollout=json.loads(
            robomaker_hyperparams_json)["num_episodes_between_training"],
        training_algorithm=training_algorithm)
Exemplo n.º 4
0
def main():
    screen.set_use_colors(False)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-c',
        '--checkpoint_dir',
        help=
        '(string) Path to a folder containing a checkpoint to restore the model from.',
        type=str,
        default='./checkpoint')
    parser.add_argument('--s3_bucket',
                        help='(string) S3 bucket',
                        type=str,
                        default=rospy.get_param("SAGEMAKER_SHARED_S3_BUCKET",
                                                "gsaur-test"))
    parser.add_argument('--s3_prefix',
                        help='(string) S3 prefix',
                        type=str,
                        default=rospy.get_param("SAGEMAKER_SHARED_S3_PREFIX",
                                                "sagemaker"))
    parser.add_argument(
        '--num_workers',
        help="(int) The number of workers started in this pool",
        type=int,
        default=int(rospy.get_param("NUM_WORKERS", 1)))
    parser.add_argument('--rollout_idx',
                        help="(int) The index of current rollout worker",
                        type=int,
                        default=0)
    parser.add_argument('-r',
                        '--redis_ip',
                        help="(string) IP or host for the redis server",
                        default='localhost',
                        type=str)
    parser.add_argument('-rp',
                        '--redis_port',
                        help="(int) Port of the redis server",
                        default=6379,
                        type=int)
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=rospy.get_param("AWS_REGION", "us-east-1"))
    parser.add_argument('--reward_file_s3_key',
                        help='(string) Reward File S3 Key',
                        type=str,
                        default=rospy.get_param("REWARD_FILE_S3_KEY", None))
    parser.add_argument('--model_metadata_s3_key',
                        help='(string) Model Metadata File S3 Key',
                        type=str,
                        default=rospy.get_param("MODEL_METADATA_FILE_S3_KEY",
                                                None))
    # For training job, reset is not allowed. penalty_seconds, off_track_penalty, and
    # collision_penalty will all be 0 be default
    parser.add_argument('--number_of_resets',
                        help='(integer) Number of resets',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
    parser.add_argument('--penalty_seconds',
                        help='(float) penalty second',
                        type=float,
                        default=float(rospy.get_param("PENALTY_SECONDS", 0.0)))
    parser.add_argument('--job_type',
                        help='(string) job type',
                        type=str,
                        default=rospy.get_param("JOB_TYPE", "TRAINING"))
    parser.add_argument('--is_continuous',
                        help='(boolean) is continous after lap completion',
                        type=bool,
                        default=utils.str2bool(
                            rospy.get_param("IS_CONTINUOUS", False)))
    parser.add_argument('--race_type',
                        help='(string) Race type',
                        type=str,
                        default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
    parser.add_argument('--off_track_penalty',
                        help='(float) off track penalty second',
                        type=float,
                        default=float(rospy.get_param("OFF_TRACK_PENALTY",
                                                      0.0)))
    parser.add_argument('--collision_penalty',
                        help='(float) collision penalty second',
                        type=float,
                        default=float(rospy.get_param("COLLISION_PENALTY",
                                                      0.0)))

    args = parser.parse_args()

    logger.info("S3 bucket: %s", args.s3_bucket)
    logger.info("S3 prefix: %s", args.s3_prefix)

    # Download and import reward function
    # TODO: replace 'agent' with name of each agent for multi-agent training
    reward_function_file = RewardFunction(
        bucket=args.s3_bucket,
        s3_key=args.reward_file_s3_key,
        region_name=args.aws_region,
        local_path=REWARD_FUCTION_LOCAL_PATH_FORMAT.format('agent'))
    reward_function = reward_function_file.get_reward_function()

    # Instantiate Cameras
    configure_camera(namespaces=['racecar'])

    preset_file_success, _ = download_custom_files_if_present(
        s3_bucket=args.s3_bucket,
        s3_prefix=args.s3_prefix,
        aws_region=args.aws_region)

    # download model metadata
    # TODO: replace 'agent' with name of each agent
    model_metadata = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=args.model_metadata_s3_key,
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent'))
    model_metadata_info = model_metadata.get_model_metadata_info()
    version = model_metadata_info[ModelMetadataKeys.VERSION.value]

    agent_config = {
        'model_metadata': model_metadata,
        ConfigParams.CAR_CTRL_CONFIG.value: {
            ConfigParams.LINK_NAME_LIST.value:
            LINK_NAMES,
            ConfigParams.VELOCITY_LIST.value:
            VELOCITY_TOPICS,
            ConfigParams.STEERING_LIST.value:
            STEERING_TOPICS,
            ConfigParams.CHANGE_START.value:
            utils.str2bool(rospy.get_param('CHANGE_START_POSITION', True)),
            ConfigParams.ALT_DIR.value:
            utils.str2bool(
                rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
            ConfigParams.MODEL_METADATA.value:
            model_metadata,
            ConfigParams.REWARD.value:
            reward_function,
            ConfigParams.AGENT_NAME.value:
            'racecar',
            ConfigParams.VERSION.value:
            version,
            ConfigParams.NUMBER_OF_RESETS.value:
            args.number_of_resets,
            ConfigParams.PENALTY_SECONDS.value:
            args.penalty_seconds,
            ConfigParams.NUMBER_OF_TRIALS.value:
            None,
            ConfigParams.IS_CONTINUOUS.value:
            args.is_continuous,
            ConfigParams.RACE_TYPE.value:
            args.race_type,
            ConfigParams.COLLISION_PENALTY.value:
            args.collision_penalty,
            ConfigParams.OFF_TRACK_PENALTY.value:
            args.off_track_penalty
        }
    }

    #! TODO each agent should have own s3 bucket
    metrics_key = rospy.get_param('METRICS_S3_OBJECT_KEY')
    if args.num_workers > 1 and args.rollout_idx > 0:
        key_tuple = os.path.splitext(metrics_key)
        metrics_key = "{}_{}{}".format(key_tuple[0], str(args.rollout_idx),
                                       key_tuple[1])
    metrics_s3_config = {
        MetricsS3Keys.METRICS_BUCKET.value:
        rospy.get_param('METRICS_S3_BUCKET'),
        MetricsS3Keys.METRICS_KEY.value: metrics_key,
        MetricsS3Keys.REGION.value: rospy.get_param('AWS_REGION')
    }

    run_phase_subject = RunPhaseSubject()

    agent_list = list()

    #TODO: replace agent for multi agent training
    # checkpoint s3 instance
    # TODO replace agent with agent_0 and so on for multiagent case
    checkpoint = Checkpoint(bucket=args.s3_bucket,
                            s3_prefix=args.s3_prefix,
                            region_name=args.aws_region,
                            agent_name='agent',
                            checkpoint_dir=args.checkpoint_dir)

    agent_list.append(
        create_rollout_agent(
            agent_config,
            TrainingMetrics(
                agent_name='agent',
                s3_dict_metrics=metrics_s3_config,
                deepracer_checkpoint_json=checkpoint.deepracer_checkpoint_json,
                ckpnt_dir=os.path.join(args.checkpoint_dir, 'agent'),
                run_phase_sink=run_phase_subject,
                use_model_picker=(args.rollout_idx == 0)), run_phase_subject))
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())
    # ROS service to indicate all the robomaker markov packages are ready for consumption
    signal_robomaker_markov_package_ready()

    PhaseObserver('/agent/training_phase', run_phase_subject)

    aws_region = rospy.get_param('AWS_REGION', args.aws_region)
    simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
    mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET',
                                    None) if args.rollout_idx == 0 else None
    if simtrace_s3_bucket:
        simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
        if args.num_workers > 1:
            simtrace_s3_object_prefix = os.path.join(simtrace_s3_object_prefix,
                                                     str(args.rollout_idx))
    if mp4_s3_bucket:
        mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')

    simtrace_video_s3_writers = []
    #TODO: replace 'agent' with 'agent_0' for multi agent training and
    # mp4_s3_object_prefix, mp4_s3_bucket will be a list, so need to access with index
    if simtrace_s3_bucket:
        simtrace_video_s3_writers.append(
            SimtraceVideo(
                upload_type=SimtraceVideoNames.SIMTRACE_TRAINING.value,
                bucket=simtrace_s3_bucket,
                s3_prefix=simtrace_s3_object_prefix,
                region_name=aws_region,
                local_path=SIMTRACE_TRAINING_LOCAL_PATH_FORMAT.format(
                    'agent')))
    if mp4_s3_bucket:
        simtrace_video_s3_writers.extend([
            SimtraceVideo(
                upload_type=SimtraceVideoNames.PIP.value,
                bucket=mp4_s3_bucket,
                s3_prefix=mp4_s3_object_prefix,
                region_name=aws_region,
                local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format('agent')),
            SimtraceVideo(
                upload_type=SimtraceVideoNames.DEGREE45.value,
                bucket=mp4_s3_bucket,
                s3_prefix=mp4_s3_object_prefix,
                region_name=aws_region,
                local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format('agent')),
            SimtraceVideo(
                upload_type=SimtraceVideoNames.TOPVIEW.value,
                bucket=mp4_s3_bucket,
                s3_prefix=mp4_s3_object_prefix,
                region_name=aws_region,
                local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format('agent'))
        ])

    # TODO: replace 'agent' with specific agent name for multi agent training
    ip_config = IpConfig(bucket=args.s3_bucket,
                         s3_prefix=args.s3_prefix,
                         region_name=args.aws_region,
                         local_path=IP_ADDRESS_LOCAL_PATH.format('agent'))
    redis_ip = ip_config.get_ip_config()

    # Download hyperparameters from SageMaker shared s3 bucket
    # TODO: replace 'agent' with name of each agent
    hyperparameters = Hyperparameters(
        bucket=args.s3_bucket,
        s3_key=get_s3_key(args.s3_prefix, HYPERPARAMETER_S3_POSTFIX),
        region_name=args.aws_region,
        local_path=HYPERPARAMETER_LOCAL_PATH_FORMAT.format('agent'))
    sm_hyperparams_dict = hyperparameters.get_hyperparameters_dict()

    enable_domain_randomization = utils.str2bool(
        rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False))
    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service('/gazebo/pause_physics_dr')
    rospy.wait_for_service('/gazebo/unpause_physics_dr')
    pause_physics = ServiceProxyWrapper('/gazebo/pause_physics_dr', Empty)
    unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics_dr', Empty)

    if preset_file_success:
        preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py")
        preset_location += ":graph_manager"
        graph_manager = short_dynamic_import(preset_location,
                                             ignore_module_case=True)
        logger.info("Using custom preset file!")
    else:
        graph_manager, _ = get_graph_manager(
            hp_dict=sm_hyperparams_dict,
            agent_list=agent_list,
            run_phase_subject=run_phase_subject,
            enable_domain_randomization=enable_domain_randomization,
            pause_physics=pause_physics,
            unpause_physics=unpause_physics)

    # If num_episodes_between_training is smaller than num_workers then cancel worker early.
    episode_steps_per_rollout = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps
    # Reduce number of workers if allocated more than num_episodes_between_training
    if args.num_workers > episode_steps_per_rollout:
        logger.info(
            "Excess worker allocated. Reducing from {} to {}...".format(
                args.num_workers, episode_steps_per_rollout))
        args.num_workers = episode_steps_per_rollout
    if args.rollout_idx >= episode_steps_per_rollout or args.rollout_idx >= args.num_workers:
        err_msg_format = "Exiting excess worker..."
        err_msg_format += "(rollout_idx[{}] >= num_workers[{}] or num_episodes_between_training[{}])"
        logger.info(
            err_msg_format.format(args.rollout_idx, args.num_workers,
                                  episode_steps_per_rollout))
        # Close the down the job
        utils.cancel_simulation_job()

    memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters(
        redis_address=redis_ip,
        redis_port=6379,
        run_type=str(RunType.ROLLOUT_WORKER),
        channel=args.s3_prefix,
        num_workers=args.num_workers,
        rollout_idx=args.rollout_idx)

    graph_manager.memory_backend_params = memory_backend_params

    checkpoint_dict = {'agent': checkpoint}
    ds_params_instance = S3BotoDataStoreParameters(
        checkpoint_dict=checkpoint_dict)

    graph_manager.data_store = S3BotoDataStore(ds_params_instance,
                                               graph_manager)

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.checkpoint_dir

    rollout_worker(graph_manager=graph_manager,
                   num_workers=args.num_workers,
                   rollout_idx=args.rollout_idx,
                   task_parameters=task_parameters,
                   simtrace_video_s3_writers=simtrace_video_s3_writers,
                   pause_physics=pause_physics,
                   unpause_physics=unpause_physics)
Exemplo n.º 5
0
def main():
    """ Main function for evaluation worker """
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--preset',
                        help="(string) Name of a preset to run \
                             (class name from the 'presets' directory.)",
                        type=str,
                        required=False)
    parser.add_argument('--s3_bucket',
                        help='list(string) S3 bucket',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_BUCKET",
                                                ["gsaur-test"]))
    parser.add_argument('--s3_prefix',
                        help='list(string) S3 prefix',
                        type=str,
                        nargs='+',
                        default=rospy.get_param("MODEL_S3_PREFIX",
                                                ["sagemaker"]))
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=rospy.get_param("AWS_REGION", "us-east-1"))
    parser.add_argument('--number_of_trials',
                        help='(integer) Number of trials',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_TRIALS", 10)))
    parser.add_argument(
        '-c',
        '--local_model_directory',
        help='(string) Path to a folder containing a checkpoint \
                             to restore the model from.',
        type=str,
        default='./checkpoint')
    parser.add_argument('--number_of_resets',
                        help='(integer) Number of resets',
                        type=int,
                        default=int(rospy.get_param("NUMBER_OF_RESETS", 0)))
    parser.add_argument('--penalty_seconds',
                        help='(float) penalty second',
                        type=float,
                        default=float(rospy.get_param("PENALTY_SECONDS", 2.0)))
    parser.add_argument('--job_type',
                        help='(string) job type',
                        type=str,
                        default=rospy.get_param("JOB_TYPE", "EVALUATION"))
    parser.add_argument('--is_continuous',
                        help='(boolean) is continous after lap completion',
                        type=bool,
                        default=utils.str2bool(
                            rospy.get_param("IS_CONTINUOUS", False)))
    parser.add_argument('--race_type',
                        help='(string) Race type',
                        type=str,
                        default=rospy.get_param("RACE_TYPE", "TIME_TRIAL"))
    parser.add_argument('--off_track_penalty',
                        help='(float) off track penalty second',
                        type=float,
                        default=float(rospy.get_param("OFF_TRACK_PENALTY",
                                                      2.0)))
    parser.add_argument('--collision_penalty',
                        help='(float) collision penalty second',
                        type=float,
                        default=float(rospy.get_param("COLLISION_PENALTY",
                                                      5.0)))

    args = parser.parse_args()
    arg_s3_bucket = args.s3_bucket
    arg_s3_prefix = args.s3_prefix
    logger.info("S3 bucket: %s \n S3 prefix: %s", arg_s3_bucket, arg_s3_prefix)

    metrics_s3_buckets = rospy.get_param('METRICS_S3_BUCKET')
    metrics_s3_object_keys = rospy.get_param('METRICS_S3_OBJECT_KEY')

    arg_s3_bucket, arg_s3_prefix = utils.force_list(
        arg_s3_bucket), utils.force_list(arg_s3_prefix)
    metrics_s3_buckets = utils.force_list(metrics_s3_buckets)
    metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys)

    validate_list = [
        arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets,
        metrics_s3_object_keys
    ]

    simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None)
    mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None)
    if simtrace_s3_bucket:
        simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX')
        simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket)
        simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix)
        validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix])
    if mp4_s3_bucket:
        mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX')
        mp4_s3_bucket = utils.force_list(mp4_s3_bucket)
        mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix)
        validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix])

    if not all([lambda x: len(x) == len(validate_list[0]), validate_list]):
        log_and_exit(
            "Eval worker error: Incorrect arguments passed: {}".format(
                validate_list), SIMAPP_SIMULATION_WORKER_EXCEPTION,
            SIMAPP_EVENT_ERROR_CODE_500)
    if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT:
        raise GenericRolloutException(
            "number of resets is less than {}".format(MIN_RESET_COUNT))

    # Instantiate Cameras
    if len(arg_s3_bucket) == 1:
        configure_camera(namespaces=['racecar'])
    else:
        configure_camera(namespaces=[
            'racecar_{}'.format(str(agent_index))
            for agent_index in range(len(arg_s3_bucket))
        ])

    agent_list = list()
    s3_bucket_dict = dict()
    s3_prefix_dict = dict()
    checkpoint_dict = dict()
    simtrace_video_s3_writers = []
    start_positions = get_start_positions(len(arg_s3_bucket))
    done_condition = utils.str_to_done_condition(
        rospy.get_param("DONE_CONDITION", any))
    park_positions = utils.pos_2d_str_to_list(
        rospy.get_param("PARK_POSITIONS", []))
    # if not pass in park positions for all done condition case, use default
    if not park_positions:
        park_positions = [DEFAULT_PARK_POSITION for _ in arg_s3_bucket]
    for agent_index, _ in enumerate(arg_s3_bucket):
        agent_name = 'agent' if len(arg_s3_bucket) == 1 else 'agent_{}'.format(
            str(agent_index))
        racecar_name = 'racecar' if len(
            arg_s3_bucket) == 1 else 'racecar_{}'.format(str(agent_index))
        s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index]
        s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index]

        # download model metadata
        model_metadata = ModelMetadata(
            bucket=arg_s3_bucket[agent_index],
            s3_key=get_s3_key(arg_s3_prefix[agent_index],
                              MODEL_METADATA_S3_POSTFIX),
            region_name=args.aws_region,
            local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format(agent_name))
        model_metadata_info = model_metadata.get_model_metadata_info()
        version = model_metadata_info[ModelMetadataKeys.VERSION.value]

        # checkpoint s3 instance
        checkpoint = Checkpoint(bucket=arg_s3_bucket[agent_index],
                                s3_prefix=arg_s3_prefix[agent_index],
                                region_name=args.aws_region,
                                agent_name=agent_name,
                                checkpoint_dir=args.local_model_directory)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible(
        ):
            checkpoint.rl_coach_checkpoint.make_compatible(
                checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint(
        )
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_kms_extra_args())

        checkpoint_dict[agent_name] = checkpoint

        agent_config = {
            'model_metadata': model_metadata,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [
                    link_name.replace('racecar', racecar_name)
                    for link_name in LINK_NAMES
                ],
                ConfigParams.VELOCITY_LIST.value: [
                    velocity_topic.replace('racecar', racecar_name)
                    for velocity_topic in VELOCITY_TOPICS
                ],
                ConfigParams.STEERING_LIST.value: [
                    steering_topic.replace('racecar', racecar_name)
                    for steering_topic in STEERING_TOPICS
                ],
                ConfigParams.CHANGE_START.value:
                utils.str2bool(rospy.get_param('CHANGE_START_POSITION',
                                               False)),
                ConfigParams.ALT_DIR.value:
                utils.str2bool(
                    rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)),
                ConfigParams.MODEL_METADATA.value:
                model_metadata,
                ConfigParams.REWARD.value:
                reward_function,
                ConfigParams.AGENT_NAME.value:
                racecar_name,
                ConfigParams.VERSION.value:
                version,
                ConfigParams.NUMBER_OF_RESETS.value:
                args.number_of_resets,
                ConfigParams.PENALTY_SECONDS.value:
                args.penalty_seconds,
                ConfigParams.NUMBER_OF_TRIALS.value:
                args.number_of_trials,
                ConfigParams.IS_CONTINUOUS.value:
                args.is_continuous,
                ConfigParams.RACE_TYPE.value:
                args.race_type,
                ConfigParams.COLLISION_PENALTY.value:
                args.collision_penalty,
                ConfigParams.OFF_TRACK_PENALTY.value:
                args.off_track_penalty,
                ConfigParams.START_POSITION.value:
                start_positions[agent_index],
                ConfigParams.DONE_CONDITION.value:
                done_condition
            }
        }

        metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value:
            metrics_s3_buckets[agent_index],
            MetricsS3Keys.METRICS_KEY.value:
            metrics_s3_object_keys[agent_index],
            # Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed
            # or default argument set
            MetricsS3Keys.REGION.value:
            args.aws_region
        }
        aws_region = rospy.get_param('AWS_REGION', args.aws_region)

        if simtrace_s3_bucket:
            simtrace_video_s3_writers.append(
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.SIMTRACE_EVAL.value,
                    bucket=simtrace_s3_bucket[agent_index],
                    s3_prefix=simtrace_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=SIMTRACE_EVAL_LOCAL_PATH_FORMAT.format(
                        agent_name)))
        if mp4_s3_bucket:
            simtrace_video_s3_writers.extend([
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.PIP.value,
                    bucket=mp4_s3_bucket[agent_index],
                    s3_prefix=mp4_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=CAMERA_PIP_MP4_LOCAL_PATH_FORMAT.format(
                        agent_name)),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.DEGREE45.value,
                    bucket=mp4_s3_bucket[agent_index],
                    s3_prefix=mp4_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=CAMERA_45DEGREE_LOCAL_PATH_FORMAT.format(
                        agent_name)),
                SimtraceVideo(
                    upload_type=SimtraceVideoNames.TOPVIEW.value,
                    bucket=mp4_s3_bucket[agent_index],
                    s3_prefix=mp4_s3_object_prefix[agent_index],
                    region_name=aws_region,
                    local_path=CAMERA_TOPVIEW_LOCAL_PATH_FORMAT.format(
                        agent_name))
            ])

        run_phase_subject = RunPhaseSubject()
        agent_list.append(
            create_rollout_agent(
                agent_config,
                EvalMetrics(agent_name, metrics_s3_config, args.is_continuous),
                run_phase_subject))
    agent_list.append(create_obstacles_agent())
    agent_list.append(create_bot_cars_agent())

    # ROS service to indicate all the robomaker markov packages are ready for consumption
    signal_robomaker_markov_package_ready()

    PhaseObserver('/agent/training_phase', run_phase_subject)
    enable_domain_randomization = utils.str2bool(
        rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False))

    sm_hyperparams_dict = {}

    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service('/gazebo/pause_physics_dr')
    rospy.wait_for_service('/gazebo/unpause_physics_dr')
    pause_physics = ServiceProxyWrapper('/gazebo/pause_physics_dr', Empty)
    unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics_dr', Empty)

    graph_manager, _ = get_graph_manager(
        hp_dict=sm_hyperparams_dict,
        agent_list=agent_list,
        run_phase_subject=run_phase_subject,
        enable_domain_randomization=enable_domain_randomization,
        done_condition=done_condition,
        pause_physics=pause_physics,
        unpause_physics=unpause_physics)

    ds_params_instance = S3BotoDataStoreParameters(
        checkpoint_dict=checkpoint_dict)

    graph_manager.data_store = S3BotoDataStore(params=ds_params_instance,
                                               graph_manager=graph_manager,
                                               ignore_lock=True)
    graph_manager.env_params.seed = 0

    task_parameters = TaskParameters()
    task_parameters.checkpoint_restore_path = args.local_model_directory

    evaluation_worker(graph_manager=graph_manager,
                      number_of_trials=args.number_of_trials,
                      task_parameters=task_parameters,
                      simtrace_video_s3_writers=simtrace_video_s3_writers,
                      is_continuous=args.is_continuous,
                      park_positions=park_positions,
                      race_type=args.race_type,
                      pause_physics=pause_physics,
                      unpause_physics=unpause_physics)