def __init__(self, bucket, s3_prefix, region_name='us-east-1', local_path="./custom_files/agent/ip.json", max_retry_attempts=5, backoff_time_sec=1.0): '''ip upload, download, and parse Args: bucket (str): s3 bucket s3_prefix (str): s3 prefix region_name (str): s3 region name local_path (str): ip addres json file local path max_retry_attempts (int): maximum retry attempts backoff_time_sec (float): retry backoff time in seconds ''' if not s3_prefix or not bucket: log_and_exit( "Ip config S3 prefix or bucket not available for S3. \ bucket: {}, prefix: {}".format(bucket, s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._bucket = bucket self._s3_ip_done_key = os.path.normpath( os.path.join(s3_prefix, IP_DONE_POSTFIX)) self._s3_ip_address_key = os.path.normpath( os.path.join(s3_prefix, IP_ADDRESS_POSTFIX)) self._local_path = local_path self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec) self._ip_file = None
def __init__(self, bucket, s3_prefix, region_name, max_sample_count=None, sampling_frequency=None, max_retry_attempts=5, backoff_time_sec=1.0): '''Sample Collector class to collect sample and persist to S3. Args: bucket (str): S3 bucket string s3_prefix (str): S3 prefix string region_name (str): S3 region name max_sample_count (int): max sample count sampling_frequency (int): sampleing frequency max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry ''' self.max_sample_count = max_sample_count or 0 self.sampling_frequency = sampling_frequency or 1 if self.sampling_frequency < 1: err_msg = "sampling_frequency must be larger or equal to 1. (Given: {})".format(self.sampling_frequency) raise GenericTrainerException(err_msg) self.s3_prefix = s3_prefix self._cur_sample_count = 0 self._cur_frequency = 0 self._bucket = bucket self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec)
def __init__(self, upload_type, bucket, s3_prefix, region_name="us-east-1", local_path="./custom_files/iteration_data/\ agent/file", s3_endpoint_url=None, max_retry_attempts=5, backoff_time_sec=1.0): '''This class is for all s3 simtrace and video upload Args: upload_type (str): upload simtrace or video type bucket (str): S3 bucket string s3_prefix (str): S3 prefix string region_name (str): S3 region name local_path (str): file local path max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry ''' self._upload_type = upload_type self._bucket = bucket self._s3_key = os.path.normpath( os.path.join(s3_prefix, SIMTRACE_VIDEO_POSTFIX_DICT[self._upload_type])) self._local_path = local_path self._upload_num = 0 self._s3_client = S3Client(region_name, s3_endpoint_url, max_retry_attempts, backoff_time_sec)
def __init__(self, bucket, s3_prefix, region_name='us-east-1', local_dir='./checkpoint/agent', max_retry_attempts=5, backoff_time_sec=1.0, output_head_format=FROZEN_HEAD_OUTPUT_GRAPH_FORMAT_MAPPING[ TrainingAlgorithm.CLIPPED_PPO.value]): '''This class is for tensorflow model upload and download Args: bucket (str): S3 bucket string s3_prefix (str): S3 prefix string region_name (str): S3 region name local_dir (str): local file directory max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry output_head_format (str): output head format for the specific algorithm and action space which will be used to store the frozen graph ''' if not bucket or not s3_prefix: log_and_exit("checkpoint S3 prefix or bucket not available for S3. \ bucket: {}, prefix {}" .format(bucket, s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._bucket = bucket self._local_dir = os.path.normpath( CHECKPOINT_LOCAL_DIR_FORMAT.format(local_dir)) self._s3_key_dir = os.path.normpath(os.path.join(s3_prefix, CHECKPOINT_POSTFIX_DIR)) self._delete_queue = queue.Queue() self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec) self.output_head_format = output_head_format
def __init__(self, bucket, s3_key, region_name='us-east-1', s3_endpoint_url=None, max_retry_attempts=5, backoff_time_sec=1.0): '''metrics upload Args: bucket (str): s3 bucket s3_key (str): s3 key region_name (str): s3 region name max_retry_attempts (int): maximum retry attempts backoff_time_sec (float): retry backoff time in seconds ''' if not s3_key or not bucket: log_and_exit( "Metrics S3 key or bucket not available for S3. \ bucket: {}, key: {}".format(bucket, s3_key), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._bucket = bucket self._s3_key = s3_key self._s3_client = S3Client(region_name, s3_endpoint_url, max_retry_attempts, backoff_time_sec)
def __init__(self, bucket, s3_key, region_name="us-east-1", local_path="./custom_files/agent/customer_reward_function.py", max_retry_attempts=5, backoff_time_sec=1.0): '''reward function upload, download, and parse Args: bucket (str): S3 bucket string s3_key (str): S3 key string region_name (str): S3 region name local_path (str): file local path max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry ''' # check s3 key and bucket exist for reward function if not s3_key or not bucket: log_and_exit("Reward function code S3 key or bucket not available for S3. \ bucket: {}, key: {}".format(bucket, s3_key), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._bucket = bucket # Strip the s3://<bucket> from uri, if s3_key past in as uri self._s3_key = s3_key.replace('s3://{}/'.format(self._bucket), '') self._local_path_processed = local_path # if _local_path_processed is test.py then _local_path_preprocessed is test_preprocessed.py self._local_path_preprocessed = ("_preprocessed.py").join(local_path.split(".py")) # if local _local_path_processed is ./custom_files/agent/customer_reward_function.py, # then the import path should be custom_files.agent.customer_reward_function by # remove ".py", remove "./", and replace "/" and "." self._import_path = local_path.replace(".py", "").replace("./", "").replace("/", ".") self._reward_function = None self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec)
def __init__( self, agent_type, bucket, s3_key, region_name="us-east-1", local_path="params.yaml", max_retry_attempts=5, backoff_time_sec=1.0, ): """yaml upload, download, and parse Args: agent_type (str): rollout for training, evaluation for eval bucket (str): S3 bucket string s3_key: (str): S3 key string. region_name (str): S3 region name local_path (str): file local path max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry """ if not bucket or not s3_key: log_and_exit( "yaml file S3 key or bucket not available for S3. \ bucket: {}, key: {}".format( bucket, s3_key ), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, ) self._bucket = bucket # Strip the s3://<bucket> from uri, if s3_key past in as uri self._s3_key = s3_key.replace("s3://{}/".format(self._bucket), "") self._local_path = local_path self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec) self._agent_type = agent_type if self._agent_type == AgentType.ROLLOUT.value: self._model_s3_bucket_yaml_key = YamlKey.SAGEMAKER_SHARED_S3_BUCKET_YAML_KEY.value self._model_s3_prefix_yaml_key = YamlKey.SAGEMAKER_SHARED_S3_PREFIX_YAML_KEY.value self._mandatory_yaml_key = TRAINING_MANDATORY_YAML_KEY elif self._agent_type == AgentType.EVALUATION.value: self._model_s3_bucket_yaml_key = YamlKey.MODEL_S3_BUCKET_YAML_KEY.value self._model_s3_prefix_yaml_key = YamlKey.MODEL_S3_PREFIX_YAML_KEY.value self._mandatory_yaml_key = EVAL_MANDATORY_YAML_KEY elif self._agent_type == AgentType.VIRTUAL_EVENT.value: self._mandatory_yaml_key = VIRUTAL_EVENT_MANDATORY_YAML_KEY else: log_and_exit( "Unknown agent type in launch file: {}".format(self._agent_type), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, ) self._yaml_values = None self._is_multicar = False self._is_f1 = False self._model_s3_buckets = list() self._model_metadata_s3_keys = list() self._body_shell_types = list() self._kinesis_webrtc_signaling_channel_name = None
def __init__(self, bucket, s3_key, region_name="us-east-1", s3_endpoint_url=None, local_path="./custom_files/agent/model_metadata.json", max_retry_attempts=5, backoff_time_sec=1.0): '''Model metadata upload, download, and parse Args: bucket (str): S3 bucket string s3_key: (str): S3 key string. region_name (str): S3 region name local_path (str): file local path max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry ''' # check s3 key and s3 bucket exist if not bucket or not s3_key: log_and_exit( "model_metadata S3 key or bucket not available for S3. \ bucket: {}, key {}".format(bucket, s3_key), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._bucket = bucket # Strip the s3://<bucket> from uri, if s3_key past in as uri self._s3_key = s3_key.replace('s3://{}/'.format(self._bucket), '') self._local_path = local_path self._local_dir = os.path.dirname(self._local_path) self._model_metadata = None self._s3_client = S3Client(region_name, s3_endpoint_url, max_retry_attempts, backoff_time_sec)
def __init__(self, syncfile_type, bucket, s3_prefix, region_name="us-east-1", s3_endpoint_url=None, local_dir='./checkpoint', max_retry_attempts=5, backoff_time_sec=1.0): '''This class is for rl coach sync file: .finished, .lock, and .ready Args: syncfile_type (str): sync file type bucket (str): S3 bucket string s3_prefix (str): S3 prefix string local_dir (str): local file directory checkpoint_dir (str): checkpoint directory max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry ''' if not bucket or not s3_prefix: log_and_exit("checkpoint S3 prefix or bucket not available for S3. \ bucket: {}, prefix {}" .format(bucket, s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._syncfile_type = syncfile_type self._bucket = bucket # deepracer checkpoint json s3 key self._s3_key = os.path.normpath(os.path.join( s3_prefix, SYNC_FILES_POSTFIX_DICT[syncfile_type])) # deepracer checkpoint json local path self._local_path = os.path.normpath( SYNC_FILES_LOCAL_PATH_FORMAT_DICT[syncfile_type].format(local_dir)) self._s3_client = S3Client(region_name, s3_endpoint_url, max_retry_attempts, backoff_time_sec)
def __init__( self, bucket, s3_prefix, region_name="us-east-1", local_dir="./checkpoint/agent", max_retry_attempts=5, backoff_time_sec=1.0, log_and_cont: bool = False, ): """This class is for RL coach checkpoint file Args: bucket (str): S3 bucket string. s3_prefix (str): S3 prefix string. region_name (str): S3 region name. Defaults to 'us-east-1'. local_dir (str, optional): Local file directory. Defaults to '.checkpoint/agent'. max_retry_attempts (int, optional): Maximum number of retry attempts for S3 download/upload. Defaults to 5. backoff_time_sec (float, optional): Backoff second between each retry. Defaults to 1.0. log_and_cont (bool, optional): Log the error and continue with the flow. Defaults to False. """ if not bucket or not s3_prefix: log_and_exit( "checkpoint S3 prefix or bucket not available for S3. \ bucket: {}, prefix {}".format(bucket, s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, ) self._bucket = bucket # coach checkpoint s3 key self._s3_key = os.path.normpath( os.path.join(s3_prefix, COACH_CHECKPOINT_POSTFIX)) # coach checkpoint local path self._local_path = os.path.normpath( COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir)) # coach checkpoint local temp path self._temp_local_path = os.path.normpath( TEMP_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir)) # old coach checkpoint s3 key to handle backward compatibility self._old_s3_key = os.path.normpath( os.path.join(s3_prefix, OLD_COACH_CHECKPOINT_POSTFIX)) # old coach checkpoint local path to handle backward compatibility self._old_local_path = os.path.normpath( OLD_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir)) # coach checkpoint state file from rl coach self._coach_checkpoint_state_file = CheckpointStateFile( os.path.dirname(self._local_path)) self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec, log_and_cont)
def __init__( self, bucket, s3_prefix, region_name="us-east-1", local_dir=".checkpoint/agent", max_retry_attempts=0, backoff_time_sec=1.0, log_and_cont: bool = False, ): """This class is for deepracer checkpoint json file upload and download Args: bucket (str): S3 bucket string. s3_prefix (str): S3 prefix string. region_name (str): S3 region name. Defaults to 'us-east-1'. local_dir (str, optional): Local file directory. Defaults to '.checkpoint/agent'. max_retry_attempts (int, optional): Maximum number of retry attempts for S3 download/upload. Defaults to 0. backoff_time_sec (float, optional): Backoff second between each retry. Defaults to 1.0. log_and_cont (bool, optional): Log the error and continue with the flow. Defaults to False. """ if not bucket or not s3_prefix: log_and_exit( "checkpoint S3 prefix or bucket not available for S3. \ bucket: {}, prefix {}".format(bucket, s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, ) self._bucket = bucket # deepracer checkpoint json s3 key self._s3_key = os.path.normpath( os.path.join(s3_prefix, DEEPRACER_CHECKPOINT_KEY_POSTFIX)) # deepracer checkpoint json local path self._local_path = os.path.normpath( DEEPRACER_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir)) self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec, log_and_cont=log_and_cont)
def download_custom_files_if_present(s3_bucket, s3_prefix, aws_region, s3_endpoint_url): '''download custom environment and preset files Args: s3_bucket (str): s3 bucket string s3_prefix (str): s3 prefix string aws_region (str): aws region string Returns: tuple (bool, bool): tuple of bool on whether preset and environemnt is downloaded successfully ''' success_environment_download, success_preset_download = False, False try: s3_client = S3Client(region_name=aws_region, s3_endpoint_url=s3_endpoint_url, max_retry_attempts=0) environment_file_s3_key = os.path.normpath( s3_prefix + "/environments/deepracer_racetrack_env.py") environment_local_path = os.path.join(CUSTOM_FILES_PATH, "deepracer_racetrack_env.py") s3_client.download_file(bucket=s3_bucket, s3_key=environment_file_s3_key, local_path=environment_local_path) success_environment_download = True except botocore.exceptions.ClientError: pass try: preset_file_s3_key = os.path.normpath(s3_prefix + "/presets/preset.py") preset_local_path = os.path.join(CUSTOM_FILES_PATH, "preset.py") s3_client.download_file(bucket=s3_bucket, s3_key=preset_file_s3_key, local_path=preset_local_path) success_preset_download = True except botocore.exceptions.ClientError: pass return success_preset_download, success_environment_download
def main(): screen.set_use_colors(False) parser = argparse.ArgumentParser() parser.add_argument('-pk', '--preset_s3_key', help="(string) Name of a preset to download from S3", type=str, required=False) parser.add_argument( '-ek', '--environment_s3_key', help="(string) Name of an environment file to download from S3", type=str, required=False) parser.add_argument('--model_metadata_s3_key', help="(string) Model Metadata File S3 Key", type=str, required=False) parser.add_argument( '-c', '--checkpoint_dir', help= '(string) Path to a folder containing a checkpoint to write the model to.', type=str, default='./checkpoint') parser.add_argument( '--pretrained_checkpoint_dir', help='(string) Path to a folder for downloading a pre-trained model', type=str, default=PRETRAINED_MODEL_DIR) parser.add_argument('--s3_bucket', help='(string) S3 bucket', type=str, default=os.environ.get( "SAGEMAKER_SHARED_S3_BUCKET_PATH", "gsaur-test")) parser.add_argument('--s3_prefix', help='(string) S3 prefix', type=str, default='sagemaker') parser.add_argument('--framework', help='(string) tensorflow or mxnet', type=str, default='tensorflow') parser.add_argument('--pretrained_s3_bucket', help='(string) S3 bucket for pre-trained model', type=str) parser.add_argument('--pretrained_s3_prefix', help='(string) S3 prefix for pre-trained model', type=str, default='sagemaker') parser.add_argument('--aws_region', help='(string) AWS region', type=str, default=os.environ.get("AWS_REGION", "us-east-1")) args, _ = parser.parse_known_args() s3_client = S3Client(region_name=args.aws_region, max_retry_attempts=0) # download model metadata # TODO: replace 'agent' with name of each agent model_metadata_download = ModelMetadata( bucket=args.s3_bucket, s3_key=args.model_metadata_s3_key, region_name=args.aws_region, local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent')) model_metadata_info = model_metadata_download.get_model_metadata_info() network_type = model_metadata_info[ModelMetadataKeys.NEURAL_NETWORK.value] version = model_metadata_info[ModelMetadataKeys.VERSION.value] # upload model metadata model_metadata_upload = ModelMetadata( bucket=args.s3_bucket, s3_key=get_s3_key(args.s3_prefix, MODEL_METADATA_S3_POSTFIX), region_name=args.aws_region, local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format('agent')) model_metadata_upload.persist( s3_kms_extra_args=utils.get_s3_kms_extra_args()) shutil.copy2(model_metadata_download.local_path, SM_MODEL_OUTPUT_DIR) success_custom_preset = False if args.preset_s3_key: preset_local_path = "./markov/presets/preset.py" try: s3_client.download_file(bucket=args.s3_bucket, s3_key=args.preset_s3_key, local_path=preset_local_path) success_custom_preset = True except botocore.exceptions.ClientError: pass if not success_custom_preset: logger.info( "Could not download the preset file. Using the default DeepRacer preset." ) else: preset_location = "markov.presets.preset:graph_manager" graph_manager = short_dynamic_import(preset_location, ignore_module_case=True) s3_client.upload_file( bucket=args.s3_bucket, s3_key=os.path.normpath("%s/presets/preset.py" % args.s3_prefix), local_path=preset_local_path, s3_kms_extra_args=utils.get_s3_kms_extra_args()) if success_custom_preset: logger.info("Using preset: %s" % args.preset_s3_key) if not success_custom_preset: params_blob = os.environ.get('SM_TRAINING_ENV', '') if params_blob: params = json.loads(params_blob) sm_hyperparams_dict = params["hyperparameters"] else: sm_hyperparams_dict = {} #! TODO each agent should have own config agent_config = { 'model_metadata': model_metadata_download, ConfigParams.CAR_CTRL_CONFIG.value: { ConfigParams.LINK_NAME_LIST.value: [], ConfigParams.VELOCITY_LIST.value: {}, ConfigParams.STEERING_LIST.value: {}, ConfigParams.CHANGE_START.value: None, ConfigParams.ALT_DIR.value: None, ConfigParams.MODEL_METADATA.value: model_metadata_download, ConfigParams.REWARD.value: None, ConfigParams.AGENT_NAME.value: 'racecar' } } agent_list = list() agent_list.append(create_training_agent(agent_config)) graph_manager, robomaker_hyperparams_json = get_graph_manager( hp_dict=sm_hyperparams_dict, agent_list=agent_list, run_phase_subject=None, run_type=str(RunType.TRAINER)) # Upload hyperparameters to SageMaker shared s3 bucket hyperparameters = Hyperparameters(bucket=args.s3_bucket, s3_key=get_s3_key( args.s3_prefix, HYPERPARAMETER_S3_POSTFIX), region_name=args.aws_region) hyperparameters.persist( hyperparams_json=robomaker_hyperparams_json, s3_kms_extra_args=utils.get_s3_kms_extra_args()) # Attach sample collector to graph_manager only if sample count > 0 max_sample_count = int(sm_hyperparams_dict.get("max_sample_count", 0)) if max_sample_count > 0: sample_collector = SampleCollector( bucket=args.s3_bucket, s3_prefix=args.s3_prefix, region_name=args.aws_region, max_sample_count=max_sample_count, sampling_frequency=int( sm_hyperparams_dict.get("sampling_frequency", 1))) graph_manager.sample_collector = sample_collector # persist IP config from sagemaker to s3 ip_config = IpConfig(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, region_name=args.aws_region) ip_config.persist(s3_kms_extra_args=utils.get_s3_kms_extra_args()) training_algorithm = model_metadata_download.training_algorithm output_head_format = FROZEN_HEAD_OUTPUT_GRAPH_FORMAT_MAPPING[ training_algorithm] use_pretrained_model = args.pretrained_s3_bucket and args.pretrained_s3_prefix # Handle backward compatibility if use_pretrained_model: # checkpoint s3 instance for pretrained model # TODO: replace 'agent' for multiagent training checkpoint = Checkpoint(bucket=args.pretrained_s3_bucket, s3_prefix=args.pretrained_s3_prefix, region_name=args.aws_region, agent_name='agent', checkpoint_dir=args.pretrained_checkpoint_dir, output_head_format=output_head_format) # make coach checkpoint compatible if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible( ): checkpoint.rl_coach_checkpoint.make_compatible( checkpoint.syncfile_ready) # get best model checkpoint string model_checkpoint_name = checkpoint.deepracer_checkpoint_json.get_deepracer_best_checkpoint( ) # Select the best checkpoint model by uploading rl coach .coach_checkpoint file checkpoint.rl_coach_checkpoint.update( model_checkpoint_name=model_checkpoint_name, s3_kms_extra_args=utils.get_s3_kms_extra_args()) # add checkpoint into checkpoint_dict checkpoint_dict = {'agent': checkpoint} # load pretrained model ds_params_instance_pretrained = S3BotoDataStoreParameters( checkpoint_dict=checkpoint_dict) data_store_pretrained = S3BotoDataStore(ds_params_instance_pretrained, graph_manager, True) data_store_pretrained.load_from_store() memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters( redis_address="localhost", redis_port=6379, run_type=str(RunType.TRAINER), channel=args.s3_prefix, network_type=network_type) graph_manager.memory_backend_params = memory_backend_params # checkpoint s3 instance for training model checkpoint = Checkpoint(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, region_name=args.aws_region, agent_name='agent', checkpoint_dir=args.checkpoint_dir, output_head_format=output_head_format) checkpoint_dict = {'agent': checkpoint} ds_params_instance = S3BotoDataStoreParameters( checkpoint_dict=checkpoint_dict) graph_manager.data_store_params = ds_params_instance graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager) task_parameters = TaskParameters() task_parameters.experiment_path = SM_MODEL_OUTPUT_DIR task_parameters.checkpoint_save_secs = 20 if use_pretrained_model: task_parameters.checkpoint_restore_path = args.pretrained_checkpoint_dir task_parameters.checkpoint_save_dir = args.checkpoint_dir training_worker( graph_manager=graph_manager, task_parameters=task_parameters, user_batch_size=json.loads(robomaker_hyperparams_json)["batch_size"], user_episode_per_rollout=json.loads( robomaker_hyperparams_json)["num_episodes_between_training"], training_algorithm=training_algorithm)
def __init__(self, queue_url, aws_region='us-east-1', race_duration=180, number_of_trials=3, number_of_resets=10000, penalty_seconds=2.0, off_track_penalty=2.0, collision_penalty=5.0, is_continuous=False, race_type="TIME_TRIAL"): # constructor arguments self._model_updater = ModelUpdater.get_instance() self._deepracer_path = rospkg.RosPack().get_path( DeepRacerPackages.DEEPRACER_SIMULATION_ENVIRONMENT) body_shell_path = os.path.join(self._deepracer_path, "meshes", "f1") self._valid_body_shells = \ set(".".join(f.split(".")[:-1]) for f in os.listdir(body_shell_path) if os.path.isfile( os.path.join(body_shell_path, f))) self._valid_body_shells.add(const.BodyShellType.DEFAULT.value) self._valid_car_colors = set(e.value for e in const.CarColorType if "f1" not in e.value) self._num_sectors = int(rospy.get_param("NUM_SECTORS", "3")) self._queue_url = queue_url self._region = aws_region self._number_of_trials = number_of_trials self._number_of_resets = number_of_resets self._penalty_seconds = penalty_seconds self._off_track_penalty = off_track_penalty self._collision_penalty = collision_penalty self._is_continuous = is_continuous self._race_type = race_type self._is_save_simtrace_enabled = False self._is_save_mp4_enabled = False self._is_event_end = False self._done_condition = any self._race_duration = race_duration self._enable_domain_randomization = False # sqs client # The boto client errors out after polling for 1 hour. self._sqs_client = SQSClient(queue_url=self._queue_url, region_name=self._region, max_num_of_msg=MAX_NUM_OF_SQS_MESSAGE, wait_time_sec=SQS_WAIT_TIME_SEC, session=refreshed_session(self._region)) self._s3_client = S3Client(region_name=self._region) # tracking current state information self._track_data = TrackData.get_instance() self._start_lane = self._track_data.center_line # keep track of the racer specific info, e.g. s3 locations, alias, car color etc. self._current_racer = None # keep track of the current race car we are using. It is always "racecar". car_model_state = ModelState() car_model_state.model_name = "racecar" self._current_car_model_state = car_model_state self._last_body_shell_type = None self._last_sensors = None self._racecar_model = AgentModel() # keep track of the current control agent we are using self._current_agent = None # keep track of the current control graph manager self._current_graph_manager = None # Keep track of previous model's name self._prev_model_name = None self._hide_position_idx = 0 self._hide_positions = get_hide_positions(race_car_num=1) self._run_phase_subject = RunPhaseSubject() self._simtrace_video_s3_writers = [] self._local_model_directory = './checkpoint' # virtual event only have single agent, so set agent_name to "agent" self._agent_name = "agent" # camera manager self._camera_manager = CameraManager.get_instance() # setting up virtual event top and follow camera in CameraManager # virtual event configure camera does not need to wait for car to spawm because # follow car camera is not tracking any car initially self._main_cameras, self._sub_camera = configure_camera( namespaces=[VIRTUAL_EVENT], is_wait_for_model=False) self._spawn_cameras() # pop out all cameras after configuration to prevent camera from moving self._camera_manager.pop(namespace=VIRTUAL_EVENT) dummy_metrics_s3_config = { MetricsS3Keys.METRICS_BUCKET.value: "dummy-bucket", MetricsS3Keys.METRICS_KEY.value: "dummy-key", MetricsS3Keys.REGION.value: self._region } self._eval_metrics = EvalMetrics( agent_name=self._agent_name, s3_dict_metrics=dummy_metrics_s3_config, is_continuous=self._is_continuous, pause_time_before_start=PAUSE_TIME_BEFORE_START) # upload a default best sector time for all sectors with time inf for each sector # if there is not best sector time existed in s3 # use the s3 bucket and prefix for yaml file stored as environment variable because # here is SimApp use only. For virtual event there is no s3 bucket and prefix past # through yaml file. All are past through sqs. For simplicity, reuse the yaml s3 bucket # and prefix environment variable. virtual_event_best_sector_time = VirtualEventBestSectorTime( bucket=os.environ.get("YAML_S3_BUCKET", ''), s3_key=get_s3_key(os.environ.get("YAML_S3_PREFIX", ''), SECTOR_TIME_S3_POSTFIX), region_name=os.environ.get("APP_REGION", "us-east-1"), local_path=SECTOR_TIME_LOCAL_PATH) response = virtual_event_best_sector_time.list() # this is used to handle situation such as robomaker job crash, so the next robomaker job # can catch the best sector time left over from crashed job if "Contents" not in response: virtual_event_best_sector_time.persist( body=json.dumps({ SECTOR_X_FORMAT.format(idx + 1): float("inf") for idx in range(self._num_sectors) }), s3_kms_extra_args=utils.get_s3_kms_extra_args()) # ROS service to indicate all the robomaker markov packages are ready for consumption signal_robomaker_markov_package_ready() PhaseObserver('/agent/training_phase', self._run_phase_subject) # setup mp4 services self._setup_mp4_services()