def get_action_dict(self, action): """return the action dict containing the steering_angle and speed value Args: action (int or list): model metadata action_space index for discreet action spaces or [steering, speed] float values for continuous action spaces Returns: dict (str, float): dictionary containing {steering_angle: value, speed: value} """ if self.action_space_type == ActionSpaceTypes.DISCRETE.value: return self._model_metadata[ ModelMetadataKeys.ACTION_SPACE.value][action] elif self.action_space_type == ActionSpaceTypes.CONTINUOUS.value: json_action = dict() json_action[ModelMetadataKeys.STEERING_ANGLE.value] = action[0] json_action[ModelMetadataKeys.SPEED.value] = action[1] return json_action else: log_and_exit( "Unknown action_space_type found while getting action dict. \ action_space_type: {}".format(self.action_space_type), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def __call__(self, *argv): """ Makes a client call for the stored service argv (list): Arguments to pass into the client object """ try_count = 0 while True: try: return self.client(*argv) except TypeError as err: log_and_exit("Invalid arguments for client {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) except Exception as ex: try_count += 1 if try_count > self._max_retry_attempts: time.sleep(ROBOMAKER_CANCEL_JOB_WAIT_TIME) log_and_exit("Unable to call service {}".format(ex), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) error_message = ROS_SERVICE_ERROR_MSG_FORMAT.format(self._service_name, str(try_count), str(self._max_retry_attempts), ex) logger.info(error_message)
def do_model_selection(s3_bucket, s3_prefix, region, checkpoint_type=BEST_CHECKPOINT): '''Sets the chekpoint file to point at the best model based on reward and progress s3_bucket - DeepRacer s3 bucket s3_prefix - Prefix for the training job for which to select the best model for region - Name of the aws region where the job ran :returns status of model selection. True if successfully selected model otherwise false. ''' try: s3_extra_args = get_s3_kms_extra_args() model_checkpoint = get_deepracer_checkpoint(s3_bucket=s3_bucket, s3_prefix=s3_prefix, region=region, checkpoint_type=checkpoint_type) if model_checkpoint is None: return False local_path = os.path.abspath(os.path.join(os.getcwd(), 'coach_checkpoint')) with open(local_path, '+w') as new_ckpnt: new_ckpnt.write(model_checkpoint) s3_client = boto3.Session().client('s3', region_name=region, config=get_boto_config()) s3_client.upload_file(Filename=local_path, Bucket=s3_bucket, Key=os.path.join(s3_prefix, CHKPNT_KEY_SUFFIX), ExtraArgs=s3_extra_args) os.remove(local_path) return True except botocore.exceptions.ClientError as err: log_and_exit("Unable to upload checkpoint: {}, {}" .format(s3_bucket, err.response['Error']['Code']), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except Exception as ex: log_and_exit("Exception in uploading checkpoint: {}" .format(ex), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def subscribe_to_save_mp4(self): """ Ros service handler function used to subscribe to the Image topic. Arguments: req (req): Dummy req else the ros service throws exception Return: [] - Empty list else ros service throws exception """ try: for camera_enum in self.camera_infos: name = camera_enum['name'] local_path, topic_name = camera_enum['local_path'], camera_enum['topic_name'] self.cv2_video_writers[name] = cv2.VideoWriter(local_path, self.fourcc, self.fps, self.frame_size) self.mp4_subscription[name] = rospy.Subscriber(topic_name, Image, callback=self._subscribe_to_image_topic, callback_args=name) if name not in self.mp4_subscription_lock_map: self.mp4_subscription_lock_map[name] = threading.Lock() else: self.mp4_subscription_lock_map[name].release() except Exception as err_msg: log_and_exit("Exception in the handler function to subscribe to save_mp4 download: {}".format(err_msg), SIMAPP_SIMULATION_SAVE_TO_MP4_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def __init__(self, bucket, s3_key, region_name="us-east-1", s3_endpoint_url=None, local_path="./custom_files/agent/customer_reward_function.py", max_retry_attempts=5, backoff_time_sec=1.0): '''reward function upload, download, and parse Args: bucket (str): S3 bucket string s3_key (str): S3 key string region_name (str): S3 region name local_path (str): file local path max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry ''' # check s3 key and bucket exist for reward function if not s3_key or not bucket: log_and_exit("Reward function code S3 key or bucket not available for S3. \ bucket: {}, key: {}".format(bucket, s3_key), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._bucket = bucket # Strip the s3://<bucket> from uri, if s3_key past in as uri self._s3_key = s3_key.replace('s3://{}/'.format(self._bucket), '') self._local_path_processed = local_path # if _local_path_processed is test.py then _local_path_preprocessed is test_preprocessed.py self._local_path_preprocessed = ("_preprocessed.py").join(local_path.split(".py")) # if local _local_path_processed is ./custom_files/agent/customer_reward_function.py, # then the import path should be custom_files.agent.customer_reward_function by # remove ".py", remove "./", and replace "/" and "." self._import_path = local_path.replace(".py", "").replace("./", "").replace("/", ".") self._reward_function = None self._s3_client = S3Client(region_name, s3_endpoint_url, max_retry_attempts, backoff_time_sec)
def _get_deepracer_checkpoint(self, checkpoint_type): '''Returns the deepracer checkpoint stored in the checkpoint json Args: checkpoint_type (str): BEST_CHECKPOINT/LAST_CHECKPOINT string ''' try: # Download deepracer checkpoint self._download() except botocore.exceptions.ClientError as err: if err.response['Error']['Code'] == "404": LOG.info("Unable to find deepracer checkpoint json") return None else: log_and_exit( "Unable to download deepracer checkpoint json: {}, {}". format(self._bucket, err.response['Error']['Code']), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except Exception as ex: log_and_exit( "Can't download deepracer checkpoint json: {}".format(ex), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) try: with open(self._local_path) as deepracer_checkpoint_file: checkpoint_name = json.load( deepracer_checkpoint_file)[checkpoint_type]["name"] if not checkpoint_name: raise Exception("No deepracer checkpoint json recorded") os.remove(self._local_path) except Exception as ex: LOG.info( "Unable to parse deepracer checkpoint json: {}".format(ex)) return None return checkpoint_name
def make_compatible(s3_bucket, s3_prefix, region, ready_file, s3_endpoint_url=None): '''Moves and creates all the necessary files to make models trained by coach 0.11 compatible with coach 1.0 s3_bucket - DeepRacer s3 bucket s3_prefix - Prefix for the training job for which to select the best model for region - Name of the aws region where the job ran ''' try: session = boto3.Session() s3_client = session.client('s3', region_name=region, endpoint_url=s3_endpoint_url, config=get_boto_config()) s3_extra_args = get_s3_kms_extra_args() old_checkpoint = os.path.join(os.getcwd(), 'checkpoint') s3_client.download_file(Bucket=s3_bucket, Key=os.path.join(s3_prefix, 'model/checkpoint'), Filename=old_checkpoint) with open(old_checkpoint) as old_checkpoint_file: chekpoint = re.findall(r'"(.*?)"', old_checkpoint_file.readline()) if len(chekpoint) != 1: log_and_exit("No checkpoint file found", SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) os.remove(old_checkpoint) # Upload ready file so that the system can gab the checkpoints s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=s3_bucket, Key=os.path.join( s3_prefix, "model/{}").format(ready_file), ExtraArgs=s3_extra_args) # Upload the new checkpoint file new_checkpoint = os.path.join(os.getcwd(), 'coach_checkpoint') with open(new_checkpoint, 'w+') as new_checkpoint_file: new_checkpoint_file.write(chekpoint[0]) s3_client.upload_file(Filename=new_checkpoint, Bucket=s3_bucket, Key=os.path.join(s3_prefix, CHKPNT_KEY_SUFFIX), ExtraArgs=s3_extra_args) os.remove(new_checkpoint) except botocore.exceptions.ClientError as e: log_and_exit( "Unable to make model compatible: {}, {}".format( s3_bucket, e.response['Error']['Code']), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except Exception as e: log_and_exit("Unable to make model compatible: {}".format(e), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def __init__( self, bucket, s3_key, region_name="us-east-1", local_path="./custom_files/agent/model_metadata.json", max_retry_attempts=5, backoff_time_sec=1.0, ): """Model metadata upload, download, and parse Args: bucket (str): S3 bucket string s3_key: (str): S3 key string. region_name (str): S3 region name local_path (str): file local path max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry """ # check s3 key and s3 bucket exist if not bucket or not s3_key: log_and_exit( "model_metadata S3 key or bucket not available for S3. \ bucket: {}, key {}".format(bucket, s3_key), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, ) self._bucket = bucket # Strip the s3://<bucket> from uri, if s3_key past in as uri self._s3_key = s3_key.replace("s3://{}/".format(self._bucket), "") self._local_path = local_path self._local_dir = os.path.dirname(self._local_path) self._model_metadata = None self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec)
def _download(self): '''wait for ip config to be ready first and then download it''' # check and make local directory local_dir = os.path.dirname(self._local_path) if local_dir and not os.path.exists(local_dir): os.makedirs(local_dir) # Download the ip file with retry try: # Wait for sagemaker to produce the redis ip self._wait_for_ip_config() self._s3_client.download_file(bucket=self._bucket, s3_key=self._s3_ip_address_key, local_path=self._local_path) LOG.info("[s3] Successfully downloaded ip config from \ s3 key {} to local {}.".format(self._s3_ip_address_key, self._local_path)) with open(self._local_path) as file: self._ip_file = json.load(file)["IP"] except botocore.exceptions.ClientError as err: log_and_exit("Failed to download ip file: s3_bucket: {}, s3_key: {}, {}"\ .format(self._bucket, self._s3_ip_address_key, err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def signal_ready(self): """upload rl coach .ready file""" try: # remove lock file if it exists self.syncfile_lock.delete() # acquire lock self.syncfile_lock.persist(s3_kms_extra_args=get_s3_kms_extra_args()) for _, checkpoint in self.params.checkpoint_dict.items(): # upload .ready checkpoint.syncfile_ready.persist(s3_kms_extra_args=get_s3_kms_extra_args()) # release lock by delete it self.syncfile_lock.delete() except botocore.exceptions.ClientError: log_and_exit( "Unable to upload .ready", SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400, ) except Exception as ex: log_and_exit( "Exception in uploading .ready file: {}".format(ex), SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, )
def __init__( self, bucket, s3_prefix, region_name="us-east-1", local_path="./custom_files/agent/ip.json", max_retry_attempts=5, backoff_time_sec=1.0, ): """ip upload, download, and parse Args: bucket (str): s3 bucket s3_prefix (str): s3 prefix region_name (str): s3 region name local_path (str): ip addres json file local path max_retry_attempts (int): maximum retry attempts backoff_time_sec (float): retry backoff time in seconds """ if not s3_prefix or not bucket: log_and_exit( "Ip config S3 prefix or bucket not available for S3. \ bucket: {}, prefix: {}".format(bucket, s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, ) self._bucket = bucket self._s3_ip_done_key = os.path.normpath( os.path.join(s3_prefix, IP_DONE_POSTFIX)) self._s3_ip_address_key = os.path.normpath( os.path.join(s3_prefix, IP_ADDRESS_POSTFIX)) self._local_path = local_path self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec) self._ip_file = None
def __init__(self, bucket, s3_prefix, region_name='us-east-1', local_dir='./checkpoint/agent', max_retry_attempts=5, backoff_time_sec=1.0, output_head_format=FROZEN_HEAD_OUTPUT_GRAPH_FORMAT_MAPPING[ TrainingAlgorithm.CLIPPED_PPO.value]): '''This class is for tensorflow model upload and download Args: bucket (str): S3 bucket string s3_prefix (str): S3 prefix string region_name (str): S3 region name local_dir (str): local file directory max_retry_attempts (int): maximum number of retry attempts for S3 download/upload backoff_time_sec (float): backoff second between each retry output_head_format (str): output head format for the specific algorithm and action space which will be used to store the frozen graph ''' if not bucket or not s3_prefix: log_and_exit( "checkpoint S3 prefix or bucket not available for S3. \ bucket: {}, prefix {}".format(bucket, s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) self._bucket = bucket self._local_dir = os.path.normpath( CHECKPOINT_LOCAL_DIR_FORMAT.format(local_dir)) self._s3_key_dir = os.path.normpath( os.path.join(s3_prefix, CHECKPOINT_POSTFIX_DIR)) self._delete_queue = queue.Queue() self._s3_client = S3Client(region_name, max_retry_attempts, backoff_time_sec) self.output_head_format = output_head_format
def make_compatible(self, syncfile_ready): """update coach checkpoint file to make it compatible Args: syncfile_ready (RlCoachSyncFile): RlCoachSyncFile class instance for .ready file """ try: # download old coach checkpoint self._s3_client.download_file(bucket=self._bucket, s3_key=self._old_s3_key, local_path=self._old_local_path) # parse old coach checkpoint with open(self._old_local_path) as old_coach_checkpoint_file: coach_checkpoint_value = re.findall( r'"(.*?)"', old_coach_checkpoint_file.readline()) if len(coach_checkpoint_value) != 1: log_and_exit( "No checkpoint file found", SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400, ) # remove old local coach checkpoint os.remove(self._old_local_path) # Upload ready file so that the system can gab the checkpoints syncfile_ready.persist(s3_kms_extra_args=get_s3_kms_extra_args()) # write new temp coach checkpoint file with open(self._temp_local_path, "w+") as new_coach_checkpoint_file: new_coach_checkpoint_file.write(coach_checkpoint_value[0]) # upload new temp coach checkpoint file self._persist_temp_coach_checkpoint( s3_kms_extra_args=get_s3_kms_extra_args()) # remove new temp local coach checkpoint os.remove(self._temp_local_path) except botocore.exceptions.ClientError as e: log_and_exit( "Unable to make model compatible: {}, {}".format( self._bucket, e.response["Error"]["Code"]), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400, ) except Exception as e: log_and_exit( "Exception in making model compatible: {}".format(e), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, )
def is_compatible(self): """check whether rl coach checkpoint is compatiable by checking whether there is a .coach_checkpoint file presetn in the expected s3 bucket Returns: bool: True is coach checkpoint is compatiable, False otherwise """ try: coach_checkpoint_dir, coach_checkpoint_filename = os.path.split( self._s3_key) response = self._s3_client.list_objects_v2( bucket=self._bucket, prefix=coach_checkpoint_dir) if "Contents" not in response: # Customer deleted checkpoint file. log_and_exit( "No objects found: {}".format(self._bucket), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400, ) return any( list( map( lambda obj: os.path.split(obj["Key"])[1] == coach_checkpoint_filename, response["Contents"], ))) except botocore.exceptions.ClientError as e: log_and_exit( "No objects found: {}, {}".format(self._bucket, e.response["Error"]["Code"]), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400, ) except Exception as e: log_and_exit( "Exception in checking for current checkpoint key: {}".format( e), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500, )
def download_file(self, s3_key, local_path): s3_client = self.get_client() try: s3_client.download_file(self.bucket, s3_key, local_path) return True except botocore.exceptions.ClientError as err: # It is possible that the file isn't there in which case we should # return fasle and let the client decide the next action if err.response['Error']['Code'] == "404": return False else: log_and_exit("Unable to download file", SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except botocore.exceptions.ConnectTimeoutError as ex: log_and_exit("Issue with your current VPC stack and IAM roles.\ You might need to reset your account resources: {}".format(ex), SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except Exception as ex: log_and_exit("Exception in downloading file: {}".format(ex), SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
s3_bucket=s3_bucket, s3_prefix=s3_prefix, aws_region=aws_region) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--s3_bucket', help='(string) S3 bucket', type=str) parser.add_argument('--s3_prefix', help='(string) S3 prefix', type=str) parser.add_argument('--aws_region', help='(string) AWS region', type=str) args = parser.parse_args() try: validate(s3_bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region) except ValueError as err: if utils.is_user_error(err): log_and_exit("User modified model/model_metadata: {}".format(err), SIMAPP_VALIDATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) else: log_and_exit("Validation worker value error: {}".format(err), SIMAPP_VALIDATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) except Exception as ex: log_and_exit("Validation worker exited with exception: {}".format(ex), SIMAPP_VALIDATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def get(self, coach_checkpoint_state_file): '''get tensorflow model specified in the rl coach checkpoint state file If the rl coach checkpoint state file specified checkpoint is missing. It will download last checkpoints and over write the last in local rl coach checkpoint state file Args: coach_checkpoint_state_file (CheckpointStateFile): CheckpointStateFile instance ''' has_checkpoint = False last_checkpoint_number = -1 last_checkpoint_name = None # list everything in tensorflow model s3 bucket dir # to find the checkpoint specified in .coach_checkpoint # or use the last checkpoint_name = str(coach_checkpoint_state_file.read()) for page in self._s3_client.paginate(bucket=self._bucket, prefix=self._s3_key_dir): if "Contents" in page: # Check to see if the desired tensorflow model is in the bucket # for example if obj is (dir)/487_Step-2477372.ckpt.data-00000-of-00001 # curr_checkpoint_number: 487 # curr_checkpoint_name: 487_Step-2477372.ckpt.data-00000-of-00001 for obj in page['Contents']: curr_checkpoint_name = os.path.split(obj['Key'])[1] # if found the checkpoint name stored in .coach_checkpoint file # break inner loop for file search if curr_checkpoint_name.startswith(checkpoint_name): has_checkpoint = True break # if the file name does not start with a number (not ckpt file) # continue for next file if not utils.is_int_repr( curr_checkpoint_name.split("_")[0]): continue # if the file name start with a number, update the last checkpoint name # and number curr_checkpoint_number = int( curr_checkpoint_name.split("_")[0]) if curr_checkpoint_number > last_checkpoint_number: last_checkpoint_number = curr_checkpoint_number last_checkpoint_name = curr_checkpoint_name.rsplit( '.', 1)[0] # break out from pagination if find the checkpoint if has_checkpoint: break # update checkpoint_name to the last_checkpoint_name and overwrite local # .coach_checkpoint file to contain the last checkpoint if not has_checkpoint: if last_checkpoint_name: coach_checkpoint_state_file.write( SingleCheckpoint(num=last_checkpoint_number, name=last_checkpoint_name)) LOG.info("%s not in s3 bucket, downloading %s checkpoints", checkpoint_name, last_checkpoint_name) checkpoint_name = last_checkpoint_name else: log_and_exit("No checkpoint files", SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) # download the desired checkpoint file for page in self._s3_client.paginate(bucket=self._bucket, prefix=self._s3_key_dir): if "Contents" in page: for obj in page['Contents']: s3_key = obj["Key"] _, file_name = os.path.split(s3_key) local_path = os.path.normpath( os.path.join(self._local_dir, file_name)) _, file_extension = os.path.splitext(s3_key) if file_extension != '.pb' and file_name.startswith( checkpoint_name): self._download(s3_key=s3_key, local_path=local_path)
def copy_best_frozen_graph_to_sm_output_dir(self, best_checkpoint_number, last_checkpoint_number, source_dir, dest_dir): """Copy the frozen model for the current best checkpoint from soure directory to the destination directory. Args: s3_bucket (str): S3 bucket where the deepracer_checkpoints.json is stored s3_prefix (str): S3 prefix where the deepracer_checkpoints.json is stored region (str): AWS region where the deepracer_checkpoints.json is stored source_dir (str): Source directory where the frozen models are present dest_dir (str): Sagemaker output directory where we store the frozen models for best checkpoint """ dest_dir_pb_files = [ filename for filename in os.listdir(dest_dir) if os.path.isfile(os.path.join(dest_dir, filename)) and filename.endswith(".pb") ] source_dir_pb_files = [ filename for filename in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, filename)) and filename.endswith(".pb") ] LOG.info( "Best checkpoint number: {}, Last checkpoint number: {}".format( best_checkpoint_number, last_checkpoint_number)) best_model_name = 'model_{}.pb'.format(best_checkpoint_number) last_model_name = 'model_{}.pb'.format(last_checkpoint_number) if len(source_dir_pb_files) < 1: log_and_exit( "Could not find any frozen model file in the local directory", SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) try: # Could not find the deepracer_checkpoints.json file or there are no model.pb files in destination if best_checkpoint_number == -1 or len(dest_dir_pb_files) == 0: if len(source_dir_pb_files) > 1: LOG.info( "More than one model.pb found in the source directory. Choosing the " "first one to copy to destination: {}".format( source_dir_pb_files[0])) # copy the frozen model present in the source directory LOG.info("Copying the frozen checkpoint from {} to {}.".format( os.path.join(source_dir, source_dir_pb_files[0]), os.path.join(dest_dir, "model.pb"))) shutil.copy(os.path.join(source_dir, source_dir_pb_files[0]), os.path.join(dest_dir, "model.pb")) else: # Delete the current .pb files in the destination direcory for filename in dest_dir_pb_files: os.remove(os.path.join(dest_dir, filename)) # Copy the frozen model for the current best checkpoint to the destination directory LOG.info("Copying the frozen checkpoint from {} to {}.".format( os.path.join(source_dir, best_model_name), os.path.join(dest_dir, "model.pb"))) shutil.copy(os.path.join(source_dir, best_model_name), os.path.join(dest_dir, "model.pb")) # Loop through the current list of frozen models in source directory and # delete the iterations lower than last_checkpoint_iteration except best_model for filename in source_dir_pb_files: if filename not in [best_model_name, last_model_name]: if len(filename.split("_")[1]) > 1 and len( filename.split("_")[1].split(".pb")): file_iteration = int( filename.split("_")[1].split(".pb")[0]) if file_iteration < last_checkpoint_number: os.remove(os.path.join(source_dir, filename)) else: LOG.error( "Frozen model name not in the right format in the source directory: {}, {}" .format(filename, source_dir)) except FileNotFoundError as err: log_and_exit("No such file or directory: {}".format(err), SIMAPP_S3_DATA_STORE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400)
def rename(self, coach_checkpoint_state_file, agent_name): '''rename the tensorflow model specified in the rl coach checkpoint state file to include agent name Args: coach_checkpoint_state_file (CheckpointStateFile): CheckpointStateFile instance agent_name (str): agent name ''' try: LOG.info( "Renaming checkpoint from checkpoint_dir: {} for agent: {}". format(self._local_dir, agent_name)) checkpoint_name = str(coach_checkpoint_state_file.read()) tf_checkpoint_file = os.path.join(self._local_dir, "checkpoint") with open(tf_checkpoint_file, "w") as outfile: outfile.write( "model_checkpoint_path: \"{}\"".format(checkpoint_name)) with tf.Session() as sess: for var_name, _ in tf.contrib.framework.list_variables( self._local_dir): # Load the variable var = tf.contrib.framework.load_variable( self._local_dir, var_name) new_name = var_name # Set the new name # Replace agent/ or agent_#/ with {agent_name}/ new_name = re.sub('agent/|agent_\d+/', '{}/'.format(agent_name), new_name) # Rename the variable var = tf.Variable(var, name=new_name) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) renamed_checkpoint_path = os.path.join(TEMP_RENAME_FOLDER, checkpoint_name) LOG.info('Saving updated checkpoint to {}'.format( renamed_checkpoint_path)) saver.save(sess, renamed_checkpoint_path) # Remove the tensorflow 'checkpoint' file os.remove(tf_checkpoint_file) # Remove the old checkpoint from the checkpoint dir for file_name in os.listdir(self._local_dir): if checkpoint_name in file_name: os.remove(os.path.join(self._local_dir, file_name)) # Copy the new checkpoint with renamed variable to the checkpoint dir for file_name in os.listdir(TEMP_RENAME_FOLDER): full_file_name = os.path.join( os.path.abspath(TEMP_RENAME_FOLDER), file_name) if os.path.isfile( full_file_name) and file_name != "checkpoint": shutil.copy(full_file_name, self._local_dir) # Remove files from temp_rename_folder shutil.rmtree(TEMP_RENAME_FOLDER) tf.reset_default_graph() # If either of the checkpoint files (index, meta or data) not found except tf.errors.NotFoundError as err: log_and_exit("No checkpoint found: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) # Thrown when user modifies model, checkpoints get corrupted/truncated except tf.errors.DataLossError as err: log_and_exit( "User modified ckpt, unrecoverable dataloss or corruption: {}". format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except tf.errors.OutOfRangeError as err: log_and_exit("User modified ckpt: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except ValueError as err: if utils.is_user_error(err): log_and_exit( "Couldn't find 'checkpoint' file or checkpoints in given \ directory ./checkpoint: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) else: log_and_exit("ValueError in rename checkpoint: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) except Exception as ex: log_and_exit("Exception in rename checkpoint: {}".format(ex), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def training_worker(graph_manager, task_parameters, user_batch_size, user_episode_per_rollout, training_algorithm): try: # initialize graph graph_manager.create_graph(task_parameters) # save initial checkpoint graph_manager.save_checkpoint() # training loop steps = 0 graph_manager.setup_memory_backend() graph_manager.signal_ready() # To handle SIGTERM door_man = utils.DoorMan() while steps < graph_manager.improve_steps.num_steps: # Collect profiler information only IS_PROFILER_ON is true with utils.Profiler( s3_bucket=PROFILER_S3_BUCKET, s3_prefix=PROFILER_S3_PREFIX, output_local_path=TRAINING_WORKER_PROFILER_PATH, enable_profiling=IS_PROFILER_ON): graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.fetch_from_worker( graph_manager.agent_params.algorithm. num_consecutive_playing_steps) graph_manager.phase = core_types.RunPhase.UNDEFINED episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout( ) for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout # TODO: Refactor the flow to remove conditional checks for specific algorithms # ------------------------sac only--------------------------------------------- if training_algorithm == TrainingAlgorithm.SAC.value: rollout_steps = graph_manager.memory_backend.get_rollout_steps( ) # NOTE: you can train even more iterations than rollout_steps by increasing the number below for SAC agent.ap.algorithm.num_consecutive_training_steps = list( rollout_steps.values())[0] # rollout_steps[agent] # ------------------------------------------------------------------------------- if graph_manager.should_train(): # Make sure we have enough data for the requested batches rollout_steps = graph_manager.memory_backend.get_rollout_steps( ) if any(rollout_steps.values()) <= 0: log_and_exit( "No rollout data retrieved from the rollout worker", SIMAPP_TRAINING_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) # TODO: Refactor the flow to remove conditional checks for specific algorithms # DH: for SAC, check if experience replay memory has enough transitions logger.info("setting trainig algorithm") if training_algorithm == TrainingAlgorithm.SAC.value: replay_mem_size = min([ agent.memory.num_transitions() for level in graph_manager.level_managers for agent in level.agents.values() ]) episode_batch_size = user_batch_size if replay_mem_size > user_batch_size \ else 2**math.floor(math.log(min(rollout_steps.values()), 2)) else: logger.info("it is CPPO") episode_batch_size = user_batch_size if min( rollout_steps.values( )) > user_batch_size else 2**math.floor( math.log(min(rollout_steps.values()), 2)) # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing # as batch size less than 2 causes the batch list to become a scalar which causes an exception for level in graph_manager.level_managers: for agent in level.agents.values(): for net_key in agent.ap.network_wrappers: agent.ap.network_wrappers[ net_key].batch_size = episode_batch_size steps += 1 graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train() graph_manager.phase = core_types.RunPhase.UNDEFINED # Check for Nan's in all agents rollout_has_nan = False for level in graph_manager.level_managers: for agent in level.agents.values(): if np.isnan(agent.loss.get_mean()): rollout_has_nan = True if rollout_has_nan: log_and_exit( "NaN detected in loss function, aborting training.", SIMAPP_TRAINING_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: graph_manager.save_checkpoint() else: graph_manager.occasionally_save_checkpoint() # Clear any data stored in signals that is no longer necessary graph_manager.reset_internal_state() for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout if door_man.terminate_now: log_and_exit( "Received SIGTERM. Checkpointing before exiting.", SIMAPP_TRAINING_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) graph_manager.save_checkpoint() break except ValueError as err: if utils.is_user_error(err): log_and_exit("User modified model: {}".format(err), SIMAPP_TRAINING_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) else: log_and_exit("An error occured while training: {}".format(err), SIMAPP_TRAINING_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) except Exception as ex: log_and_exit("An error occured while training: {}".format(ex), SIMAPP_TRAINING_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) finally: graph_manager.data_store.upload_finished_file()
checkpoint_dict=checkpoint_dict) graph_manager.data_store_params = ds_params_instance graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager) task_parameters = TaskParameters() task_parameters.experiment_path = SM_MODEL_OUTPUT_DIR task_parameters.checkpoint_save_secs = 20 if use_pretrained_model: task_parameters.checkpoint_restore_path = args.pretrained_checkpoint_dir task_parameters.checkpoint_save_dir = args.checkpoint_dir training_worker( graph_manager=graph_manager, task_parameters=task_parameters, user_batch_size=json.loads(robomaker_hyperparams_json)["batch_size"], user_episode_per_rollout=json.loads( robomaker_hyperparams_json)["num_episodes_between_training"], training_algorithm=training_algorithm) if __name__ == '__main__': try: main() except Exception as ex: log_and_exit("Training worker exited with exception: {}".format(ex), SIMAPP_TRAINING_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def main(): screen.set_use_colors(False) parser = argparse.ArgumentParser() parser.add_argument( '-c', '--checkpoint_dir', help= '(string) Path to a folder containing a checkpoint to restore the model from.', type=str, default='./checkpoint') parser.add_argument('--s3_bucket', help='(string) S3 bucket', type=str, default=rospy.get_param("SAGEMAKER_SHARED_S3_BUCKET", "gsaur-test")) parser.add_argument('--s3_prefix', help='(string) S3 prefix', type=str, default=rospy.get_param("SAGEMAKER_SHARED_S3_PREFIX", "sagemaker")) parser.add_argument( '--num_workers', help="(int) The number of workers started in this pool", type=int, default=int(rospy.get_param("NUM_WORKERS", 1))) parser.add_argument('--rollout_idx', help="(int) The index of current rollout worker", type=int, default=0) parser.add_argument('-r', '--redis_ip', help="(string) IP or host for the redis server", default='localhost', type=str) parser.add_argument('-rp', '--redis_port', help="(int) Port of the redis server", default=6379, type=int) parser.add_argument('--aws_region', help='(string) AWS region', type=str, default=rospy.get_param("AWS_REGION", "us-east-1")) parser.add_argument('--reward_file_s3_key', help='(string) Reward File S3 Key', type=str, default=rospy.get_param("REWARD_FILE_S3_KEY", None)) parser.add_argument('--model_metadata_s3_key', help='(string) Model Metadata File S3 Key', type=str, default=rospy.get_param("MODEL_METADATA_FILE_S3_KEY", None)) # For training job, reset is not allowed. penalty_seconds, off_track_penalty, and # collision_penalty will all be 0 be default parser.add_argument('--number_of_resets', help='(integer) Number of resets', type=int, default=int(rospy.get_param("NUMBER_OF_RESETS", 0))) parser.add_argument('--penalty_seconds', help='(float) penalty second', type=float, default=float(rospy.get_param("PENALTY_SECONDS", 0.0))) parser.add_argument('--job_type', help='(string) job type', type=str, default=rospy.get_param("JOB_TYPE", "TRAINING")) parser.add_argument('--is_continuous', help='(boolean) is continous after lap completion', type=bool, default=utils.str2bool( rospy.get_param("IS_CONTINUOUS", False))) parser.add_argument('--race_type', help='(string) Race type', type=str, default=rospy.get_param("RACE_TYPE", "TIME_TRIAL")) parser.add_argument('--off_track_penalty', help='(float) off track penalty second', type=float, default=float(rospy.get_param("OFF_TRACK_PENALTY", 0.0))) parser.add_argument('--collision_penalty', help='(float) collision penalty second', type=float, default=float(rospy.get_param("COLLISION_PENALTY", 0.0))) args = parser.parse_args() s3_client = SageS3Client(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region) logger.info("S3 bucket: %s", args.s3_bucket) logger.info("S3 prefix: %s", args.s3_prefix) # Load the model metadata model_metadata_local_path = os.path.join(CUSTOM_FILES_PATH, 'model_metadata.json') utils.load_model_metadata(s3_client, args.model_metadata_s3_key, model_metadata_local_path) # Download and import reward function if not args.reward_file_s3_key: log_and_exit( "Reward function code S3 key not available for S3 bucket {} and prefix {}" .format(args.s3_bucket, args.s3_prefix), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) download_customer_reward_function(s3_client, args.reward_file_s3_key) try: from custom_files.customer_reward_function import reward_function except Exception as e: log_and_exit("Failed to import user's reward_function: {}".format(e), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) # Instantiate Cameras configure_camera(namespaces=['racecar']) preset_file_success, _ = download_custom_files_if_present( s3_client, args.s3_prefix) #! TODO each agent should have own config _, _, version = utils_parse_model_metadata.parse_model_metadata( model_metadata_local_path) agent_config = { 'model_metadata': model_metadata_local_path, ConfigParams.CAR_CTRL_CONFIG.value: { ConfigParams.LINK_NAME_LIST.value: LINK_NAMES, ConfigParams.VELOCITY_LIST.value: VELOCITY_TOPICS, ConfigParams.STEERING_LIST.value: STEERING_TOPICS, ConfigParams.CHANGE_START.value: utils.str2bool(rospy.get_param('CHANGE_START_POSITION', True)), ConfigParams.ALT_DIR.value: utils.str2bool( rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)), ConfigParams.ACTION_SPACE_PATH.value: 'custom_files/model_metadata.json', ConfigParams.REWARD.value: reward_function, ConfigParams.AGENT_NAME.value: 'racecar', ConfigParams.VERSION.value: version, ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets, ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds, ConfigParams.NUMBER_OF_TRIALS.value: None, ConfigParams.IS_CONTINUOUS.value: args.is_continuous, ConfigParams.RACE_TYPE.value: args.race_type, ConfigParams.COLLISION_PENALTY.value: args.collision_penalty, ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty } } #! TODO each agent should have own s3 bucket step_metrics_prefix = rospy.get_param('SAGEMAKER_SHARED_S3_PREFIX') if args.num_workers > 1: step_metrics_prefix = os.path.join(step_metrics_prefix, str(args.rollout_idx)) metrics_s3_config = { MetricsS3Keys.METRICS_BUCKET.value: rospy.get_param('METRICS_S3_BUCKET'), MetricsS3Keys.METRICS_KEY.value: rospy.get_param('METRICS_S3_OBJECT_KEY'), MetricsS3Keys.REGION.value: rospy.get_param('AWS_REGION') } metrics_s3_model_cfg = { MetricsS3Keys.METRICS_BUCKET.value: args.s3_bucket, MetricsS3Keys.METRICS_KEY.value: os.path.join(args.s3_prefix, DEEPRACER_CHKPNT_KEY_SUFFIX), MetricsS3Keys.REGION.value: args.aws_region } run_phase_subject = RunPhaseSubject() agent_list = list() agent_list.append( create_rollout_agent( agent_config, TrainingMetrics(agent_name='agent', s3_dict_metrics=metrics_s3_config, s3_dict_model=metrics_s3_model_cfg, ckpnt_dir=args.checkpoint_dir, run_phase_sink=run_phase_subject, use_model_picker=(args.rollout_idx == 0)), run_phase_subject)) agent_list.append(create_obstacles_agent()) agent_list.append(create_bot_cars_agent()) # ROS service to indicate all the robomaker markov packages are ready for consumption signal_robomaker_markov_package_ready() PhaseObserver('/agent/training_phase', run_phase_subject) aws_region = rospy.get_param('AWS_REGION', args.aws_region) simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None) mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None) if args.rollout_idx == 0 else None if simtrace_s3_bucket: simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX') if args.num_workers > 1: simtrace_s3_object_prefix = os.path.join(simtrace_s3_object_prefix, str(args.rollout_idx)) if mp4_s3_bucket: mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX') s3_writer_job_info = [] if simtrace_s3_bucket: s3_writer_job_info.append( IterationData( 'simtrace', simtrace_s3_bucket, simtrace_s3_object_prefix, aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, 'agent', IterationDataLocalFileNames.SIM_TRACE_TRAINING_LOCAL_FILE. value))) if mp4_s3_bucket: s3_writer_job_info.extend([ IterationData( 'pip', mp4_s3_bucket, mp4_s3_object_prefix, aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, 'agent', IterationDataLocalFileNames. CAMERA_PIP_MP4_VALIDATION_LOCAL_PATH.value)), IterationData( '45degree', mp4_s3_bucket, mp4_s3_object_prefix, aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, 'agent', IterationDataLocalFileNames. CAMERA_45DEGREE_MP4_VALIDATION_LOCAL_PATH.value)), IterationData( 'topview', mp4_s3_bucket, mp4_s3_object_prefix, aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, 'agent', IterationDataLocalFileNames. CAMERA_TOPVIEW_MP4_VALIDATION_LOCAL_PATH.value)) ]) s3_writer = S3Writer(job_info=s3_writer_job_info) redis_ip = s3_client.get_ip() logger.info("Received IP from SageMaker successfully: %s", redis_ip) # Download hyperparameters from SageMaker hyperparameters_file_success = False hyperparams_s3_key = os.path.normpath(args.s3_prefix + "/ip/hyperparameters.json") hyperparameters_file_success = s3_client.download_file( s3_key=hyperparams_s3_key, local_path="hyperparameters.json") sm_hyperparams_dict = {} if hyperparameters_file_success: logger.info("Received Sagemaker hyperparameters successfully!") with open("hyperparameters.json") as filepointer: sm_hyperparams_dict = json.load(filepointer) else: logger.info("SageMaker hyperparameters not found.") enable_domain_randomization = utils.str2bool( rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False)) if preset_file_success: preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py") preset_location += ":graph_manager" graph_manager = short_dynamic_import(preset_location, ignore_module_case=True) logger.info("Using custom preset file!") else: graph_manager, _ = get_graph_manager( hp_dict=sm_hyperparams_dict, agent_list=agent_list, run_phase_subject=run_phase_subject, enable_domain_randomization=enable_domain_randomization) # If num_episodes_between_training is smaller than num_workers then cancel worker early. episode_steps_per_rollout = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps # Reduce number of workers if allocated more than num_episodes_between_training if args.num_workers > episode_steps_per_rollout: logger.info( "Excess worker allocated. Reducing from {} to {}...".format( args.num_workers, episode_steps_per_rollout)) args.num_workers = episode_steps_per_rollout if args.rollout_idx >= episode_steps_per_rollout or args.rollout_idx >= args.num_workers: err_msg_format = "Exiting excess worker..." err_msg_format += "(rollout_idx[{}] >= num_workers[{}] or num_episodes_between_training[{}])" logger.info( err_msg_format.format(args.rollout_idx, args.num_workers, episode_steps_per_rollout)) # Close the down the job utils.cancel_simulation_job( os.environ.get('AWS_ROBOMAKER_SIMULATION_JOB_ARN'), rospy.get_param('AWS_REGION')) memory_backend_params = DeepRacerRedisPubSubMemoryBackendParameters( redis_address=redis_ip, redis_port=6379, run_type=str(RunType.ROLLOUT_WORKER), channel=args.s3_prefix, num_workers=args.num_workers, rollout_idx=args.rollout_idx) graph_manager.memory_backend_params = memory_backend_params ds_params_instance = S3BotoDataStoreParameters( aws_region=args.aws_region, bucket_names={'agent': args.s3_bucket}, base_checkpoint_dir=args.checkpoint_dir, s3_folders={'agent': args.s3_prefix}) graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager) task_parameters = TaskParameters() task_parameters.checkpoint_restore_path = args.checkpoint_dir rollout_worker(graph_manager=graph_manager, num_workers=args.num_workers, rollout_idx=args.rollout_idx, task_parameters=task_parameters, s3_writer=s3_writer)
task_parameters.checkpoint_restore_path = args.checkpoint_dir rollout_worker(graph_manager=graph_manager, num_workers=args.num_workers, rollout_idx=args.rollout_idx, task_parameters=task_parameters, s3_writer=s3_writer) if __name__ == '__main__': try: rospy.init_node('rl_coach', anonymous=True) main() except ValueError as err: if utils.is_error_bad_ckpnt(err): log_and_exit("User modified model: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) else: log_and_exit("Rollout worker value error: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) except GenericRolloutError as ex: ex.log_except_and_exit() except GenericRolloutException as ex: ex.log_except_and_exit() except Exception as ex: log_and_exit("Rollout worker exited with exception: {}".format(ex), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def main(): """ Main function for tournament worker """ parser = argparse.ArgumentParser() parser.add_argument('-p', '--preset', help="(string) Name of a preset to run \ (class name from the 'presets' directory.)", type=str, required=False) parser.add_argument('--s3_bucket', help='list(string) S3 bucket', type=str, nargs='+', default=rospy.get_param("MODEL_S3_BUCKET", ["gsaur-test"])) parser.add_argument('--s3_prefix', help='list(string) S3 prefix', type=str, nargs='+', default=rospy.get_param("MODEL_S3_PREFIX", ["sagemaker"])) parser.add_argument('--aws_region', help='(string) AWS region', type=str, default=rospy.get_param("AWS_REGION", "us-east-1")) parser.add_argument('--number_of_trials', help='(integer) Number of trials', type=int, default=int(rospy.get_param("NUMBER_OF_TRIALS", 10))) parser.add_argument( '-c', '--local_model_directory', help='(string) Path to a folder containing a checkpoint \ to restore the model from.', type=str, default='./checkpoint') parser.add_argument('--number_of_resets', help='(integer) Number of resets', type=int, default=int(rospy.get_param("NUMBER_OF_RESETS", 0))) parser.add_argument('--penalty_seconds', help='(float) penalty second', type=float, default=float(rospy.get_param("PENALTY_SECONDS", 2.0))) parser.add_argument('--job_type', help='(string) job type', type=str, default=rospy.get_param("JOB_TYPE", "EVALUATION")) parser.add_argument('--is_continuous', help='(boolean) is continous after lap completion', type=bool, default=utils.str2bool( rospy.get_param("IS_CONTINUOUS", False))) parser.add_argument('--race_type', help='(string) Race type', type=str, default=rospy.get_param("RACE_TYPE", "TIME_TRIAL")) parser.add_argument('--off_track_penalty', help='(float) off track penalty second', type=float, default=float(rospy.get_param("OFF_TRACK_PENALTY", 2.0))) parser.add_argument('--collision_penalty', help='(float) collision penalty second', type=float, default=float(rospy.get_param("COLLISION_PENALTY", 5.0))) args = parser.parse_args() arg_s3_bucket = args.s3_bucket arg_s3_prefix = args.s3_prefix logger.info("S3 bucket: %s \n S3 prefix: %s", arg_s3_bucket, arg_s3_prefix) # tournament_worker: names to be displayed in MP4. # This is racer alias in tournament worker case. display_names = utils.get_video_display_name() metrics_s3_buckets = rospy.get_param('METRICS_S3_BUCKET') metrics_s3_object_keys = rospy.get_param('METRICS_S3_OBJECT_KEY') arg_s3_bucket, arg_s3_prefix = utils.force_list( arg_s3_bucket), utils.force_list(arg_s3_prefix) metrics_s3_buckets = utils.force_list(metrics_s3_buckets) metrics_s3_object_keys = utils.force_list(metrics_s3_object_keys) validate_list = [ arg_s3_bucket, arg_s3_prefix, metrics_s3_buckets, metrics_s3_object_keys ] simtrace_s3_bucket = rospy.get_param('SIMTRACE_S3_BUCKET', None) mp4_s3_bucket = rospy.get_param('MP4_S3_BUCKET', None) if simtrace_s3_bucket: simtrace_s3_object_prefix = rospy.get_param('SIMTRACE_S3_PREFIX') simtrace_s3_bucket = utils.force_list(simtrace_s3_bucket) simtrace_s3_object_prefix = utils.force_list(simtrace_s3_object_prefix) validate_list.extend([simtrace_s3_bucket, simtrace_s3_object_prefix]) if mp4_s3_bucket: mp4_s3_object_prefix = rospy.get_param('MP4_S3_OBJECT_PREFIX') mp4_s3_bucket = utils.force_list(mp4_s3_bucket) mp4_s3_object_prefix = utils.force_list(mp4_s3_object_prefix) validate_list.extend([mp4_s3_bucket, mp4_s3_object_prefix]) if not all([lambda x: len(x) == len(validate_list[0]), validate_list]): log_and_exit( "Tournament worker error: Incorrect arguments passed: {}".format( validate_list), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) if args.number_of_resets != 0 and args.number_of_resets < MIN_RESET_COUNT: raise GenericRolloutException( "number of resets is less than {}".format(MIN_RESET_COUNT)) # Instantiate Cameras if len(arg_s3_bucket) == 1: configure_camera(namespaces=['racecar']) else: configure_camera(namespaces=[ 'racecar_{}'.format(str(agent_index)) for agent_index in range(len(arg_s3_bucket)) ]) agent_list = list() s3_bucket_dict = dict() s3_prefix_dict = dict() s3_writers = list() start_positions = [ START_POS_OFFSET * idx for idx in reversed(range(len(arg_s3_bucket))) ] done_condition = utils.str_to_done_condition( rospy.get_param("DONE_CONDITION", any)) park_positions = utils.pos_2d_str_to_list( rospy.get_param("PARK_POSITIONS", [])) # if not pass in park positions for all done condition case, use default if not park_positions: park_positions = [DEFAULT_PARK_POSITION for _ in arg_s3_bucket] # tournament_worker: list of required S3 locations simtrace_s3_bucket_dict = dict() simtrace_s3_prefix_dict = dict() metrics_s3_bucket_dict = dict() metrics_s3_obect_key_dict = dict() mp4_s3_bucket_dict = dict() mp4_s3_object_prefix_dict = dict() for agent_index, s3_bucket_val in enumerate(arg_s3_bucket): agent_name = 'agent' if len(arg_s3_bucket) == 1 else 'agent_{}'.format( str(agent_index)) racecar_name = 'racecar' if len( arg_s3_bucket) == 1 else 'racecar_{}'.format(str(agent_index)) s3_bucket_dict[agent_name] = arg_s3_bucket[agent_index] s3_prefix_dict[agent_name] = arg_s3_prefix[agent_index] # tournament_worker: remap key with agent_name instead of agent_index for list of S3 locations. simtrace_s3_bucket_dict[agent_name] = simtrace_s3_bucket[agent_index] simtrace_s3_prefix_dict[agent_name] = simtrace_s3_object_prefix[ agent_index] metrics_s3_bucket_dict[agent_name] = metrics_s3_buckets[agent_index] metrics_s3_obect_key_dict[agent_name] = metrics_s3_object_keys[ agent_index] mp4_s3_bucket_dict[agent_name] = mp4_s3_bucket[agent_index] mp4_s3_object_prefix_dict[agent_name] = mp4_s3_object_prefix[ agent_index] s3_client = SageS3Client(bucket=arg_s3_bucket[agent_index], s3_prefix=arg_s3_prefix[agent_index], aws_region=args.aws_region) # Load the model metadata if not os.path.exists(os.path.join(CUSTOM_FILES_PATH, agent_name)): os.makedirs(os.path.join(CUSTOM_FILES_PATH, agent_name)) model_metadata_local_path = os.path.join( os.path.join(CUSTOM_FILES_PATH, agent_name), 'model_metadata.json') utils.load_model_metadata( s3_client, os.path.normpath("%s/model/model_metadata.json" % arg_s3_prefix[agent_index]), model_metadata_local_path) # Handle backward compatibility _, _, version = parse_model_metadata(model_metadata_local_path) if float(version) < float(SIMAPP_VERSION) and \ not utils.has_current_ckpnt_name(arg_s3_bucket[agent_index], arg_s3_prefix[agent_index], args.aws_region): utils.make_compatible(arg_s3_bucket[agent_index], arg_s3_prefix[agent_index], args.aws_region, SyncFiles.TRAINER_READY.value) # Select the optimal model utils.do_model_selection(s3_bucket=arg_s3_bucket[agent_index], s3_prefix=arg_s3_prefix[agent_index], region=args.aws_region) # Download hyperparameters from SageMaker if not os.path.exists(agent_name): os.makedirs(agent_name) hyperparameters_file_success = False hyperparams_s3_key = os.path.normpath(arg_s3_prefix[agent_index] + "/ip/hyperparameters.json") hyperparameters_file_success = s3_client.download_file( s3_key=hyperparams_s3_key, local_path=os.path.join(agent_name, "hyperparameters.json")) sm_hyperparams_dict = {} if hyperparameters_file_success: logger.info("Received Sagemaker hyperparameters successfully!") with open(os.path.join(agent_name, "hyperparameters.json")) as file: sm_hyperparams_dict = json.load(file) else: logger.info("SageMaker hyperparameters not found.") agent_config = { 'model_metadata': model_metadata_local_path, ConfigParams.CAR_CTRL_CONFIG.value: { ConfigParams.LINK_NAME_LIST.value: [ link_name.replace('racecar', racecar_name) for link_name in LINK_NAMES ], ConfigParams.VELOCITY_LIST.value: [ velocity_topic.replace('racecar', racecar_name) for velocity_topic in VELOCITY_TOPICS ], ConfigParams.STEERING_LIST.value: [ steering_topic.replace('racecar', racecar_name) for steering_topic in STEERING_TOPICS ], ConfigParams.CHANGE_START.value: utils.str2bool(rospy.get_param('CHANGE_START_POSITION', False)), ConfigParams.ALT_DIR.value: utils.str2bool( rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False)), ConfigParams.ACTION_SPACE_PATH.value: 'custom_files/' + agent_name + '/model_metadata.json', ConfigParams.REWARD.value: reward_function, ConfigParams.AGENT_NAME.value: racecar_name, ConfigParams.VERSION.value: version, ConfigParams.NUMBER_OF_RESETS.value: args.number_of_resets, ConfigParams.PENALTY_SECONDS.value: args.penalty_seconds, ConfigParams.NUMBER_OF_TRIALS.value: args.number_of_trials, ConfigParams.IS_CONTINUOUS.value: args.is_continuous, ConfigParams.RACE_TYPE.value: args.race_type, ConfigParams.COLLISION_PENALTY.value: args.collision_penalty, ConfigParams.OFF_TRACK_PENALTY.value: args.off_track_penalty, ConfigParams.START_POSITION.value: start_positions[agent_index], ConfigParams.DONE_CONDITION.value: done_condition } } metrics_s3_config = { MetricsS3Keys.METRICS_BUCKET.value: metrics_s3_buckets[agent_index], MetricsS3Keys.METRICS_KEY.value: metrics_s3_object_keys[agent_index], # Replaced rospy.get_param('AWS_REGION') to be equal to the argument being passed # or default argument set MetricsS3Keys.REGION.value: args.aws_region } aws_region = rospy.get_param('AWS_REGION', args.aws_region) s3_writer_job_info = [] if simtrace_s3_bucket: s3_writer_job_info.append( IterationData( 'simtrace', simtrace_s3_bucket[agent_index], simtrace_s3_object_prefix[agent_index], aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, agent_name, IterationDataLocalFileNames. SIM_TRACE_EVALUATION_LOCAL_FILE.value))) if mp4_s3_bucket: s3_writer_job_info.extend([ IterationData( 'pip', mp4_s3_bucket[agent_index], mp4_s3_object_prefix[agent_index], aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, agent_name, IterationDataLocalFileNames. CAMERA_PIP_MP4_VALIDATION_LOCAL_PATH.value)), IterationData( '45degree', mp4_s3_bucket[agent_index], mp4_s3_object_prefix[agent_index], aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, agent_name, IterationDataLocalFileNames. CAMERA_45DEGREE_MP4_VALIDATION_LOCAL_PATH.value)), IterationData( 'topview', mp4_s3_bucket[agent_index], mp4_s3_object_prefix[agent_index], aws_region, os.path.join( ITERATION_DATA_LOCAL_FILE_PATH, agent_name, IterationDataLocalFileNames. CAMERA_TOPVIEW_MP4_VALIDATION_LOCAL_PATH.value)) ]) s3_writers.append(S3Writer(job_info=s3_writer_job_info)) run_phase_subject = RunPhaseSubject() agent_list.append( create_rollout_agent( agent_config, EvalMetrics(agent_name, metrics_s3_config, args.is_continuous), run_phase_subject)) agent_list.append(create_obstacles_agent()) agent_list.append(create_bot_cars_agent()) # ROS service to indicate all the robomaker markov packages are ready for consumption signal_robomaker_markov_package_ready() PhaseObserver('/agent/training_phase', run_phase_subject) enable_domain_randomization = utils.str2bool( rospy.get_param('ENABLE_DOMAIN_RANDOMIZATION', False)) graph_manager, _ = get_graph_manager( hp_dict=sm_hyperparams_dict, agent_list=agent_list, run_phase_subject=run_phase_subject, enable_domain_randomization=enable_domain_randomization, done_condition=done_condition) ds_params_instance = S3BotoDataStoreParameters( aws_region=args.aws_region, bucket_names=s3_bucket_dict, base_checkpoint_dir=args.local_model_directory, s3_folders=s3_prefix_dict) graph_manager.data_store = S3BotoDataStore(params=ds_params_instance, graph_manager=graph_manager, ignore_lock=True) graph_manager.env_params.seed = 0 task_parameters = TaskParameters() task_parameters.checkpoint_restore_path = args.local_model_directory tournament_worker(graph_manager=graph_manager, number_of_trials=args.number_of_trials, task_parameters=task_parameters, s3_writers=s3_writers, is_continuous=args.is_continuous, park_positions=park_positions) # tournament_worker: write race report to local file. write_race_report(graph_manager, model_s3_bucket_map=s3_bucket_dict, model_s3_prefix_map=s3_prefix_dict, metrics_s3_bucket_map=metrics_s3_bucket_dict, metrics_s3_key_map=metrics_s3_obect_key_dict, simtrace_s3_bucket_map=simtrace_s3_bucket_dict, simtrace_s3_prefix_map=simtrace_s3_prefix_dict, mp4_s3_bucket_map=mp4_s3_bucket_dict, mp4_s3_prefix_map=mp4_s3_object_prefix_dict, display_names=display_names) # tournament_worker: terminate tournament_race_node. terminate_tournament_race()
simtrace_s3_prefix_map=simtrace_s3_prefix_dict, mp4_s3_bucket_map=mp4_s3_bucket_dict, mp4_s3_prefix_map=mp4_s3_object_prefix_dict, display_names=display_names) # tournament_worker: terminate tournament_race_node. terminate_tournament_race() if __name__ == '__main__': try: rospy.init_node('rl_coach', anonymous=True) main() except ValueError as err: if utils.is_error_bad_ckpnt(err): log_and_exit("User modified model: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) else: log_and_exit("Tournament worker value error: {}".format(err), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) except GenericRolloutError as ex: ex.log_except_and_exit() except GenericRolloutException as ex: ex.log_except_and_exit() except Exception as ex: log_and_exit("Tournament worker error: {}".format(ex), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def setup_race(self): """ Setting up the race for the current racer. Returns: bool: True if setup race is successful. False is a non fatal exception occurred. """ LOG.info("[virtual event manager] Setting up race for racer") try: self._model_updater.unpause_physics() LOG.info( "[virtual event manager] Unpause physics in current world to setup race." ) # step 1: hide the racecar to a position that camera cannot see self._hide_racecar_model( model_name=self._current_car_model_state.model_name) # step 2: set camera to starting position after previous car is deleted initial_pose = self._track_data.get_racecar_start_pose( racecar_idx=0, racer_num=1, start_position=get_start_positions(1)[0]) self._main_cameras[VIRTUAL_EVENT].reset_pose(car_pose=initial_pose) LOG.info("[virtual event manager] Reset camera to starting line.") # step 3: download model metadata from s3 sensors, version, model_metadata = self._download_model_metadata() # step 4: check whether body shell and sensors have been updated # to decide whether need to delete and re-spawn. Then, update # shell or color accordingly if hasattr(self._current_racer, "carConfig") and \ hasattr(self._current_racer.carConfig, "bodyShellType"): body_shell_type = self._current_racer.carConfig.bodyShellType \ if self._current_racer.carConfig.bodyShellType in self._valid_body_shells \ else const.BodyShellType.DEFAULT.value else: body_shell_type = const.BodyShellType.DEFAULT.value # check whether need to delete and respawn racecar # re-spawn if sensor or body shell type changed if self._last_body_shell_type != body_shell_type or \ self._last_sensors != sensors: # delete last racecar self._racecar_model.delete() # respawn a new racecar hide_pose = Pose() hide_pose.position.x = self._hide_positions[ self._hide_position_idx][0] hide_pose.position.y = self._hide_positions[ self._hide_position_idx][1] self._racecar_model.spawn( name=self._current_car_model_state.model_name, pose=hide_pose, include_second_camera="true" if Input.STEREO.value in sensors else "false", include_lidar_sensor=str( any(["lidar" in sensor.lower() for sensor in sensors])).lower(), body_shell_type=body_shell_type, lidar_360_degree_sample=str(LIDAR_360_DEGREE_SAMPLE), lidar_360_degree_horizontal_resolution=str( LIDAR_360_DEGREE_HORIZONTAL_RESOLUTION), lidar_360_degree_min_angle=str(LIDAR_360_DEGREE_MIN_ANGLE), lidar_360_degree_max_angle=str(LIDAR_360_DEGREE_MAX_ANGLE), lidar_360_degree_min_range=str(LIDAR_360_DEGREE_MIN_RANGE), lidar_360_degree_max_range=str(LIDAR_360_DEGREE_MAX_RANGE), lidar_360_degree_range_resolution=str( LIDAR_360_DEGREE_RANGE_RESOLUTION), lidar_360_degree_noise_mean=str( LIDAR_360_DEGREE_NOISE_MEAN), lidar_360_degree_noise_stddev=str( LIDAR_360_DEGREE_NOISE_STDDEV)) self._last_body_shell_type = body_shell_type self._last_sensors = sensors # step 5: download checkpoint, setup simtrace, mp4, clear metrics, and setup graph manager # download checkpoint from s3 checkpoint = self._download_checkpoint(version) # setup the simtrace and mp4 writers if the s3 locations are available self._setup_simtrace_mp4_writers() # reset the metrics s3 location for the current racer self._reset_metrics_loc() # setup agents agent_list = self._get_agent_list(model_metadata, version) # after _setup_graph_manager finishes, physics is paused # physics will be unpaused again when race start self._setup_graph_manager(checkpoint, agent_list) LOG.info( "[virtual event manager] Graph manager successfully created the graph: setup race successful." ) # step 6: update body shell or color # treat amazon van digital reward specially by also hiding the collision wheel visuals = self._model_updater.get_model_visuals( self._current_car_model_state.model_name) if const.F1 in body_shell_type: self._model_updater.hide_visuals( visuals=visuals, ignore_keywords=["f1_body_link"] if "with_wheel" in body_shell_type.lower() else ["wheel", "f1_body_link"]) else: if hasattr(self._current_racer, "carConfig") and \ hasattr(self._current_racer.carConfig, "carColor"): car_color = self._current_racer.carConfig.carColor if self._current_racer.carConfig.carColor in self._valid_car_colors \ else DEFAULT_COLOR else: car_color = DEFAULT_COLOR self._model_updater.update_color(visuals, car_color) return True except GenericNonFatalException as ex: ex.log_except_and_continue() self.upload_race_status(status_code=ex.error_code, error_name=ex.error_name, error_details=ex.error_msg) self._clean_up_race() return False except Exception as ex: log_and_exit( "[virtual event manager] Something really wrong happened when setting up the race. {}" .format(ex), SIMAPP_VIRTUAL_EVENT_RACE_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def log_except_and_exit(self): '''Logs the exception to cloud watch and exits the sim app''' log_and_exit("Validation worker failed: {}".format(self.msg), SIMAPP_VALIDATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def validate(s3_bucket, s3_prefix, aws_region): screen.set_use_colors(False) screen.log_title(" S3 bucket: {} \n S3 prefix: {}".format( s3_bucket, s3_prefix)) # download model metadata model_metadata = ModelMetadata(bucket=s3_bucket, s3_key=get_s3_key( s3_prefix, MODEL_METADATA_S3_POSTFIX), region_name=aws_region, local_path=MODEL_METADATA_LOCAL_PATH) # Create model local path os.makedirs(LOCAL_MODEL_DIR) try: # Handle backward compatibility model_metadata_info = model_metadata.get_model_metadata_info() observation_list = model_metadata_info[ModelMetadataKeys.SENSOR.value] version = model_metadata_info[ModelMetadataKeys.VERSION.value] except Exception as ex: log_and_exit("Failed to parse model_metadata file: {}".format(ex), SIMAPP_VALIDATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) # Below get_transition_data function must called before create_training_agent function # to avoid 500 in case unsupported Sensor is received. # create_training_agent will exit with 500 if unsupported sensor is received, # and get_transition_data function below will exit with 400 if unsupported sensor is received. # We want to return 400 in model validation case if unsupported sensor is received. # Thus, call this get_transition_data function before create_traning_agent function! transitions = get_transition_data(observation_list) checkpoint = Checkpoint(bucket=s3_bucket, s3_prefix=s3_prefix, region_name=args.aws_region, agent_name='agent', checkpoint_dir=LOCAL_MODEL_DIR) # make coach checkpoint compatible if version < SIMAPP_VERSION_2 and not checkpoint.rl_coach_checkpoint.is_compatible( ): checkpoint.rl_coach_checkpoint.make_compatible( checkpoint.syncfile_ready) # add checkpoint into checkpoint_dict checkpoint_dict = {'agent': checkpoint} agent_config = { 'model_metadata': model_metadata, ConfigParams.CAR_CTRL_CONFIG.value: { ConfigParams.LINK_NAME_LIST.value: [], ConfigParams.VELOCITY_LIST.value: {}, ConfigParams.STEERING_LIST.value: {}, ConfigParams.CHANGE_START.value: None, ConfigParams.ALT_DIR.value: None, ConfigParams.MODEL_METADATA.value: model_metadata, ConfigParams.REWARD.value: None, ConfigParams.AGENT_NAME.value: 'racecar' } } agent_list = list() agent_list.append(create_training_agent(agent_config)) sm_hyperparams_dict = {} graph_manager, _ = get_graph_manager(hp_dict=sm_hyperparams_dict, agent_list=agent_list, run_phase_subject=None) ds_params_instance = S3BotoDataStoreParameters( checkpoint_dict=checkpoint_dict) graph_manager.data_store = S3BotoDataStore(ds_params_instance, graph_manager, ignore_lock=True) task_parameters = TaskParameters() task_parameters.checkpoint_restore_path = LOCAL_MODEL_DIR _validate(graph_manager=graph_manager, task_parameters=task_parameters, transitions=transitions, s3_bucket=s3_bucket, s3_prefix=s3_prefix, aws_region=aws_region)
def main(): """ Main function for tournament""" try: # parse argument s3_region = sys.argv[1] s3_bucket = sys.argv[2] s3_prefix = sys.argv[3] s3_yaml_name = sys.argv[4] # create boto3 session/client and download yaml/json file session = boto3.session.Session() s3_endpoint_url = os.environ.get("S3_ENDPOINT_URL", None) s3_client = S3Client(region_name=s3_region, s3_endpoint_url=s3_endpoint_url) # Intermediate tournament files queue_pickle_name = 'tournament_candidate_queue.pkl' queue_pickle_s3_key = os.path.normpath( os.path.join(s3_prefix, queue_pickle_name)) local_queue_pickle_path = os.path.abspath( os.path.join(os.getcwd(), queue_pickle_name)) report_pickle_name = 'tournament_report.pkl' report_pickle_s3_key = os.path.normpath( os.path.join(s3_prefix, report_pickle_name)) local_report_pickle_path = os.path.abspath( os.path.join(os.getcwd(), report_pickle_name)) final_report_name = 'tournament_report.json' final_report_s3_key = os.path.normpath( os.path.join(s3_prefix, final_report_name)) try: s3_client.download_file(bucket=s3_bucket, s3_key=queue_pickle_s3_key, local_path=local_queue_pickle_path) s3_client.download_file(bucket=s3_bucket, s3_key=report_pickle_s3_key, local_path=local_report_pickle_path) except: pass # download yaml file yaml_file = YamlFile( agent_type=AgentType.TOURNAMENT.value, bucket=s3_bucket, s3_key=get_s3_key(s3_prefix, s3_yaml_name), region_name=s3_region, s3_endpoint_url=s3_endpoint_url, local_path=YAML_LOCAL_PATH_FORMAT.format(s3_yaml_name)) yaml_dict = yaml_file.get_yaml_values() if os.path.exists(local_queue_pickle_path): with open(local_queue_pickle_path, 'rb') as f: tournament_candidate_queue = pickle.load(f) with open(local_report_pickle_path, 'rb') as f: tournament_report = pickle.load(f) logger.info('tournament_candidate_queue loaded from existing file') else: logger.info('tournament_candidate_queue initialized') tournament_candidate_queue = deque() for agent_idx, _ in enumerate( yaml_dict[YamlKey.MODEL_S3_BUCKET_YAML_KEY.value]): tournament_candidate_queue.append(( yaml_dict[YamlKey.MODEL_S3_BUCKET_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.MODEL_S3_PREFIX_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.MODEL_METADATA_FILE_S3_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.METRICS_S3_BUCKET_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.METRICS_S3_PREFIX_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.SIMTRACE_S3_BUCKET_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.SIMTRACE_S3_PREFIX_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.MP4_S3_BUCKET_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.MP4_S3_PREFIX_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.DISPLAY_NAME_YAML_KEY.value][agent_idx], # TODO: Deprecate the DISPLAY_NAME and use only the RACER_NAME without if else check "" if None in yaml_dict.get(YamlKey.RACER_NAME_YAML_KEY.value, [None]) \ else yaml_dict[YamlKey.RACER_NAME_YAML_KEY.value][agent_idx], yaml_dict[YamlKey.BODY_SHELL_TYPE_YAML_KEY.value][agent_idx] )) tournament_report = {"race_results": []} race_idx = len(tournament_report["race_results"]) while len(tournament_candidate_queue) > 1: car1 = tournament_candidate_queue.popleft() car2 = tournament_candidate_queue.popleft() (car1_model_s3_bucket, car1_s3_prefix, car1_model_metadata, car1_metrics_bucket, car1_metrics_s3_key, car1_simtrace_bucket, car1_simtrace_prefix, car1_mp4_bucket, car1_mp4_prefix, car1_display_name, car1_racer_name, car1_body_shell_type) = car1 (car2_model_s3_bucket, car2_s3_prefix, car2_model_metadata, car2_metrics_bucket, car2_metrics_s3_key, car2_simtrace_bucket, car2_simtrace_prefix, car2_mp4_bucket, car2_mp4_prefix, car2_display_name, car2_racer_name, car2_body_shell_type) = car2 race_yaml_dict = generate_race_yaml(yaml_dict=yaml_dict, car1=car1, car2=car2, race_idx=race_idx) if s3_endpoint_url is not None: race_yaml_dict["S3_ENDPOINT_URL"] = s3_endpoint_url race_model_s3_buckets = [ car1_model_s3_bucket, car2_model_s3_bucket ] race_model_metadatas = [car1_model_metadata, car2_model_metadata] body_shell_types = [car1_body_shell_type, car2_body_shell_type] # List of directories created dirs_to_delete = list() yaml_dir = os.path.abspath(os.path.join(os.getcwd(), str(race_idx))) os.makedirs(yaml_dir) dirs_to_delete.append(yaml_dir) race_yaml_path = os.path.abspath( os.path.join(yaml_dir, 'evaluation_params.yaml')) with open(race_yaml_path, 'w') as race_yaml_file: yaml.dump(race_yaml_dict, race_yaml_file) # List of racecar names that should include second camera while launching racecars_with_stereo_cameras = list() # List of racecar names that should include lidar while launching racecars_with_lidars = list() # List of SimApp versions simapp_versions = list() for agent_index, model_s3_bucket in enumerate( race_model_s3_buckets): racecar_name = 'racecar_' + str(agent_index) json_key = race_model_metadatas[agent_index] # download model metadata try: model_metadata = ModelMetadata( bucket=model_s3_bucket, s3_key=json_key, region_name=s3_region, s3_endpoint_url=s3_endpoint_url, local_path=MODEL_METADATA_LOCAL_PATH_FORMAT.format( racecar_name)) dirs_to_delete.append(model_metadata.local_dir) except Exception as e: log_and_exit( "Failed to download model_metadata file: s3_bucket: {}, s3_key: {}, {}" .format(model_s3_bucket, json_key, e), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500) sensors, _, simapp_version = model_metadata.get_model_metadata_info( ) simapp_versions.append(str(simapp_version)) if Input.STEREO.value in sensors: racecars_with_stereo_cameras.append(racecar_name) if Input.LIDAR.value in sensors or Input.SECTOR_LIDAR.value in sensors: racecars_with_lidars.append(racecar_name) cmd = [ os.path.join(os.path.dirname(os.path.abspath(__file__)), "tournament_race_node.py"), str(race_idx), race_yaml_path, ','.join(racecars_with_stereo_cameras), ','.join(racecars_with_lidars), ','.join(simapp_versions), ','.join(body_shell_types) ] try: return_code, _, stderr = run_cmd(cmd_args=cmd, shell=False, stdout=None, stderr=None) except KeyboardInterrupt: logger.info( "KeyboardInterrupt raised, SimApp must be faulted! exiting..." ) return # Retrieve winner and append tournament report with open('race_report.pkl', 'rb') as f: race_report = pickle.load(f) race_report['race_idx'] = race_idx winner = car1 if race_report[ 'winner'] == car1_display_name else car2 logger.info("race {}'s winner: {}".format(race_idx, race_report['winner'])) tournament_candidate_queue.append(winner) tournament_report["race_results"].append(race_report) # Clean up directories created for dir_to_delete in dirs_to_delete: shutil.rmtree(dir_to_delete, ignore_errors=True) race_idx += 1 s3_extra_args = get_s3_kms_extra_args() # Persist latest queue and report to use after job restarts. with open(local_queue_pickle_path, 'wb') as f: pickle.dump(tournament_candidate_queue, f, protocol=2) s3_client.upload_file(bucket=s3_bucket, s3_key=queue_pickle_s3_key, local_path=local_queue_pickle_path, s3_kms_extra_args=s3_extra_args) with open(local_report_pickle_path, 'wb') as f: pickle.dump(tournament_report, f, protocol=2) s3_client.upload_file(bucket=s3_bucket, s3_key=report_pickle_s3_key, local_path=local_report_pickle_path, s3_kms_extra_args=s3_extra_args) # If there is more than 1 candidates then restart the simulation job otherwise # tournament is finished, persists final report and ends the job. if len(tournament_candidate_queue) > 1: restart_simulation_job( os.environ.get('AWS_ROBOMAKER_SIMULATION_JOB_ARN'), s3_region) break else: # Persist final tournament report in json format # and terminate the job by canceling it s3_client.put_object(bucket=s3_bucket, s3_key=final_report_s3_key, body=json.dumps(tournament_report), s3_kms_extra_args=s3_extra_args) cancel_simulation_job( os.environ.get('AWS_ROBOMAKER_SIMULATION_JOB_ARN'), s3_region) except ValueError as ex: log_and_exit("User modified model_metadata.json: {}".format(ex), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400) except Exception as e: log_and_exit("Tournament node failed: {}".format(e), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500)
def log_except_and_exit(self): '''Logs the exception to cloud watch and exits the sim app''' log_and_exit("Rollout worker failed: {}".format(self.msg), SIMAPP_SIMULATION_WORKER_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_400)