def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): """ block until there is a checkpoint in checkpoint_dir """ for i in range(timeout): if data_store: data_store.load_from_store() if has_checkpoint(checkpoint_dir): return time.sleep(10) # one last time if has_checkpoint(checkpoint_dir): return utils.json_format_logger( "checkpoint never found in {}, Waited {} seconds. Job failed!".format( checkpoint_dir, timeout), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503)) traceback.print_exc() raise ValueError( ('Waited {timeout} seconds, but checkpoint never found in ' '{checkpoint_dir}').format( timeout=timeout, checkpoint_dir=checkpoint_dir, ))
def upload_model(self, checkpoint_dir): try: s3_client = self.get_client() num_files = 0 for root, _, files in os.walk("./" + checkpoint_dir): for filename in files: abs_name = os.path.abspath(os.path.join(root, filename)) s3_client.upload_file( abs_name, self.bucket, "%s/%s/%s" % (self.s3_prefix, checkpoint_dir, filename)) num_files += 1 except botocore.exceptions.ClientError as e: utils.json_format_logger( "Model failed to upload to {}, {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Model failed to upload to {}, {}".format(self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def write_ip_config(self, ip): try: s3_client = self.get_client() data = {"IP": ip} json_blob = json.dumps(data) file_handle = io.BytesIO(json_blob.encode()) file_handle_done = io.BytesIO(b'done') s3_client.upload_fileobj(file_handle, self.bucket, self.config_key) s3_client.upload_fileobj(file_handle_done, self.bucket, self.done_file_key) except botocore.exceptions.ClientError as e: utils.json_format_logger( "Write ip config failed to upload to {}, {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Write ip config failed to upload to {}, {}".format( self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def download_file(self, s3_key, local_path): s3_client = self.get_client() try: s3_client.download_file(self.bucket, s3_key, local_path) return True except botocore.exceptions.ClientError as e: # It is possible that the file isn't there in which case we should return fasle and let the client decide the next action if e.response['Error']['Code'] == "404": return False else: utils.json_format_logger( "Unable to download {} from {}: {}".format( s3_key, self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to download {} from {}: {}".format( s3_key, self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def get_latest_checkpoint(self): try: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "latest_ckpt")) if not os.path.exists(self.params.checkpoint_dir): os.makedirs(self.params.checkpoint_dir) while True: s3_client = self._get_client() state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) # wait until lock is removed response = s3_client.list_objects_v2(Bucket=self.params.bucket, Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value)) if "Contents" not in response: try: # fetch checkpoint state file from S3 s3_client.download_file(Bucket=self.params.bucket, Key=self._get_s3_key(state_file.filename), Filename=filename) except Exception as e: time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND) continue else: time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND) continue return self._get_current_checkpoint_number(checkpoint_metadata_filepath=filename) except Exception as e: utils.json_format_logger("Exception [{}] occured while getting latest checkpoint from S3.".format(e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
def callback_image(self, data): try: self.image_queue.put_nowait(data) except queue.Full: pass except Exception as ex: utils.json_format_logger("Error retrieving frame from gazebo: {}".format(ex), **utils.build_system_error_dict(utils.SIMAPP_ENVIRONMENT_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
def log_info(message): ''' Helper method that logs the exception mesage - Message to send to the log ''' json_format_logger( message, **build_system_error_dict(SIMAPP_MEMORY_BACKEND_EXCEPTION, SIMAPP_EVENT_ERROR_CODE_500))
def download_model(self, checkpoint_dir): s3_client = self.get_client() filename = "None" try: filename = os.path.abspath( os.path.join(checkpoint_dir, "checkpoint")) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) while True: response = s3_client.list_objects_v2(Bucket=self.bucket, Prefix=self._get_s3_key( self.lock_file)) if "Contents" not in response: # If no lock is found, try getting the checkpoint try: s3_client.download_file( Bucket=self.bucket, Key=self._get_s3_key("checkpoint"), Filename=filename) except Exception as e: time.sleep(2) continue else: time.sleep(2) continue ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() text_format.Merge(contents, ckpt) rel_path = ckpt.model_checkpoint_path checkpoint = int(rel_path.split('_Step')[0]) response = s3_client.list_objects_v2( Bucket=self.bucket, Prefix=self._get_s3_key(rel_path)) if "Contents" in response: num_files = 0 for obj in response["Contents"]: filename = os.path.abspath( os.path.join( checkpoint_dir, obj["Key"].replace( self.model_checkpoints_prefix, ""))) s3_client.download_file(Bucket=self.bucket, Key=obj["Key"], Filename=filename) num_files += 1 return True except Exception as e: utils.json_format_logger( "{} while downloading the model {} from S3".format( e, filename), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) return False
def download_model(self, checkpoint_dir): s3_client = self.get_client() logger.info("Downloading pretrained model from %s/%s %s" % (self.bucket, self.model_checkpoints_prefix, checkpoint_dir)) filename = "None" try: filename = os.path.abspath(os.path.join(checkpoint_dir, "checkpoint")) if not os.path.exists(checkpoint_dir): logger.info("Model folder %s does not exist, creating" % filename) os.makedirs(checkpoint_dir) while True: response = s3_client.list_objects_v2(Bucket=self.bucket, Prefix=self._get_s3_key(self.lock_file)) if "Contents" not in response: # If no lock is found, try getting the checkpoint try: key = self._get_s3_key("checkpoint") logger.info("Downloading %s" % key) s3_client.download_file(Bucket=self.bucket, Key=key, Filename=filename) except Exception as e: logger.info("Something went wrong, will retry in 2 seconds %s" % e) time.sleep(2) continue else: logger.info("Found a lock file %s , waiting" % self._get_s3_key(self.lock_file)) time.sleep(2) continue ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() text_format.Merge(contents, ckpt) rel_path = ckpt.model_checkpoint_path checkpoint = int(rel_path.split('_Step')[0]) response = s3_client.list_objects_v2(Bucket=self.bucket, Prefix=self._get_s3_key(rel_path)) if "Contents" in response: num_files = 0 for obj in response["Contents"]: filename = os.path.abspath(os.path.join(checkpoint_dir, obj["Key"].replace(self.model_checkpoints_prefix, ""))) logger.info("Downloading model file %s" % filename) s3_client.download_file(Bucket=self.bucket, Key=obj["Key"], Filename=filename) num_files += 1 return True except Exception as e: utils.json_format_logger ("{} while downloading the model {} from S3".format(e, filename), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) return False
def upload_file(self, s3_key, local_path): s3_client = self.get_client() try: s3_client.upload_file(Filename=local_path, Bucket=self.bucket, Key=s3_key) return True except Exception as e: utils.json_format_logger("{} on upload file-{} to s3 bucket-{} key-{}".format(e, local_path, self.bucket, s3_key), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) return False
def _get_current_checkpoint_number(self, checkpoint_metadata_filepath=None): try: if not os.path.exists(checkpoint_metadata_filepath): return None with open(checkpoint_metadata_filepath, 'r') as fp: data = fp.read() return int(data.split('_')[0]) except Exception as e: utils.json_format_logger("Exception[{}] occured while reading checkpoint metadata".format(e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) raise e
def get_ip(self): s3_client = self.get_client() self._wait_for_ip_upload() try: s3_client.download_file(self.bucket, self.config_key, 'ip.json') with open("ip.json") as f: ip = json.load(f)["IP"] return ip except Exception as e: utils.json_format_logger("Exception [{}] occured, Cannot fetch IP of redis server running in SageMaker. Job failed!".format(e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503)) sys.exit(1)
def download_customer_reward_function(s3_client, reward_file_s3_key): reward_function_local_path = os.path.join(CUSTOM_FILES_PATH, "customer_reward_function.py") success_reward_function_download = s3_client.download_file( s3_key=reward_file_s3_key, local_path=reward_function_local_path) if not success_reward_function_download: utils.json_format_logger( "Could not download the customer reward function file. Job failed!", **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503)) traceback.print_exc() sys.exit(1)
def _wait_for_ip_upload(self, timeout=600): s3_client = self.get_client() time_elapsed = 0 while time_elapsed < timeout: try: response = s3_client.list_objects(Bucket=self.bucket, Prefix=self.done_file_key) except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to access {}: {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to access {}: {}".format(self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() if "Contents" in response: return time.sleep(1) time_elapsed += 1 if time_elapsed % 5 == 0: logger.info( "Waiting for SageMaker Redis server IP... Time elapsed: %s seconds" % time_elapsed) utils.json_format_logger( "Timed out while attempting to retrieve the Redis IP", **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def write_simtrace_data(self,jsondata): if self.data_state != SIMTRACE_DATA_UPLOAD_UNKNOWN_STATE: try: csvdata = [] for key in SIMTRACE_CSV_DATA_HEADER: csvdata.append(jsondata[key]) self.csvwriter.writerow(csvdata) self.total_upload_size += sys.getsizeof(csvdata) logger.debug ("csvdata={} size data={} csv={}".format(csvdata, sys.getsizeof(csvdata), sys.getsizeof(self.simtrace_csv_data.getvalue()))) except Exception as ex: utils.json_format_logger("Invalid SIM_TRACE data format , Exception={}. Job failed!".format(ex), **utils.build_system_error_dict(utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) traceback.print_exc() utils.simapp_exit_gracefully()
def callback_image(self, data): try: self.image_queue.put_nowait(data) except queue.Full: # Only warn if its the middle of an episode, not during training if self.allow_servo_step_signals: logger.info("Warning: dropping image due to queue full") pass except Exception as ex: utils.json_format_logger( "Error retrieving frame from gazebo: {}".format(ex), **utils.build_system_error_dict( utils.SIMAPP_ENVIRONMENT_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
def racecar_reset(self): try: for joint in EFFORT_JOINTS: self.clear_forces_client(joint) prev_index, next_index = self.find_prev_next_waypoints(self.start_ndist) self.reset_car_client(self.start_ndist, next_index) # First clear the queue so that we set the state to the start image _ = self.image_queue.get(block=True, timeout=None) self.set_next_state() except Exception as ex: utils.json_format_logger("Unable to reset the car: {}".format(ex), **utils.build_system_error_dict(utils.SIMAPP_ENVIRONMENT_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
def _get_current_checkpoint(self): try: checkpoint_metadata_filepath = os.path.abspath( os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME)) checkpoint = CheckpointState() if os.path.exists(checkpoint_metadata_filepath) == False: return None contents = open(checkpoint_metadata_filepath, 'r').read() text_format.Merge(contents, checkpoint) return checkpoint except Exception as e: utils.json_format_logger("Exception[{}] occured while reading checkpoint metadata".format(e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) raise e
def get_latest_checkpoint(self): try: filename = os.path.abspath( os.path.join(self.params.checkpoint_dir, "latest_ckpt")) if not os.path.exists(self.params.checkpoint_dir): os.makedirs(self.params.checkpoint_dir) while True: s3_client = self._get_client() # Check if there's a lock file response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key(self.params.lock_file)) if "Contents" not in response: try: # If no lock is found, try getting the checkpoint s3_client.download_file( Bucket=self.params.bucket, Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME), Filename=filename) except Exception as e: logger.info( "Error occured while getting latest checkpoint %s. Waiting." % e) time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue else: time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue checkpoint = self._get_current_checkpoint( checkpoint_metadata_filepath=filename) if checkpoint: checkpoint_number = self._get_checkpoint_number(checkpoint) return checkpoint_number except Exception as e: utils.json_format_logger( "Exception [{}] occured while getting latest checkpoint from S3." .format(e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
def _wait_for_ip_upload(self, timeout=600): s3_client = self.get_client() time_elapsed = 0 while True: response = s3_client.list_objects(Bucket=self.bucket, Prefix=self.done_file_key) if "Contents" not in response: time.sleep(1) time_elapsed += 1 if time_elapsed % 5 == 0: logger.info ("Waiting for SageMaker Redis server IP... Time elapsed: %s seconds" % time_elapsed) if time_elapsed >= timeout: #raise RuntimeError("Cannot retrieve IP of redis server running in SageMaker") utils.json_format_logger("Cannot retrieve IP of redis server running in SageMaker. Job failed!", **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503)) sys.exit(1) else: return
def upload_finished_file(self): try: s3_client = self._get_client() s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.FINISHED.value)) except botocore.exceptions.ClientError as e: utils.json_format_logger("Unable to upload finished file to {}, {}" .format(self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger("Unable to upload finished file to {}, {}" .format(self.params.bucket, e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def upload_file(self, s3_key, local_path): s3_client = self.get_client() try: s3_client.upload_file(Filename=local_path, Bucket=self.bucket, Key=s3_key) return True except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to upload {} to {}: {}".format( s3_key, self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) except Exception as e: utils.json_format_logger( "Unable to upload {} to {}: {}".format(s3_key, self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) return False
def upload_hyperparameters(self, hyperparams_json): try: s3_client = self.get_client() file_handle = io.BytesIO(hyperparams_json.encode()) s3_client.upload_fileobj(file_handle, self.bucket, self.hyperparameters_key) except botocore.exceptions.ClientError as e: utils.json_format_logger( "Hyperparameters failed to upload to {}, {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Hyperparameters failed to upload to {}, {}".format( self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): """ block until there is a checkpoint in checkpoint_dir """ for i in range(timeout): if data_store: data_store.load_from_store() if has_checkpoint(checkpoint_dir): return time.sleep(10) # one last time if has_checkpoint(checkpoint_dir): return utils.json_format_logger( "Checkpoint never found in {}, waited {} seconds.".format( checkpoint_dir, timeout), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) traceback.print_exc() utils.simapp_exit_gracefully()
def get_ip(self): s3_client = self.get_client() self._wait_for_ip_upload() try: s3_client.download_file(self.bucket, self.config_key, 'ip.json') with open("ip.json") as f: ip = json.load(f)["IP"] return ip except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to retrieve redis ip from {}: {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to retrieve redis ip from {}: {}".format( self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def training_worker(graph_manager, task_parameters, user_batch_size, user_episode_per_rollout): try: # initialize graph graph_manager.create_graph(task_parameters) # save initial checkpoint graph_manager.save_checkpoint() # training loop steps = 0 graph_manager.setup_memory_backend() graph_manager.signal_ready() # To handle SIGTERM door_man = utils.DoorMan() while steps < graph_manager.improve_steps.num_steps: graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) graph_manager.phase = core_types.RunPhase.UNDEFINED episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout() for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout if graph_manager.should_train(): # Make sure we have enough data for the requested batches rollout_steps = graph_manager.memory_backend.get_rollout_steps() if any(rollout_steps.values()) <= 0: utils.json_format_logger("No rollout data retrieved from the rollout worker", **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2)) # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing # as batch size less than 2 causes the batch list to become a scalar which causes an exception for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.network_wrappers['main'].batch_size = episode_batch_size steps += 1 graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train() graph_manager.phase = core_types.RunPhase.UNDEFINED # Check for Nan's in all agents rollout_has_nan = False for level in graph_manager.level_managers: for agent in level.agents.values(): if np.isnan(agent.loss.get_mean()): rollout_has_nan = True if rollout_has_nan: utils.json_format_logger("NaN detected in loss function, aborting training.", **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: graph_manager.save_checkpoint() else: graph_manager.occasionally_save_checkpoint() # Clear any data stored in signals that is no longer necessary graph_manager.reset_internal_state() for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout if door_man.terminate_now: utils.json_format_logger("Received SIGTERM. Checkpointing before exiting.", **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) graph_manager.save_checkpoint() break except ValueError as err: if utils.is_error_bad_ckpnt(err): utils.log_and_exit("User modified model: {}".format(err), utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400) else: utils.log_and_exit("An error occured while training: {}".format(err), utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500) except Exception as ex: utils.log_and_exit("An error occured while training: {}".format(ex), utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500) finally: graph_manager.data_store.upload_finished_file()
def save_to_store(self): try: s3_client = self._get_client() # Delete any existing lock file s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key( self.params.lock_file)) # We take a lock by writing a lock file to the same location in S3 s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=self.params.bucket, Key=self._get_s3_key( self.params.lock_file)) # Start writing the model checkpoints to S3 checkpoint = self._get_current_checkpoint() if checkpoint: checkpoint_number = self._get_checkpoint_number(checkpoint) checkpoint_file = None for root, _, files in os.walk(self.params.checkpoint_dir): num_files_uploaded = 0 for filename in files: # Skip the checkpoint file that has the latest checkpoint number if filename == CHECKPOINT_METADATA_FILENAME: checkpoint_file = (root, filename) continue if not filename.startswith(str(checkpoint_number)): continue # Upload all the other files from the checkpoint directory abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) s3_client.upload_file(Filename=abs_name, Bucket=self.params.bucket, Key=self._get_s3_key(rel_name)) num_files_uploaded += 1 logger.info("Uploaded {} files for checkpoint {}".format( num_files_uploaded, checkpoint_number)) # After all the checkpoint files have been uploaded, we upload the version file. abs_name = os.path.abspath( os.path.join(checkpoint_file[0], checkpoint_file[1])) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) s3_client.upload_file(Filename=abs_name, Bucket=self.params.bucket, Key=self._get_s3_key(rel_name)) # Release the lock by deleting the lock file from S3 s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key( self.params.lock_file)) # Upload the frozen graph which is used for deployment if self.graph_manager: utils.write_frozen_graph(self.graph_manager) # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint iteration_id = self.graph_manager.level_managers[0].agents[ 'agent'].training_iteration frozen_graph_fpath = utils.SM_MODEL_OUTPUT_DIR + "/model.pb" frozen_graph_s3_name = "model_%s.pb" % iteration_id s3_client.upload_file( Filename=frozen_graph_fpath, Bucket=self.params.bucket, Key=self._get_s3_key(frozen_graph_s3_name)) logger.info("saved intermediate frozen graph: {}".format( self._get_s3_key(frozen_graph_s3_name))) # Clean up old checkpoints checkpoint = self._get_current_checkpoint() if checkpoint: checkpoint_number = self._get_checkpoint_number(checkpoint) checkpoint_number_to_delete = checkpoint_number - 4 # List all the old checkpoint files to be deleted response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key( str(checkpoint_number_to_delete) + "_")) if "Contents" in response: num_files = 0 for obj in response["Contents"]: s3_client.delete_object(Bucket=self.params.bucket, Key=obj["Key"]) num_files += 1 except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to upload checkpoint to {}, {}".format( self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to upload checkpoint to {}, {}".format( self.params.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def load_from_store(self, expected_checkpoint_number=-1): try: filename = os.path.abspath( os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME)) if not os.path.exists(self.params.checkpoint_dir): os.makedirs(self.params.checkpoint_dir) # CT: remove all prior checkpoint files to save local storage space if len(os.listdir(self.params.checkpoint_dir)) > 0: files_to_remove = os.listdir(self.params.checkpoint_dir) utils.json_format_logger('Removing %d old checkpoint files' % len(files_to_remove)) for f in files_to_remove: try: os.unlink( os.path.abspath( os.path.join(self.params.checkpoint_dir, f))) except Exception as e: # swallow errors pass while True: s3_client = self._get_client() # Check if there's a finished file response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key(SyncFiles.FINISHED.value)) if "Contents" in response: finished_file_path = os.path.abspath( os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value)) s3_client.download_file(Bucket=self.params.bucket, Key=self._get_s3_key( SyncFiles.FINISHED.value), Filename=finished_file_path) return False # Check if there's a lock file response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key(self.params.lock_file)) if "Contents" not in response: try: # If no lock is found, try getting the checkpoint s3_client.download_file( Bucket=self.params.bucket, Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME), Filename=filename) except Exception as e: utils.json_format_logger( "Sleeping {} seconds while lock file is present". format( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND )) time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue else: utils.json_format_logger( "Sleeping {} seconds while lock file is present". format( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND )) time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue checkpoint = self._get_current_checkpoint() if checkpoint: checkpoint_number = self._get_checkpoint_number(checkpoint) # if we get a checkpoint that is older that the expected checkpoint, we wait for # the new checkpoint to arrive. if checkpoint_number < expected_checkpoint_number: time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue # Found a checkpoint to be downloaded response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key( checkpoint.model_checkpoint_path)) if "Contents" in response: num_files = 0 for obj in response["Contents"]: # Get the local filename of the checkpoint file full_key_prefix = os.path.normpath( self.key_prefix) + "/" filename = os.path.abspath( os.path.join( self.params.checkpoint_dir, obj["Key"].replace(full_key_prefix, ""))) s3_client.download_file(Bucket=self.params.bucket, Key=obj["Key"], Filename=filename) num_files += 1 return True except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to download checkpoint from {}, {}".format( self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to download checkpoint from {}, {}".format( self.params.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def training_worker(graph_manager, checkpoint_dir, use_pretrained_model, framework, memory_backend_params): """ restore a checkpoint then perform rollouts using the restored model """ # initialize graph task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir task_parameters.__dict__['checkpoint_save_secs'] = 20 task_parameters.__dict__['experiment_path'] = SM_MODEL_OUTPUT_DIR if framework.lower() == "mxnet": task_parameters.framework_type = Frameworks.mxnet if hasattr(graph_manager, 'agent_params'): for network_parameters in graph_manager.agent_params.network_wrappers.values( ): network_parameters.framework = Frameworks.mxnet elif hasattr(graph_manager, 'agents_params'): for ap in graph_manager.agents_params: for network_parameters in ap.network_wrappers.values(): network_parameters.framework = Frameworks.mxnet if use_pretrained_model: task_parameters.__dict__[ 'checkpoint_restore_dir'] = PRETRAINED_MODEL_DIR graph_manager.create_graph(task_parameters) # save randomly initialized graph graph_manager.save_checkpoint() # training loop steps = 0 graph_manager.memory_backend = deepracer_memory.DeepRacerTrainerBackEnd( memory_backend_params) # To handle SIGTERM door_man = DoorMan() try: while steps < graph_manager.improve_steps.num_steps: graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.fetch_from_worker( graph_manager.agent_params.algorithm. num_consecutive_playing_steps) graph_manager.phase = core_types.RunPhase.UNDEFINED if graph_manager.should_train(): steps += 1 graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train() graph_manager.phase = core_types.RunPhase.UNDEFINED # Check for Nan's in all agents rollout_has_nan = False for level in graph_manager.level_managers: for agent in level.agents.values(): if np.isnan(agent.loss.get_mean()): rollout_has_nan = True #! TODO handle NaN's on a per agent level for distributed training if rollout_has_nan: utils.json_format_logger( "NaN detected in loss function, aborting training. Job failed!", **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503)) sys.exit(1) if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: graph_manager.save_checkpoint() else: graph_manager.occasionally_save_checkpoint() # Clear any data stored in signals that is no longer necessary graph_manager.reset_internal_state() if door_man.terminate_now: utils.json_format_logger( "Received SIGTERM. Checkpointing before exiting.", **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) graph_manager.save_checkpoint() break except Exception as e: utils.json_format_logger( "An error occured while training: {}. Job failed!.".format(e), **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503)) traceback.print_exc() sys.exit(1) finally: graph_manager.data_store.upload_finished_file()
def __init__(self): # Create the observation space self.observation_space = spaces.Box(low=0, high=255, shape=(TRAINING_IMAGE_SIZE[1], TRAINING_IMAGE_SIZE[0], 3), dtype=np.uint8) # Create the action space self.action_space = spaces.Box(low=np.array([-1, 0]), high=np.array([+1, +1]), dtype=np.float32) if node_type == SIMULATION_WORKER: # ROS initialization rospy.init_node('rl_coach', anonymous=True) # wait for required services rospy.wait_for_service( '/deepracer_simulation_environment/get_waypoints') rospy.wait_for_service( '/deepracer_simulation_environment/reset_car') rospy.wait_for_service('/gazebo/get_model_state') rospy.wait_for_service('/gazebo/get_link_state') rospy.wait_for_service('/gazebo/clear_joint_forces') self.get_model_state = rospy.ServiceProxy( '/gazebo/get_model_state', GetModelState) self.get_link_state = rospy.ServiceProxy('/gazebo/get_link_state', GetLinkState) self.clear_forces_client = rospy.ServiceProxy( '/gazebo/clear_joint_forces', JointRequest) self.reset_car_client = rospy.ServiceProxy( '/deepracer_simulation_environment/reset_car', ResetCarSrv) get_waypoints_client = rospy.ServiceProxy( '/deepracer_simulation_environment/get_waypoints', GetWaypointSrv) # Create the publishers for sending speed and steering info to the car self.velocity_pub_dict = OrderedDict() self.steering_pub_dict = OrderedDict() for topic in VELOCITY_TOPICS: self.velocity_pub_dict[topic] = rospy.Publisher(topic, Float64, queue_size=1) for topic in STEERING_TOPICS: self.steering_pub_dict[topic] = rospy.Publisher(topic, Float64, queue_size=1) # Read in parameters self.world_name = rospy.get_param('WORLD_NAME') self.job_type = rospy.get_param('JOB_TYPE') self.aws_region = rospy.get_param('AWS_REGION') self.metrics_s3_bucket = rospy.get_param('METRICS_S3_BUCKET') self.metrics_s3_object_key = rospy.get_param( 'METRICS_S3_OBJECT_KEY') self.metrics = [] self.simulation_job_arn = 'arn:aws:robomaker:' + self.aws_region + ':' + \ rospy.get_param('ROBOMAKER_SIMULATION_JOB_ACCOUNT_ID') + \ ':simulation-job/' + rospy.get_param('AWS_ROBOMAKER_SIMULATION_JOB_ID') if self.job_type == TRAINING_JOB: from custom_files.customer_reward_function import reward_function self.reward_function = reward_function self.metric_name = rospy.get_param('METRIC_NAME') self.metric_namespace = rospy.get_param('METRIC_NAMESPACE') self.training_job_arn = rospy.get_param('TRAINING_JOB_ARN') self.target_number_of_episodes = rospy.get_param( 'NUMBER_OF_EPISODES') self.target_reward_score = rospy.get_param( 'TARGET_REWARD_SCORE') else: from markov.defaults import reward_function self.reward_function = reward_function self.number_of_trials = 0 self.target_number_of_trials = rospy.get_param( 'NUMBER_OF_TRIALS') # Request the waypoints waypoints = None try: resp = get_waypoints_client() waypoints = np.array(resp.waypoints).reshape( resp.row, resp.col) except Exception as ex: utils.json_format_logger( "Unable to retrieve waypoints: {}".format(ex), **utils.build_system_error_dict( utils.SIMAPP_ENVIRONMENT_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) is_loop = np.all(waypoints[0, :] == waypoints[-1, :]) if is_loop: self.center_line = LinearRing(waypoints[:, 0:2]) self.inner_border = LinearRing(waypoints[:, 2:4]) self.outer_border = LinearRing(waypoints[:, 4:6]) self.road_poly = Polygon(self.outer_border, [self.inner_border]) else: self.center_line = LineString(waypoints[:, 0:2]) self.inner_border = LineString(waypoints[:, 2:4]) self.outer_border = LineString(waypoints[:, 4:6]) self.road_poly = Polygon( np.vstack( (self.outer_border, np.flipud(self.inner_border)))) self.center_dists = [ self.center_line.project(Point(p), normalized=True) for p in self.center_line.coords[:-1] ] + [1.0] self.track_length = self.center_line.length # Queue used to maintain image consumption synchronicity self.image_queue = queue.Queue(IMG_QUEUE_BUF_SIZE) rospy.Subscriber('/camera/zed/rgb/image_rect_color', sensor_image, self.callback_image) # Initialize state data self.episodes = 0 self.start_ndist = 0.0 self.reverse_dir = False self.change_start = rospy.get_param( 'CHANGE_START_POSITION', (self.job_type == TRAINING_JOB)) self.alternate_dir = rospy.get_param('ALTERNATE_DRIVING_DIRECTION', False) self.is_simulation_done = False self.steering_angle = 0 self.speed = 0 self.action_taken = 0 self.prev_progress = 0 self.prev_point = Point(0, 0) self.prev_point_2 = Point(0, 0) self.next_state = None self.reward = None self.reward_in_episode = 0 self.done = False self.steps = 0 self.simulation_start_time = 0 self.allow_servo_step_signals = False