예제 #1
0
    def __init__(self,
                 bucket,
                 s3_prefix,
                 region_name='us-east-1',
                 local_path="./custom_files/agent/ip.json",
                 max_retry_attempts=5,
                 backoff_time_sec=1.0):
        '''ip upload, download, and parse

        Args:
            bucket (str): s3 bucket
            s3_prefix (str): s3 prefix
            region_name (str): s3 region name
            local_path (str): ip addres json file local path
            max_retry_attempts (int): maximum retry attempts
            backoff_time_sec (float): retry backoff time in seconds

        '''
        if not s3_prefix or not bucket:
            log_and_exit(
                "Ip config S3 prefix or bucket not available for S3. \
                         bucket: {}, prefix: {}".format(bucket, s3_prefix),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500)
        self._bucket = bucket
        self._s3_ip_done_key = os.path.normpath(
            os.path.join(s3_prefix, IP_DONE_POSTFIX))
        self._s3_ip_address_key = os.path.normpath(
            os.path.join(s3_prefix, IP_ADDRESS_POSTFIX))
        self._local_path = local_path
        self._s3_client = S3Client(region_name, max_retry_attempts,
                                   backoff_time_sec)
        self._ip_file = None
예제 #2
0
    def __init__(self, bucket, s3_prefix, region_name,
                 max_sample_count=None, sampling_frequency=None,
                 max_retry_attempts=5, backoff_time_sec=1.0):
        '''Sample Collector class to collect sample and persist to S3.

        Args:
            bucket (str): S3 bucket string
            s3_prefix (str): S3 prefix string
            region_name (str): S3 region name
            max_sample_count (int): max sample count
            sampling_frequency (int): sampleing frequency
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry
        '''
        self.max_sample_count = max_sample_count or 0
        self.sampling_frequency = sampling_frequency or 1
        if self.sampling_frequency < 1:
            err_msg = "sampling_frequency must be larger or equal to 1. (Given: {})".format(self.sampling_frequency)
            raise GenericTrainerException(err_msg)
        self.s3_prefix = s3_prefix

        self._cur_sample_count = 0
        self._cur_frequency = 0
        self._bucket = bucket
        self._s3_client = S3Client(region_name,
                                   max_retry_attempts,
                                   backoff_time_sec)
예제 #3
0
    def __init__(self,
                 upload_type,
                 bucket,
                 s3_prefix,
                 region_name="us-east-1",
                 local_path="./custom_files/iteration_data/\
                 agent/file",
                 s3_endpoint_url=None,
                 max_retry_attempts=5,
                 backoff_time_sec=1.0):
        '''This class is for all s3 simtrace and video upload

        Args:
            upload_type (str): upload simtrace or video type
            bucket (str): S3 bucket string
            s3_prefix (str): S3 prefix string
            region_name (str): S3 region name
            local_path (str): file local path
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry

        '''
        self._upload_type = upload_type
        self._bucket = bucket
        self._s3_key = os.path.normpath(
            os.path.join(s3_prefix,
                         SIMTRACE_VIDEO_POSTFIX_DICT[self._upload_type]))
        self._local_path = local_path
        self._upload_num = 0
        self._s3_client = S3Client(region_name, s3_endpoint_url,
                                   max_retry_attempts, backoff_time_sec)
예제 #4
0
    def __init__(self, bucket, s3_prefix, region_name='us-east-1',
                 local_dir='./checkpoint/agent',
                 max_retry_attempts=5, backoff_time_sec=1.0,
                 output_head_format=FROZEN_HEAD_OUTPUT_GRAPH_FORMAT_MAPPING[
                                    TrainingAlgorithm.CLIPPED_PPO.value]):
        '''This class is for tensorflow model upload and download

        Args:
            bucket (str): S3 bucket string
            s3_prefix (str): S3 prefix string
            region_name (str): S3 region name
            local_dir (str): local file directory
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry
            output_head_format (str): output head format for the specific algorithm and action space
                                      which will be used to store the frozen graph
        '''
        if not bucket or not s3_prefix:
            log_and_exit("checkpoint S3 prefix or bucket not available for S3. \
                         bucket: {}, prefix {}"
                         .format(bucket, s3_prefix),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
        self._bucket = bucket
        self._local_dir = os.path.normpath(
            CHECKPOINT_LOCAL_DIR_FORMAT.format(local_dir))
        self._s3_key_dir = os.path.normpath(os.path.join(s3_prefix,
                                                         CHECKPOINT_POSTFIX_DIR))
        self._delete_queue = queue.Queue()
        self._s3_client = S3Client(region_name,
                                   max_retry_attempts,
                                   backoff_time_sec)
        self.output_head_format = output_head_format
예제 #5
0
    def __init__(self,
                 bucket,
                 s3_key,
                 region_name='us-east-1',
                 s3_endpoint_url=None,
                 max_retry_attempts=5,
                 backoff_time_sec=1.0):
        '''metrics upload

        Args:
            bucket (str): s3 bucket
            s3_key (str): s3 key
            region_name (str): s3 region name
            max_retry_attempts (int): maximum retry attempts
            backoff_time_sec (float): retry backoff time in seconds
        '''
        if not s3_key or not bucket:
            log_and_exit(
                "Metrics S3 key or bucket not available for S3. \
                         bucket: {}, key: {}".format(bucket, s3_key),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500)
        self._bucket = bucket
        self._s3_key = s3_key
        self._s3_client = S3Client(region_name, s3_endpoint_url,
                                   max_retry_attempts, backoff_time_sec)
예제 #6
0
    def __init__(self, bucket, s3_key, region_name="us-east-1",
                 local_path="./custom_files/agent/customer_reward_function.py",
                 max_retry_attempts=5, backoff_time_sec=1.0):
        '''reward function upload, download, and parse

        Args:
            bucket (str): S3 bucket string
            s3_key (str): S3 key string
            region_name (str): S3 region name
            local_path (str): file local path
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry
        '''

        # check s3 key and bucket exist for reward function
        if not s3_key or not bucket:
            log_and_exit("Reward function code S3 key or bucket not available for S3. \
                         bucket: {}, key: {}".format(bucket, s3_key),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
        self._bucket = bucket
        # Strip the s3://<bucket> from uri, if s3_key past in as uri
        self._s3_key = s3_key.replace('s3://{}/'.format(self._bucket), '')
        self._local_path_processed = local_path
        # if _local_path_processed is test.py then _local_path_preprocessed is test_preprocessed.py
        self._local_path_preprocessed = ("_preprocessed.py").join(local_path.split(".py"))
        # if local _local_path_processed is ./custom_files/agent/customer_reward_function.py,
        # then the import path should be custom_files.agent.customer_reward_function by
        # remove ".py", remove "./", and replace "/" and "."
        self._import_path = local_path.replace(".py", "").replace("./", "").replace("/", ".")
        self._reward_function = None
        self._s3_client = S3Client(region_name,
                                   max_retry_attempts,
                                   backoff_time_sec)
예제 #7
0
    def __init__(
        self,
        agent_type,
        bucket,
        s3_key,
        region_name="us-east-1",
        local_path="params.yaml",
        max_retry_attempts=5,
        backoff_time_sec=1.0,
    ):
        """yaml upload, download, and parse

        Args:
            agent_type (str): rollout for training, evaluation for eval
            bucket (str): S3 bucket string
            s3_key: (str): S3 key string.
            region_name (str): S3 region name
            local_path (str): file local path
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry

        """
        if not bucket or not s3_key:
            log_and_exit(
                "yaml file S3 key or bucket not available for S3. \
                         bucket: {}, key: {}".format(
                    bucket, s3_key
                ),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500,
            )
        self._bucket = bucket
        # Strip the s3://<bucket> from uri, if s3_key past in as uri
        self._s3_key = s3_key.replace("s3://{}/".format(self._bucket), "")
        self._local_path = local_path
        self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec)
        self._agent_type = agent_type
        if self._agent_type == AgentType.ROLLOUT.value:
            self._model_s3_bucket_yaml_key = YamlKey.SAGEMAKER_SHARED_S3_BUCKET_YAML_KEY.value
            self._model_s3_prefix_yaml_key = YamlKey.SAGEMAKER_SHARED_S3_PREFIX_YAML_KEY.value
            self._mandatory_yaml_key = TRAINING_MANDATORY_YAML_KEY
        elif self._agent_type == AgentType.EVALUATION.value:
            self._model_s3_bucket_yaml_key = YamlKey.MODEL_S3_BUCKET_YAML_KEY.value
            self._model_s3_prefix_yaml_key = YamlKey.MODEL_S3_PREFIX_YAML_KEY.value
            self._mandatory_yaml_key = EVAL_MANDATORY_YAML_KEY
        elif self._agent_type == AgentType.VIRTUAL_EVENT.value:
            self._mandatory_yaml_key = VIRUTAL_EVENT_MANDATORY_YAML_KEY
        else:
            log_and_exit(
                "Unknown agent type in launch file: {}".format(self._agent_type),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500,
            )
        self._yaml_values = None
        self._is_multicar = False
        self._is_f1 = False
        self._model_s3_buckets = list()
        self._model_metadata_s3_keys = list()
        self._body_shell_types = list()
        self._kinesis_webrtc_signaling_channel_name = None
    def __init__(self,
                 bucket,
                 s3_key,
                 region_name="us-east-1",
                 s3_endpoint_url=None,
                 local_path="./custom_files/agent/model_metadata.json",
                 max_retry_attempts=5,
                 backoff_time_sec=1.0):
        '''Model metadata upload, download, and parse

        Args:
            bucket (str): S3 bucket string
            s3_key: (str): S3 key string.
            region_name (str): S3 region name
            local_path (str): file local path
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry

        '''
        # check s3 key and s3 bucket exist
        if not bucket or not s3_key:
            log_and_exit(
                "model_metadata S3 key or bucket not available for S3. \
                         bucket: {}, key {}".format(bucket, s3_key),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500)
        self._bucket = bucket
        # Strip the s3://<bucket> from uri, if s3_key past in as uri
        self._s3_key = s3_key.replace('s3://{}/'.format(self._bucket), '')
        self._local_path = local_path
        self._local_dir = os.path.dirname(self._local_path)
        self._model_metadata = None
        self._s3_client = S3Client(region_name, s3_endpoint_url,
                                   max_retry_attempts, backoff_time_sec)
    def __init__(self, syncfile_type, bucket, s3_prefix, region_name="us-east-1",
                 s3_endpoint_url=None,
                 local_dir='./checkpoint',
                 max_retry_attempts=5, backoff_time_sec=1.0):
        '''This class is for rl coach sync file: .finished, .lock, and .ready

        Args:
            syncfile_type (str): sync file type
            bucket (str): S3 bucket string
            s3_prefix (str): S3 prefix string
            local_dir (str): local file directory
            checkpoint_dir (str): checkpoint directory
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry
        '''
        if not bucket or not s3_prefix:
            log_and_exit("checkpoint S3 prefix or bucket not available for S3. \
                         bucket: {}, prefix {}"
                         .format(bucket, s3_prefix),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
        self._syncfile_type = syncfile_type
        self._bucket = bucket
        # deepracer checkpoint json s3 key
        self._s3_key = os.path.normpath(os.path.join(
            s3_prefix,
            SYNC_FILES_POSTFIX_DICT[syncfile_type]))
        # deepracer checkpoint json local path
        self._local_path = os.path.normpath(
            SYNC_FILES_LOCAL_PATH_FORMAT_DICT[syncfile_type].format(local_dir))
        self._s3_client = S3Client(region_name,
                                   s3_endpoint_url,
                                   max_retry_attempts,
                                   backoff_time_sec)
예제 #10
0
    def __init__(
        self,
        bucket,
        s3_prefix,
        region_name="us-east-1",
        local_dir="./checkpoint/agent",
        max_retry_attempts=5,
        backoff_time_sec=1.0,
        log_and_cont: bool = False,
    ):
        """This class is for RL coach checkpoint file

        Args:
            bucket (str): S3 bucket string.
            s3_prefix (str): S3 prefix string.
            region_name (str): S3 region name.
                               Defaults to 'us-east-1'.
            local_dir (str, optional): Local file directory.
                                       Defaults to '.checkpoint/agent'.
            max_retry_attempts (int, optional): Maximum number of retry attempts for S3 download/upload.
                                                Defaults to 5.
            backoff_time_sec (float, optional): Backoff second between each retry.
                                                Defaults to 1.0.
            log_and_cont (bool, optional): Log the error and continue with the flow.
                                           Defaults to False.
        """
        if not bucket or not s3_prefix:
            log_and_exit(
                "checkpoint S3 prefix or bucket not available for S3. \
                         bucket: {}, prefix {}".format(bucket, s3_prefix),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500,
            )
        self._bucket = bucket
        # coach checkpoint s3 key
        self._s3_key = os.path.normpath(
            os.path.join(s3_prefix, COACH_CHECKPOINT_POSTFIX))
        # coach checkpoint local path
        self._local_path = os.path.normpath(
            COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # coach checkpoint local temp path
        self._temp_local_path = os.path.normpath(
            TEMP_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # old coach checkpoint s3 key to handle backward compatibility
        self._old_s3_key = os.path.normpath(
            os.path.join(s3_prefix, OLD_COACH_CHECKPOINT_POSTFIX))
        # old coach checkpoint local path to handle backward compatibility
        self._old_local_path = os.path.normpath(
            OLD_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # coach checkpoint state file from rl coach
        self._coach_checkpoint_state_file = CheckpointStateFile(
            os.path.dirname(self._local_path))
        self._s3_client = S3Client(region_name, max_retry_attempts,
                                   backoff_time_sec, log_and_cont)
예제 #11
0
    def __init__(
        self,
        bucket,
        s3_prefix,
        region_name="us-east-1",
        local_dir=".checkpoint/agent",
        max_retry_attempts=0,
        backoff_time_sec=1.0,
        log_and_cont: bool = False,
    ):
        """This class is for deepracer checkpoint json file upload and download

        Args:
            bucket (str): S3 bucket string.
            s3_prefix (str): S3 prefix string.
            region_name (str): S3 region name.
                               Defaults to 'us-east-1'.
            local_dir (str, optional): Local file directory.
                                       Defaults to '.checkpoint/agent'.
            max_retry_attempts (int, optional): Maximum number of retry attempts for S3 download/upload.
                                                Defaults to 0.
            backoff_time_sec (float, optional): Backoff second between each retry.
                                                Defaults to 1.0.
            log_and_cont (bool, optional): Log the error and continue with the flow.
                                           Defaults to False.
        """
        if not bucket or not s3_prefix:
            log_and_exit(
                "checkpoint S3 prefix or bucket not available for S3. \
                         bucket: {}, prefix {}".format(bucket, s3_prefix),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500,
            )
        self._bucket = bucket
        # deepracer checkpoint json s3 key
        self._s3_key = os.path.normpath(
            os.path.join(s3_prefix, DEEPRACER_CHECKPOINT_KEY_POSTFIX))
        # deepracer checkpoint json local path
        self._local_path = os.path.normpath(
            DEEPRACER_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        self._s3_client = S3Client(region_name,
                                   max_retry_attempts,
                                   backoff_time_sec,
                                   log_and_cont=log_and_cont)
def download_custom_files_if_present(s3_bucket, s3_prefix, aws_region,
                                     s3_endpoint_url):
    '''download custom environment and preset files

    Args:
        s3_bucket (str): s3 bucket string
        s3_prefix (str): s3 prefix string
        aws_region (str): aws region string

    Returns:
        tuple (bool, bool): tuple of bool on whether preset and environemnt
        is downloaded successfully
    '''
    success_environment_download, success_preset_download = False, False
    try:
        s3_client = S3Client(region_name=aws_region,
                             s3_endpoint_url=s3_endpoint_url,
                             max_retry_attempts=0)
        environment_file_s3_key = os.path.normpath(
            s3_prefix + "/environments/deepracer_racetrack_env.py")
        environment_local_path = os.path.join(CUSTOM_FILES_PATH,
                                              "deepracer_racetrack_env.py")
        s3_client.download_file(bucket=s3_bucket,
                                s3_key=environment_file_s3_key,
                                local_path=environment_local_path)
        success_environment_download = True
    except botocore.exceptions.ClientError:
        pass
    try:
        preset_file_s3_key = os.path.normpath(s3_prefix + "/presets/preset.py")
        preset_local_path = os.path.join(CUSTOM_FILES_PATH, "preset.py")
        s3_client.download_file(bucket=s3_bucket,
                                s3_key=preset_file_s3_key,
                                local_path=preset_local_path)
        success_preset_download = True
    except botocore.exceptions.ClientError:
        pass
    return success_preset_download, success_environment_download
예제 #13
0
def main():
    screen.set_use_colors(False)

    parser = argparse.ArgumentParser()
    parser.add_argument('-pk',
                        '--preset_s3_key',
                        help="(string) Name of a preset to download from S3",
                        type=str,
                        required=False)
    parser.add_argument(
        '-ek',
        '--environment_s3_key',
        help="(string) Name of an environment file to download from S3",
        type=str,
        required=False)
    parser.add_argument('--model_metadata_s3_key',
                        help="(string) Model Metadata File S3 Key",
                        type=str,
                        required=False)
    parser.add_argument(
        '-c',
        '--checkpoint_dir',
        help=
        '(string) Path to a folder containing a checkpoint to write the model to.',
        type=str,
        default='./checkpoint')
    parser.add_argument(
        '--pretrained_checkpoint_dir',
        help='(string) Path to a folder for downloading a pre-trained model',
        type=str,
        default=PRETRAINED_MODEL_DIR)
    parser.add_argument('--s3_bucket',
                        help='(string) S3 bucket',
                        type=str,
                        default=os.environ.get(
                            "SAGEMAKER_SHARED_S3_BUCKET_PATH", "gsaur-test"))
    parser.add_argument('--s3_prefix',
                        help='(string) S3 prefix',
                        type=str,
                        default='sagemaker')
    parser.add_argument('--framework',
                        help='(string) tensorflow or mxnet',
                        type=str,
                        default='tensorflow')
    parser.add_argument('--pretrained_s3_bucket',
                        help='(string) S3 bucket for pre-trained model',
                        type=str)
    parser.add_argument('--pretrained_s3_prefix',
                        help='(string) S3 prefix for pre-trained model',
                        type=str,
                        default='sagemaker')
    parser.add_argument('--aws_region',
                        help='(string) AWS region',
                        type=str,
                        default=os.environ.get("AWS_REGION", "us-east-1"))

    args, _ = parser.parse_known_args()

    s3_client = S3Client(region_name=args.aws_region, max_retry_attempts=0)

    # download model metadata
    # TODO: replace 'agent' with name of each agent
    model_metadata_download = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=args.model_metadata_s3_key,
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent'))
    model_metadata_info = model_metadata_download.get_model_metadata_info()
    network_type = model_metadata_info[ModelMetadataKeys.NEURAL_NETWORK.value]
    version = model_metadata_info[ModelMetadataKeys.VERSION.value]

    # upload model metadata
    model_metadata_upload = ModelMetadata(
        bucket=args.s3_bucket,
        s3_key=get_s3_key(args.s3_prefix, MODEL_METADATA_S3_POSTFIX),
        region_name=args.aws_region,
        local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent'))
    model_metadata_upload.persist(
        s3_kms_extra_args=utils.get_s3_kms_extra_args())

    shutil.copy2(model_metadata_download.local_path, SM_MODEL_OUTPUT_DIR)

    success_custom_preset = False
    if args.preset_s3_key:
        preset_local_path = "./markov/presets/preset.py"
        try:
            s3_client.download_file(bucket=args.s3_bucket,
                                    s3_key=args.preset_s3_key,
                                    local_path=preset_local_path)
            success_custom_preset = True
        except botocore.exceptions.ClientError:
            pass
        if not success_custom_preset:
            logger.info(
                "Could not download the preset file. Using the default DeepRacer preset."
            )
        else:
            preset_location = "markov.presets.preset:graph_manager"
            graph_manager = short_dynamic_import(preset_location,
                                                 ignore_module_case=True)
            s3_client.upload_file(
                bucket=args.s3_bucket,
                s3_key=os.path.normpath("%s/presets/preset.py" %
                                        args.s3_prefix),
                local_path=preset_local_path,
                s3_kms_extra_args=utils.get_s3_kms_extra_args())
            if success_custom_preset:
                logger.info("Using preset: %s" % args.preset_s3_key)

    if not success_custom_preset:
        params_blob = os.environ.get('SM_TRAINING_ENV', '')
        if params_blob:
            params = json.loads(params_blob)
            sm_hyperparams_dict = params["hyperparameters"]
        else:
            sm_hyperparams_dict = {}

        #! TODO each agent should have own config
        agent_config = {
            'model_metadata': model_metadata_download,
            ConfigParams.CAR_CTRL_CONFIG.value: {
                ConfigParams.LINK_NAME_LIST.value: [],
                ConfigParams.VELOCITY_LIST.value: {},
                ConfigParams.STEERING_LIST.value: {},
                ConfigParams.CHANGE_START.value: None,
                ConfigParams.ALT_DIR.value: None,
                ConfigParams.MODEL_METADATA.value: model_metadata_download,
                ConfigParams.REWARD.value: None,
                ConfigParams.AGENT_NAME.value: 'racecar'
            }
        }

        agent_list = list()
        agent_list.append(create_training_agent(agent_config))

        graph_manager, robomaker_hyperparams_json = get_graph_manager(
            hp_dict=sm_hyperparams_dict,
            agent_list=agent_list,
            run_phase_subject=None,
            run_type=str(RunType.TRAINER))

        # Upload hyperparameters to SageMaker shared s3 bucket
        hyperparameters = Hyperparameters(bucket=args.s3_bucket,
                                          s3_key=get_s3_key(
                                              args.s3_prefix,
                                              HYPERPARAMETER_S3_POSTFIX),
                                          region_name=args.aws_region)
        hyperparameters.persist(
            hyperparams_json=robomaker_hyperparams_json,
            s3_kms_extra_args=utils.get_s3_kms_extra_args())

        # Attach sample collector to graph_manager only if sample count > 0
        max_sample_count = int(sm_hyperparams_dict.get("max_sample_count", 0))
        if max_sample_count > 0:
            sample_collector = SampleCollector(
                bucket=args.s3_bucket,
                s3_prefix=args.s3_prefix,
                region_name=args.aws_region,
                max_sample_count=max_sample_count,
                sampling_frequency=int(
                    sm_hyperparams_dict.get("sampling_frequency", 1)))
            graph_manager.sample_collector = sample_collector

    # persist IP config from sagemaker to s3
    ip_config = IpConfig(bucket=args.s3_bucket,
                         s3_prefix=args.s3_prefix,
                         region_name=args.aws_region)
    ip_config.persist(s3_kms_extra_args=utils.get_s3_kms_extra_args())

    training_algorithm = model_metadata_download.training_algorithm
    output_head_format = FROZEN_HEAD_OUTPUT_GRAPH_FORMAT_MAPPING[
        training_algorithm]

    use_pretrained_model = args.pretrained_s3_bucket and args.pretrained_s3_prefix
    # Handle backward compatibility
    if use_pretrained_model:
        # checkpoint s3 instance for pretrained model
        # TODO: replace 'agent' for multiagent training
        checkpoint = Checkpoint(bucket=args.pretrained_s3_bucket,
                                s3_prefix=args.pretrained_s3_prefix,
                                region_name=args.aws_region,
                                agent_name='agent',
                                checkpoint_dir=args.pretrained_checkpoint_dir,
                                output_head_format=output_head_format)
        # make coach checkpoint compatible
        if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible(
        ):
            checkpoint.rl_coach_checkpoint.make_compatible(
                checkpoint.syncfile_ready)
        # get best model checkpoint string
        model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint(
        )
        # Select the best checkpoint model by uploading rl coach .coach_checkpoint file
        checkpoint.rl_coach_checkpoint.update(
            model_checkpoint_name=model_checkpoint_name,
            s3_kms_extra_args=utils.get_s3_kms_extra_args())
        # add checkpoint into checkpoint_dict
        checkpoint_dict = {'agent': checkpoint}
        # load pretrained model
        ds_params_instance_pretrained = S3BotoDataStoreParameters(
            checkpoint_dict=checkpoint_dict)
        data_store_pretrained = S3BotoDataStore(ds_params_instance_pretrained,
                                                graph_manager, True)
        data_store_pretrained.load_from_store()

    memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters(
        redis_address="localhost",
        redis_port=6379,
        run_type=str(RunType.TRAINER),
        channel=args.s3_prefix,
        network_type=network_type)

    graph_manager.memory_backend_params = memory_backend_params

    # checkpoint s3 instance for training model
    checkpoint = Checkpoint(bucket=args.s3_bucket,
                            s3_prefix=args.s3_prefix,
                            region_name=args.aws_region,
                            agent_name='agent',
                            checkpoint_dir=args.checkpoint_dir,
                            output_head_format=output_head_format)
    checkpoint_dict = {'agent': checkpoint}
    ds_params_instance = S3BotoDataStoreParameters(
        checkpoint_dict=checkpoint_dict)

    graph_manager.data_store_params = ds_params_instance

    graph_manager.data_store = S3BotoDataStore(ds_params_instance,
                                               graph_manager)

    task_parameters = TaskParameters()
    task_parameters.experiment_path = SM_MODEL_OUTPUT_DIR
    task_parameters.checkpoint_save_secs = 20
    if use_pretrained_model:
        task_parameters.checkpoint_restore_path = args.pretrained_checkpoint_dir
    task_parameters.checkpoint_save_dir = args.checkpoint_dir

    training_worker(
        graph_manager=graph_manager,
        task_parameters=task_parameters,
        user_batch_size=json.loads(robomaker_hyperparams_json)["batch_size"],
        user_episode_per_rollout=json.loads(
            robomaker_hyperparams_json)["num_episodes_between_training"],
        training_algorithm=training_algorithm)
예제 #14
0
    def __init__(self,
                 queue_url,
                 aws_region='us-east-1',
                 race_duration=180,
                 number_of_trials=3,
                 number_of_resets=10000,
                 penalty_seconds=2.0,
                 off_track_penalty=2.0,
                 collision_penalty=5.0,
                 is_continuous=False,
                 race_type="TIME_TRIAL"):
        # constructor arguments
        self._model_updater = ModelUpdater.get_instance()
        self._deepracer_path = rospkg.RosPack().get_path(
            DeepRacerPackages.DEEPRACER_SIMULATION_ENVIRONMENT)
        body_shell_path = os.path.join(self._deepracer_path, "meshes", "f1")
        self._valid_body_shells = \
            set(".".join(f.split(".")[:-1]) for f in os.listdir(body_shell_path) if os.path.isfile(
                os.path.join(body_shell_path, f)))
        self._valid_body_shells.add(const.BodyShellType.DEFAULT.value)
        self._valid_car_colors = set(e.value for e in const.CarColorType
                                     if "f1" not in e.value)
        self._num_sectors = int(rospy.get_param("NUM_SECTORS", "3"))
        self._queue_url = queue_url
        self._region = aws_region
        self._number_of_trials = number_of_trials
        self._number_of_resets = number_of_resets
        self._penalty_seconds = penalty_seconds
        self._off_track_penalty = off_track_penalty
        self._collision_penalty = collision_penalty
        self._is_continuous = is_continuous
        self._race_type = race_type
        self._is_save_simtrace_enabled = False
        self._is_save_mp4_enabled = False
        self._is_event_end = False
        self._done_condition = any
        self._race_duration = race_duration
        self._enable_domain_randomization = False

        # sqs client
        # The boto client errors out after polling for 1 hour.
        self._sqs_client = SQSClient(queue_url=self._queue_url,
                                     region_name=self._region,
                                     max_num_of_msg=MAX_NUM_OF_SQS_MESSAGE,
                                     wait_time_sec=SQS_WAIT_TIME_SEC,
                                     session=refreshed_session(self._region))
        self._s3_client = S3Client(region_name=self._region)
        # tracking current state information
        self._track_data = TrackData.get_instance()
        self._start_lane = self._track_data.center_line
        # keep track of the racer specific info, e.g. s3 locations, alias, car color etc.
        self._current_racer = None
        # keep track of the current race car we are using. It is always "racecar".
        car_model_state = ModelState()
        car_model_state.model_name = "racecar"
        self._current_car_model_state = car_model_state
        self._last_body_shell_type = None
        self._last_sensors = None
        self._racecar_model = AgentModel()
        # keep track of the current control agent we are using
        self._current_agent = None
        # keep track of the current control graph manager
        self._current_graph_manager = None
        # Keep track of previous model's name
        self._prev_model_name = None
        self._hide_position_idx = 0
        self._hide_positions = get_hide_positions(race_car_num=1)
        self._run_phase_subject = RunPhaseSubject()
        self._simtrace_video_s3_writers = []

        self._local_model_directory = './checkpoint'

        # virtual event only have single agent, so set agent_name to "agent"
        self._agent_name = "agent"

        # camera manager
        self._camera_manager = CameraManager.get_instance()

        # setting up virtual event top and follow camera in CameraManager
        # virtual event configure camera does not need to wait for car to spawm because
        # follow car camera is not tracking any car initially
        self._main_cameras, self._sub_camera = configure_camera(
            namespaces=[VIRTUAL_EVENT], is_wait_for_model=False)
        self._spawn_cameras()

        # pop out all cameras after configuration to prevent camera from moving
        self._camera_manager.pop(namespace=VIRTUAL_EVENT)

        dummy_metrics_s3_config = {
            MetricsS3Keys.METRICS_BUCKET.value: "dummy-bucket",
            MetricsS3Keys.METRICS_KEY.value: "dummy-key",
            MetricsS3Keys.REGION.value: self._region
        }

        self._eval_metrics = EvalMetrics(
            agent_name=self._agent_name,
            s3_dict_metrics=dummy_metrics_s3_config,
            is_continuous=self._is_continuous,
            pause_time_before_start=PAUSE_TIME_BEFORE_START)

        # upload a default best sector time for all sectors with time inf for each sector
        # if there is not best sector time existed in s3

        # use the s3 bucket and prefix for yaml file stored as environment variable because
        # here is SimApp use only. For virtual event there is no s3 bucket and prefix past
        # through yaml file. All are past through sqs. For simplicity, reuse the yaml s3 bucket
        # and prefix environment variable.
        virtual_event_best_sector_time = VirtualEventBestSectorTime(
            bucket=os.environ.get("YAML_S3_BUCKET", ''),
            s3_key=get_s3_key(os.environ.get("YAML_S3_PREFIX", ''),
                              SECTOR_TIME_S3_POSTFIX),
            region_name=os.environ.get("APP_REGION", "us-east-1"),
            local_path=SECTOR_TIME_LOCAL_PATH)
        response = virtual_event_best_sector_time.list()
        # this is used to handle situation such as robomaker job crash, so the next robomaker job
        # can catch the best sector time left over from crashed job
        if "Contents" not in response:
            virtual_event_best_sector_time.persist(
                body=json.dumps({
                    SECTOR_X_FORMAT.format(idx + 1): float("inf")
                    for idx in range(self._num_sectors)
                }),
                s3_kms_extra_args=utils.get_s3_kms_extra_args())

        # ROS service to indicate all the robomaker markov packages are ready for consumption
        signal_robomaker_markov_package_ready()

        PhaseObserver('/agent/training_phase', self._run_phase_subject)

        # setup mp4 services
        self._setup_mp4_services()