def upload_model(self, checkpoint_dir): try: s3_client = self.get_client() num_files = 0 for root, _, files in os.walk("./" + checkpoint_dir): for filename in files: abs_name = os.path.abspath(os.path.join(root, filename)) s3_client.upload_file( abs_name, self.bucket, "%s/%s/%s" % (self.s3_prefix, checkpoint_dir, filename)) num_files += 1 except botocore.exceptions.ClientError as e: utils.json_format_logger( "Model failed to upload to {}, {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Model failed to upload to {}, {}".format(self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def write_ip_config(self, ip): try: s3_client = self.get_client() data = {"IP": ip} json_blob = json.dumps(data) file_handle = io.BytesIO(json_blob.encode()) file_handle_done = io.BytesIO(b'done') s3_client.upload_fileobj(file_handle, self.bucket, self.config_key) s3_client.upload_fileobj(file_handle_done, self.bucket, self.done_file_key) except botocore.exceptions.ClientError as e: utils.json_format_logger( "Write ip config failed to upload to {}, {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Write ip config failed to upload to {}, {}".format( self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def store_ip(self, ip_address): try: s3_client = self._get_client() ip_data = {IP_KEY: ip_address} ip_data_json_blob = json.dumps(ip_data) ip_data_file_object = io.BytesIO(ip_data_json_blob.encode()) ip_done_file_object = io.BytesIO(b'done') s3_client.upload_fileobj(ip_data_file_object, self.params.bucket, self.ip_data_key) s3_client.upload_fileobj(ip_done_file_object, self.params.bucket, self.ip_done_key) except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to store ip to {}, {}".format( self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to store ip to {}, {}".format(self.params.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def download_file(self, s3_key, local_path): s3_client = self.get_client() try: s3_client.download_file(self.bucket, s3_key, local_path) return True except botocore.exceptions.ClientError as e: # It is possible that the file isn't there in which case we should return fasle and let the client decide the next action if e.response['Error']['Code'] == "404": return False else: utils.json_format_logger( "Unable to download {} from {}: {}".format( s3_key, self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to download {} from {}: {}".format( s3_key, self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def download_preset_if_present(self, local_path): try: s3_client = self._get_client() response = s3_client.list_objects(Bucket=self.params.bucket, Prefix=self.preset_data_key) # If we don't find a preset, return false if "Contents" not in response: return False success = s3_client.download_file(Bucket=self.params.bucket, Key=self.preset_data_key, Filename=local_path) return success except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to download presets from {}, {}".format( self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to download presets from {}, {}".format( self.params.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def exit_if_trainer_done(checkpoint_dir): '''Helper method that shutsdown the sim app if the trainer is done checkpoint_dir - direcotry where the done file would be downloaded to ''' if should_stop(checkpoint_dir): logger.info("Received termination signal from trainer. Goodbye.") utils.simapp_exit_gracefully()
def exit_if_trainer_done(checkpoint_dir, s3_writer): '''Helper method that shutsdown the sim app if the trainer is done checkpoint_dir - direcotry where the done file would be downloaded to ''' if should_stop(checkpoint_dir): unsubscribe_from_save_mp4 = ServiceProxyWrapper('/racecar/save_mp4/unsubscribe_from_save_mp4', Empty) unsubscribe_from_save_mp4(EmptyRequest()) s3_writer.upload_to_s3() logger.info("Received termination signal from trainer. Goodbye.") utils.simapp_exit_gracefully()
def write_simtrace_data(self,jsondata): if self.data_state != SIMTRACE_DATA_UPLOAD_UNKNOWN_STATE: try: csvdata = [] for key in SIMTRACE_CSV_DATA_HEADER: csvdata.append(jsondata[key]) self.csvwriter.writerow(csvdata) self.total_upload_size += sys.getsizeof(csvdata) logger.debug ("csvdata={} size data={} csv={}".format(csvdata, sys.getsizeof(csvdata), sys.getsizeof(self.simtrace_csv_data.getvalue()))) except Exception as ex: utils.json_format_logger("Invalid SIM_TRACE data format , Exception={}. Job failed!".format(ex), **utils.build_system_error_dict(utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) traceback.print_exc() utils.simapp_exit_gracefully()
def download_customer_reward_function(s3_client, reward_file_s3_key): reward_function_local_path = os.path.join(CUSTOM_FILES_PATH, "customer_reward_function.py") success_reward_function_download = s3_client.download_file( s3_key=reward_file_s3_key, local_path=reward_function_local_path) if not success_reward_function_download: utils.json_format_logger( "Unable to download the reward function code.", **utils.build_user_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) traceback.print_exc() utils.simapp_exit_gracefully()
def upload_finished_file(self): try: s3_client = self._get_client() s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.FINISHED.value)) except botocore.exceptions.ClientError as e: utils.json_format_logger("Unable to upload finished file to {}, {}" .format(self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger("Unable to upload finished file to {}, {}" .format(self.params.bucket, e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def _get_current_checkpoint(self): try: checkpoint_metadata_filepath = os.path.abspath( os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME)) checkpoint = CheckpointState() if not os.path.exists(checkpoint_metadata_filepath): return None contents = open(checkpoint_metadata_filepath, 'r').read() text_format.Merge(contents, checkpoint) return checkpoint except Exception as e: utils.json_format_logger( "Error when reading checkpoint metadata: {}".format(e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def upload_hyperparameters(self, hyperparams_json): try: s3_client = self.get_client() file_handle = io.BytesIO(hyperparams_json.encode()) s3_client.upload_fileobj(file_handle, self.bucket, self.hyperparameters_key) except botocore.exceptions.ClientError as e: utils.json_format_logger( "Hyperparameters failed to upload to {}, {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Hyperparameters failed to upload to {}, {}".format( self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10): """ block until there is a checkpoint in checkpoint_dir """ for i in range(timeout): if data_store: data_store.load_from_store() if has_checkpoint(checkpoint_dir): return time.sleep(10) # one last time if has_checkpoint(checkpoint_dir): return utils.json_format_logger( "Checkpoint never found in {}, waited {} seconds.".format( checkpoint_dir, timeout), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) traceback.print_exc() utils.simapp_exit_gracefully()
def get_ip(self): s3_client = self.get_client() self._wait_for_ip_upload() try: s3_client.download_file(self.bucket, self.config_key, 'ip.json') with open("ip.json") as f: ip = json.load(f)["IP"] return ip except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to retrieve redis ip from {}: {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to retrieve redis ip from {}: {}".format( self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def _wait_for_ip_upload(self, timeout=600): s3_client = self.get_client() time_elapsed = 0 while time_elapsed < timeout: try: response = s3_client.list_objects(Bucket=self.bucket, Prefix=self.done_file_key) except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to access {}: {}".format( self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to access {}: {}".format(self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() if "Contents" in response: return time.sleep(1) time_elapsed += 1 if time_elapsed % 5 == 0: logger.info( "Waiting for SageMaker Redis server IP... Time elapsed: %s seconds" % time_elapsed) utils.json_format_logger( "Timed out while attempting to retrieve the Redis IP", **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def infer_reward_state(self, steering_angle, speed): try: self.set_next_state() except Exception as ex: utils.json_format_logger( "Unable to retrieve image from queue: {}".format(ex), **utils.build_system_error_dict( utils.SIMAPP_ENVIRONMENT_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) # Read model state from Gazebo model_state = utils.gazebo_service_call(self.get_model_state, 'racecar', '') model_orientation = Rotation.from_quat([ model_state.pose.orientation.x, model_state.pose.orientation.y, model_state.pose.orientation.z, model_state.pose.orientation.w ]) model_location = np.array([ model_state.pose.position.x, model_state.pose.position.y, model_state.pose.position.z]) + \ model_orientation.apply(RELATIVE_POSITION_OF_FRONT_OF_CAR) model_point = Point(model_location[0], model_location[1]) model_heading = model_orientation.as_euler('zyx')[0] # Read the wheel locations from Gazebo left_rear_wheel_state = utils.gazebo_service_call( self.get_link_state, 'racecar::left_rear_wheel', '') left_front_wheel_state = utils.gazebo_service_call( self.get_link_state, 'racecar::left_front_wheel', '') right_rear_wheel_state = utils.gazebo_service_call( self.get_link_state, 'racecar::right_rear_wheel', '') right_front_wheel_state = utils.gazebo_service_call( self.get_link_state, 'racecar::right_front_wheel', '') wheel_points = [ Point(left_rear_wheel_state.link_state.pose.position.x, left_rear_wheel_state.link_state.pose.position.y), Point(left_front_wheel_state.link_state.pose.position.x, left_front_wheel_state.link_state.pose.position.y), Point(right_rear_wheel_state.link_state.pose.position.x, right_rear_wheel_state.link_state.pose.position.y), Point(right_front_wheel_state.link_state.pose.position.x, right_front_wheel_state.link_state.pose.position.y) ] # Project the current location onto the center line and find nearest points current_ndist = self.center_line.project(model_point, normalized=True) prev_index, next_index = self.find_prev_next_waypoints(current_ndist) distance_from_prev = model_point.distance( Point(self.center_line.coords[prev_index])) distance_from_next = model_point.distance( Point(self.center_line.coords[next_index])) closest_waypoint_index = ( prev_index, next_index)[distance_from_next < distance_from_prev] # Compute distance from center and road width nearest_point_center = self.center_line.interpolate(current_ndist, normalized=True) nearest_point_inner = self.inner_border.interpolate( self.inner_border.project(nearest_point_center)) nearest_point_outer = self.outer_border.interpolate( self.outer_border.project(nearest_point_center)) distance_from_center = nearest_point_center.distance(model_point) distance_from_inner = nearest_point_inner.distance(model_point) distance_from_outer = nearest_point_outer.distance(model_point) track_width = nearest_point_inner.distance(nearest_point_outer) is_left_of_center = (distance_from_outer < distance_from_inner) if self.reverse_dir \ else (distance_from_inner < distance_from_outer) # Convert current progress to be [0,100] starting at the initial waypoint if self.reverse_dir: current_progress = self.start_ndist - current_ndist else: current_progress = current_ndist - self.start_ndist if current_progress < 0.0: current_progress = current_progress + 1.0 current_progress = 100 * current_progress if current_progress < self.prev_progress: # Either: (1) we wrapped around and have finished the track, delta1 = current_progress + 100 - self.prev_progress # or (2) for some reason the car went backwards (this should be rare) delta2 = self.prev_progress - current_progress current_progress = (self.prev_progress, 100)[delta1 < delta2] # Car is off track if all wheels are outside the borders wheel_on_track = [self.road_poly.contains(p) for p in wheel_points] all_wheels_on_track = all(wheel_on_track) any_wheels_on_track = any(wheel_on_track) # Simulation elapsed time, which may be faster or slower than wall time simulation_time = rospy.get_rostime() # Compute the reward reward = 0.0 if any_wheels_on_track: done = False params = { 'all_wheels_on_track': all_wheels_on_track, 'x': model_point.x, 'y': model_point.y, 'heading': model_heading * 180.0 / math.pi, 'distance_from_center': distance_from_center, 'progress': current_progress, 'steps': self.steps, 'speed': speed, 'steering_angle': steering_angle * 180.0 / math.pi, 'track_width': track_width, 'waypoints': list(self.center_line.coords), 'closest_waypoints': [prev_index, next_index], 'is_left_of_center': is_left_of_center, 'is_reversed': self.reverse_dir, 'action': self.action_taken } try: reward = float(self.reward_function(params)) except Exception as e: utils.json_format_logger( "Error in the reward function: {}".format(e), **utils.build_user_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) traceback.print_exc() utils.simapp_exit_gracefully() else: done = True reward = CRASHED # Reset if the car position hasn't changed in the last 2 steps prev_pnt_dist = min(model_point.distance(self.prev_point), model_point.distance(self.prev_point_2)) if prev_pnt_dist <= 0.0001 and self.steps % NUM_STEPS_TO_CHECK_STUCK == 0: done = True reward = CRASHED # stuck # Simulation jobs are done when progress reaches 100 if current_progress >= 100: done = True # Keep data from the previous step around self.prev_point_2 = self.prev_point self.prev_point = model_point self.prev_progress = current_progress # Set the reward and done flag self.reward = reward self.reward_in_episode += reward self.done = done # Trace logs to help us debug and visualize the training runs # btown TODO: This should be written to S3, not to CWL. logger.info( 'SIM_TRACE_LOG:%d,%d,%.4f,%.4f,%.4f,%.2f,%.2f,%d,%.4f,%s,%s,%.4f,%d,%.2f,%s,%.4f\n' % (self.episodes, self.steps, model_location[0], model_location[1], model_heading, self.steering_angle, self.speed, self.action_taken, self.reward, self.done, all_wheels_on_track, current_progress, closest_waypoint_index, self.track_length, time.time(), float(simulation_time.secs) + float(simulation_time.nsecs) / 1e9)) #build json record of the reward metrics reward_metrics = OrderedDict() reward_metrics['episode'] = self.episodes reward_metrics['steps'] = self.steps reward_metrics['X'] = model_location[0] reward_metrics['Y'] = model_location[1] reward_metrics['yaw'] = model_heading reward_metrics['steer'] = self.steering_angle reward_metrics['throttle'] = self.speed reward_metrics['action'] = self.action_taken reward_metrics['reward'] = self.reward reward_metrics['done'] = self.done reward_metrics['all_wheels_on_track'] = all_wheels_on_track reward_metrics['progress'] = current_progress reward_metrics['closest_waypoint'] = closest_waypoint_index reward_metrics['track_len'] = self.track_length reward_metrics['tstamp'] = time.time() self.simtrace_data.write_simtrace_data(reward_metrics) # Terminate this episode when ready if done and node_type == SIMULATION_WORKER: self.finish_episode(current_progress)
def save_to_store(self): try: s3_client = self._get_client() checkpoint_dir = self.params.checkpoint_dir # remove lock file if it exists s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value)) # acquire lock s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value)) state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir)) ckpt_state = None if state_file.exists(): ckpt_state = state_file.read() checkpoint_file = None num_files_uploaded = 0 for root, _, files in os.walk(checkpoint_dir): for filename in files: if filename == CheckpointStateFile.checkpoint_state_filename: checkpoint_file = (root, filename) continue if filename.startswith(ckpt_state.name): abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, checkpoint_dir) s3_client.upload_file(Filename=abs_name, Bucket=self.params.bucket, Key=self._get_s3_key(rel_name)) num_files_uploaded += 1 logger.info("Uploaded {} files for checkpoint {}".format(num_files_uploaded, ckpt_state.num)) abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) rel_name = os.path.relpath(abs_name, checkpoint_dir) s3_client.upload_file(Filename=abs_name, Bucket=self.params.bucket, Key=self._get_s3_key(rel_name)) # upload Finished if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)): s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.FINISHED.value)) # upload Ready if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)): s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.TRAINER_READY.value)) # release lock s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value)) # Upload the frozen graph which is used for deployment if self.graph_manager: self.write_frozen_graph(self.graph_manager) # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint for agent_params in self.graph_manager.agents_params: iteration_id = self.graph_manager.level_managers[0].agents[agent_params.name].training_iteration frozen_graph_fpath = os.path.join(SM_MODEL_OUTPUT_DIR, agent_params.name, "model.pb") frozen_name = "model_{}.pb".format(iteration_id) frozen_graph_s3_name = frozen_name if len(self.graph_manager.agents_params) == 1 \ else os.path.join(agent_params.name, frozen_name) s3_client.upload_file(Filename=frozen_graph_fpath, Bucket=self.params.bucket, Key=self._get_s3_key(frozen_graph_s3_name)) logger.info("saved intermediate frozen graph: {}".format(self._get_s3_key(frozen_graph_s3_name))) # Clean up old checkpoints if ckpt_state: checkpoint_number_to_delete = ckpt_state.num - NUM_MODELS_TO_KEEP # List all the old checkpoint files to be deleted response = s3_client.list_objects_v2(Bucket=self.params.bucket, Prefix=self._get_s3_key("")) if "Contents" in response: for obj in response["Contents"]: _, basename = os.path.split(obj["Key"]) if basename.startswith("{}_".format(checkpoint_number_to_delete)): s3_client.delete_object(Bucket=self.params.bucket, Key=obj["Key"]) except botocore.exceptions.ClientError as e: utils.json_format_logger("Unable to upload checkpoint to {}, {}" .format(self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger("Unable to upload checkpoint to {}, {}" .format(self.params.bucket, e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def download_model(self, checkpoint_dir): s3_client = self.get_client() filename = "None" try: filename = os.path.abspath( os.path.join(checkpoint_dir, "checkpoint")) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) while True: response = s3_client.list_objects_v2(Bucket=self.bucket, Prefix=self._get_s3_key( self.lock_file)) if "Contents" not in response: # If no lock is found, try getting the checkpoint try: s3_client.download_file( Bucket=self.bucket, Key=self._get_s3_key("checkpoint"), Filename=filename) except Exception as e: time.sleep(2) continue else: time.sleep(2) continue ckpt = CheckpointState() if os.path.exists(filename): contents = open(filename, 'r').read() text_format.Merge(contents, ckpt) rel_path = ckpt.model_checkpoint_path checkpoint = int(rel_path.split('_Step')[0]) response = s3_client.list_objects_v2( Bucket=self.bucket, Prefix=self._get_s3_key(rel_path)) if "Contents" in response: num_files = 0 for obj in response["Contents"]: filename = os.path.abspath( os.path.join( checkpoint_dir, obj["Key"].replace( self.model_checkpoints_prefix, ""))) s3_client.download_file(Bucket=self.bucket, Key=obj["Key"], Filename=filename) num_files += 1 return except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to download model {} from {}: {}".format( filename, self.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to download model {} from {}: {}".format( filename, self.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def load_from_store(self, expected_checkpoint_number=-1): try: filename = os.path.abspath( os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME)) if not os.path.exists(self.params.checkpoint_dir): os.makedirs(self.params.checkpoint_dir) # CT: remove all prior checkpoint files to save local storage space if len(os.listdir(self.params.checkpoint_dir)) > 0: files_to_remove = os.listdir(self.params.checkpoint_dir) utils.json_format_logger('Removing %d old checkpoint files' % len(files_to_remove)) for f in files_to_remove: try: os.unlink( os.path.abspath( os.path.join(self.params.checkpoint_dir, f))) except Exception as e: # swallow errors pass while True: s3_client = self._get_client() # Check if there's a finished file response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key(SyncFiles.FINISHED.value)) if "Contents" in response: finished_file_path = os.path.abspath( os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value)) s3_client.download_file(Bucket=self.params.bucket, Key=self._get_s3_key( SyncFiles.FINISHED.value), Filename=finished_file_path) return False # Check if there's a lock file response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key(self.params.lock_file)) if "Contents" not in response: try: # If no lock is found, try getting the checkpoint s3_client.download_file( Bucket=self.params.bucket, Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME), Filename=filename) except Exception as e: utils.json_format_logger( "Sleeping {} seconds while lock file is present". format( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND )) time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue else: utils.json_format_logger( "Sleeping {} seconds while lock file is present". format( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND )) time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue checkpoint = self._get_current_checkpoint() if checkpoint: checkpoint_number = self._get_checkpoint_number(checkpoint) # if we get a checkpoint that is older that the expected checkpoint, we wait for # the new checkpoint to arrive. if checkpoint_number < expected_checkpoint_number: time.sleep( SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND ) continue # Found a checkpoint to be downloaded response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key( checkpoint.model_checkpoint_path)) if "Contents" in response: num_files = 0 for obj in response["Contents"]: # Get the local filename of the checkpoint file full_key_prefix = os.path.normpath( self.key_prefix) + "/" filename = os.path.abspath( os.path.join( self.params.checkpoint_dir, obj["Key"].replace(full_key_prefix, ""))) s3_client.download_file(Bucket=self.params.bucket, Key=obj["Key"], Filename=filename) num_files += 1 return True except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to download checkpoint from {}, {}".format( self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to download checkpoint from {}, {}".format( self.params.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def training_worker(graph_manager, checkpoint_dir, use_pretrained_model, framework, memory_backend_params, user_batch_size, user_episode_per_rollout): """ restore a checkpoint then perform rollouts using the restored model """ # initialize graph task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir task_parameters.__dict__['checkpoint_save_secs'] = 20 task_parameters.__dict__['experiment_path'] = SM_MODEL_OUTPUT_DIR if framework.lower() == "mxnet": task_parameters.framework_type = Frameworks.mxnet if hasattr(graph_manager, 'agent_params'): for network_parameters in graph_manager.agent_params.network_wrappers.values( ): network_parameters.framework = Frameworks.mxnet elif hasattr(graph_manager, 'agents_params'): for ap in graph_manager.agents_params: for network_parameters in ap.network_wrappers.values(): network_parameters.framework = Frameworks.mxnet if use_pretrained_model: task_parameters.__dict__[ 'checkpoint_restore_dir'] = PRETRAINED_MODEL_DIR graph_manager.create_graph(task_parameters) # save randomly initialized graph graph_manager.save_checkpoint() # training loop steps = 0 graph_manager.memory_backend = deepracer_memory.DeepRacerTrainerBackEnd( memory_backend_params) # To handle SIGTERM door_man = DoorMan() try: while steps < graph_manager.improve_steps.num_steps: graph_manager.phase = RunPhase.TRAIN graph_manager.fetch_from_worker( graph_manager.agent_params.algorithm. num_consecutive_playing_steps) graph_manager.phase = RunPhase.UNDEFINED episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout( ) for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout if graph_manager.should_train(): # Make sure we have enough data for the requested batches rollout_steps = graph_manager.memory_backend.get_rollout_steps( ) if rollout_steps <= 0: utils.json_format_logger( "No rollout data retrieved from the rollout worker", **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() episode_batch_size = user_batch_size if rollout_steps > user_batch_size else 2**math.floor( math.log(rollout_steps, 2)) # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing # as batch size less than 2 causes the batch list to become a scalar which causes an exception for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.network_wrappers[ 'main'].batch_size = episode_batch_size steps += 1 graph_manager.phase = RunPhase.TRAIN graph_manager.train() graph_manager.phase = RunPhase.UNDEFINED # Check for Nan's in all agents rollout_has_nan = False for level in graph_manager.level_managers: for agent in level.agents.values(): if np.isnan(agent.loss.get_mean()): rollout_has_nan = True #! TODO handle NaN's on a per agent level for distributed training if rollout_has_nan: utils.json_format_logger( "NaN detected in loss function, aborting training.", **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: graph_manager.save_checkpoint() else: graph_manager.occasionally_save_checkpoint() # Clear any data stored in signals that is no longer necessary graph_manager.reset_internal_state() for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout if door_man.terminate_now: utils.json_format_logger( "Received SIGTERM. Checkpointing before exiting.", **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) graph_manager.save_checkpoint() break except Exception as e: utils.json_format_logger( "An error occured while training: {}.".format(e), **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) traceback.print_exc() utils.simapp_exit_gracefully() finally: graph_manager.data_store.upload_finished_file()
def main(): screen.set_use_colors(False) try: parser = argparse.ArgumentParser() parser.add_argument( '-pk', '--preset_s3_key', help="(string) Name of a preset to download from S3", type=str, required=False) parser.add_argument( '-ek', '--environment_s3_key', help="(string) Name of an environment file to download from S3", type=str, required=False) parser.add_argument('--model_metadata_s3_key', help="(string) Model Metadata File S3 Key", type=str, required=False) parser.add_argument( '-c', '--checkpoint-dir', help= '(string) Path to a folder containing a checkpoint to write the model to.', type=str, default='./checkpoint') parser.add_argument( '--pretrained-checkpoint-dir', help= '(string) Path to a folder for downloading a pre-trained model', type=str, default=PRETRAINED_MODEL_DIR) parser.add_argument('--s3_bucket', help='(string) S3 bucket', type=str, default=os.environ.get( "SAGEMAKER_SHARED_S3_BUCKET_PATH", "gsaur-test")) parser.add_argument('--s3_prefix', help='(string) S3 prefix', type=str, default='sagemaker') parser.add_argument('--framework', help='(string) tensorflow or mxnet', type=str, default='tensorflow') parser.add_argument('--pretrained_s3_bucket', help='(string) S3 bucket for pre-trained model', type=str) parser.add_argument('--pretrained_s3_prefix', help='(string) S3 prefix for pre-trained model', type=str, default='sagemaker') parser.add_argument('--aws_region', help='(string) AWS region', type=str, default=os.environ.get("AWS_REGION", "us-east-1")) args, unknown = parser.parse_known_args() start_redis_server() s3_client = SageS3Client(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region) # Load the model metadata model_metadata_local_path = os.path.join(CUSTOM_FILES_PATH, 'model_metadata.json') load_model_metadata(s3_client, args.model_metadata_s3_key, model_metadata_local_path) s3_client.upload_file( os.path.normpath("%s/model/model_metadata.json" % args.s3_prefix), model_metadata_local_path) shutil.copy2(model_metadata_local_path, SM_MODEL_OUTPUT_DIR) # Register the gym enviroment, this will give clients the ability to creat the enviroment object register(id=defaults.ENV_ID, entry_point=defaults.ENTRY_POINT, max_episode_steps=defaults.MAX_STEPS, reward_threshold=defaults.THRESHOLD) user_batch_size, user_episode_per_rollout = None, None success_custom_preset = False if args.preset_s3_key: preset_local_path = "./markov/presets/preset.py" success_custom_preset = s3_client.download_file( s3_key=args.preset_s3_key, local_path=preset_local_path) if not success_custom_preset: logger.info( "Could not download the preset file. Using the default DeepRacer preset." ) else: preset_location = "markov.presets.preset:graph_manager" graph_manager = short_dynamic_import(preset_location, ignore_module_case=True) success_custom_preset = s3_client.upload_file( s3_key=os.path.normpath("%s/presets/preset.py" % args.s3_prefix), local_path=preset_local_path) if success_custom_preset: agent_param_loc = "markov.presets.preset:agent_params" agent_params = short_dynamic_import( agent_param_loc, ignore_module_case=True) user_batch_size = agent_params.network_wrappers[ 'main'].batch_size user_episode_per_rollout = agent_params.algorithm.num_consecutive_playing_steps.num_steps logger.info("Using preset: %s" % args.preset_s3_key) if not success_custom_preset: from markov.sagemaker_graph_manager import get_graph_manager user_batch_size = json.loads( robomaker_hyperparams_json)["batch_size"], user_episode_per_rollout = json.loads( robomaker_hyperparams_json)["num_episodes_between_training"] params_blob = os.environ.get('SM_TRAINING_ENV', '') if params_blob: params = json.loads(params_blob) sm_hyperparams_dict = params["hyperparameters"] else: sm_hyperparams_dict = {} graph_manager, robomaker_hyperparams_json = get_graph_manager( **sm_hyperparams_dict) s3_client.upload_hyperparameters(robomaker_hyperparams_json) logger.info("Uploaded hyperparameters.json to S3") host_ip_address = get_ip_from_host() s3_client.write_ip_config(host_ip_address) logger.info("Uploaded IP address information to S3: %s" % host_ip_address) use_pretrained_model = args.pretrained_s3_bucket and args.pretrained_s3_prefix if use_pretrained_model: s3_client_pretrained = SageS3Client( bucket=args.pretrained_s3_bucket, s3_prefix=args.pretrained_s3_prefix, aws_region=args.aws_region) s3_client_pretrained.download_model(args.pretrained_checkpoint_dir) memory_backend_params = RedisPubSubMemoryBackendParameters( redis_address="localhost", redis_port=6379, run_type='trainer', channel=args.s3_prefix) ds_params_instance = S3BotoDataStoreParameters( bucket_name=args.s3_bucket, checkpoint_dir=args.checkpoint_dir, aws_region=args.aws_region, s3_folder=args.s3_prefix) graph_manager.data_store_params = ds_params_instance data_store = S3BotoDataStore(ds_params_instance) data_store.graph_manager = graph_manager graph_manager.data_store = data_store training_worker(graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, use_pretrained_model=use_pretrained_model, framework=args.framework, memory_backend_params=memory_backend_params, user_batch_size=user_batch_size, user_episode_per_rollout=user_episode_per_rollout) except Exception as ex: utils.json_format_logger( "Training worker exited with exception: {}".format(ex), **utils.build_system_error_dict( utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def main(): screen.set_use_colors(False) parser = argparse.ArgumentParser() parser.add_argument( '-c', '--checkpoint_dir', help= '(string) Path to a folder containing a checkpoint to restore the model from.', type=str, default='./checkpoint') parser.add_argument('--s3_bucket', help='(string) S3 bucket', type=str, default=os.environ.get("SAGEMAKER_SHARED_S3_BUCKET", "gsaur-test")) parser.add_argument('--s3_prefix', help='(string) S3 prefix', type=str, default=os.environ.get("SAGEMAKER_SHARED_S3_PREFIX", "sagemaker")) parser.add_argument( '--num-workers', help="(int) The number of workers started in this pool", type=int, default=1) parser.add_argument('-r', '--redis_ip', help="(string) IP or host for the redis server", default='localhost', type=str) parser.add_argument('-rp', '--redis_port', help="(int) Port of the redis server", default=6379, type=int) parser.add_argument('--aws_region', help='(string) AWS region', type=str, default=os.environ.get("APP_REGION", "us-east-1")) parser.add_argument('--reward_file_s3_key', help='(string) Reward File S3 Key', type=str, default=os.environ.get("REWARD_FILE_S3_KEY", None)) parser.add_argument('--model_metadata_s3_key', help='(string) Model Metadata File S3 Key', type=str, default=os.environ.get("MODEL_METADATA_FILE_S3_KEY", None)) parser.add_argument('--aws_endpoint_url', help='(string) AWS region', type=str, default=os.environ.get("AWS_ENDPOINT_URL", None)) args = parser.parse_args() s3_client = SageS3Client(bucket=args.s3_bucket, s3_prefix=args.s3_prefix, aws_region=args.aws_region, endpoint_url=args.aws_endpoint_url) logger.info("S3 bucket: %s" % args.s3_bucket) logger.info("S3 prefix: %s" % args.s3_prefix) # Load the model metadata model_metadata_local_path = os.path.join(CUSTOM_FILES_PATH, 'model_metadata.json') load_model_metadata(s3_client, args.model_metadata_s3_key, model_metadata_local_path) # Download reward function if not args.reward_file_s3_key: utils.json_format_logger( "Reward function code S3 key not available for S3 bucket {} and prefix {}" .format(args.s3_bucket, args.s3_prefix), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) traceback.print_exc() utils.simapp_exit_gracefully() download_customer_reward_function(s3_client, args.reward_file_s3_key) # Register the gym enviroment, this will give clients the ability to creat the enviroment object register(id=defaults.ENV_ID, entry_point=defaults.ENTRY_POINT, max_episode_steps=defaults.MAX_STEPS, reward_threshold=defaults.THRESHOLD) redis_ip = s3_client.get_ip() logger.info("Received IP from SageMaker successfully: %s" % redis_ip) # Download hyperparameters from SageMaker hyperparameters_file_success = False hyperparams_s3_key = os.path.normpath(args.s3_prefix + "/ip/hyperparameters.json") hyperparameters_file_success = s3_client.download_file( s3_key=hyperparams_s3_key, local_path="hyperparameters.json") sm_hyperparams_dict = {} if hyperparameters_file_success: logger.info("Received Sagemaker hyperparameters successfully!") with open("hyperparameters.json") as fp: sm_hyperparams_dict = json.load(fp) else: logger.info("SageMaker hyperparameters not found.") preset_file_success, _ = download_custom_files_if_present( s3_client, args.s3_prefix) if preset_file_success: preset_location = os.path.join(CUSTOM_FILES_PATH, "preset.py") preset_location += ":graph_manager" graph_manager = short_dynamic_import(preset_location, ignore_module_case=True) logger.info("Using custom preset file!") else: from markov.sagemaker_graph_manager import get_graph_manager graph_manager, _ = get_graph_manager(**sm_hyperparams_dict) logger.info("Connecting to redis at %s:%d" % (redis_ip, args.redis_port)) memory_backend_params = RedisPubSubMemoryBackendParameters( redis_address=redis_ip, redis_port=6379, run_type='worker', channel=args.s3_prefix) logger.info("Connecting to s3 boto data store at %s" % args.aws_endpoint_url) ds_params_instance = S3BotoDataStoreParameters( bucket_name=args.s3_bucket, checkpoint_dir=args.checkpoint_dir, aws_region=args.aws_region, s3_folder=args.s3_prefix, aws_endpoint_url=args.aws_endpoint_url) data_store = S3BotoDataStore(ds_params_instance) data_store.graph_manager = graph_manager graph_manager.data_store = data_store rollout_worker(graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, data_store=data_store, num_workers=args.num_workers, memory_backend_params=memory_backend_params)
def should_stop(checkpoint_dir): if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)): logger.info("Received termination signal from trainer. Goodbye.") utils.simapp_exit_gracefully() return True return False
logger.info("Connecting to s3 boto data store at %s" % args.aws_endpoint_url) ds_params_instance = S3BotoDataStoreParameters( bucket_name=args.s3_bucket, checkpoint_dir=args.checkpoint_dir, aws_region=args.aws_region, s3_folder=args.s3_prefix, aws_endpoint_url=args.aws_endpoint_url) data_store = S3BotoDataStore(ds_params_instance) data_store.graph_manager = graph_manager graph_manager.data_store = data_store rollout_worker(graph_manager=graph_manager, checkpoint_dir=args.checkpoint_dir, data_store=data_store, num_workers=args.num_workers, memory_backend_params=memory_backend_params) if __name__ == '__main__': try: main() except Exception as ex: utils.json_format_logger( "Rollout worker exited with exception: {}".format(ex), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def save_to_store(self): try: s3_client = self._get_client() # Delete any existing lock file s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key( self.params.lock_file)) # We take a lock by writing a lock file to the same location in S3 s3_client.upload_fileobj(Fileobj=io.BytesIO(b''), Bucket=self.params.bucket, Key=self._get_s3_key( self.params.lock_file)) # Start writing the model checkpoints to S3 checkpoint = self._get_current_checkpoint() if checkpoint: checkpoint_number = self._get_checkpoint_number(checkpoint) checkpoint_file = None for root, _, files in os.walk(self.params.checkpoint_dir): num_files_uploaded = 0 for filename in files: # Skip the checkpoint file that has the latest checkpoint number if filename == CHECKPOINT_METADATA_FILENAME: checkpoint_file = (root, filename) continue if not filename.startswith(str(checkpoint_number)): continue # Upload all the other files from the checkpoint directory abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) s3_client.upload_file(Filename=abs_name, Bucket=self.params.bucket, Key=self._get_s3_key(rel_name)) num_files_uploaded += 1 logger.info("Uploaded {} files for checkpoint {}".format( num_files_uploaded, checkpoint_number)) # After all the checkpoint files have been uploaded, we upload the version file. abs_name = os.path.abspath( os.path.join(checkpoint_file[0], checkpoint_file[1])) rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir) s3_client.upload_file(Filename=abs_name, Bucket=self.params.bucket, Key=self._get_s3_key(rel_name)) # Release the lock by deleting the lock file from S3 s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key( self.params.lock_file)) # Upload the frozen graph which is used for deployment if self.graph_manager: utils.write_frozen_graph(self.graph_manager) # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint iteration_id = self.graph_manager.level_managers[0].agents[ 'agent'].training_iteration frozen_graph_fpath = utils.SM_MODEL_OUTPUT_DIR + "/model.pb" frozen_graph_s3_name = "model_%s.pb" % iteration_id s3_client.upload_file( Filename=frozen_graph_fpath, Bucket=self.params.bucket, Key=self._get_s3_key(frozen_graph_s3_name)) logger.info("saved intermediate frozen graph: {}".format( self._get_s3_key(frozen_graph_s3_name))) # Clean up old checkpoints checkpoint = self._get_current_checkpoint() if checkpoint: checkpoint_number = self._get_checkpoint_number(checkpoint) checkpoint_number_to_delete = checkpoint_number - 4 # List all the old checkpoint files to be deleted response = s3_client.list_objects_v2( Bucket=self.params.bucket, Prefix=self._get_s3_key( str(checkpoint_number_to_delete) + "_")) if "Contents" in response: num_files = 0 for obj in response["Contents"]: s3_client.delete_object(Bucket=self.params.bucket, Key=obj["Key"]) num_files += 1 except botocore.exceptions.ClientError as e: utils.json_format_logger( "Unable to upload checkpoint to {}, {}".format( self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger( "Unable to upload checkpoint to {}, {}".format( self.params.bucket, e), **utils.build_system_error_dict( utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()
def training_worker(graph_manager, task_parameters, user_batch_size, user_episode_per_rollout): try: # initialize graph graph_manager.create_graph(task_parameters) # save initial checkpoint graph_manager.save_checkpoint() # training loop steps = 0 graph_manager.setup_memory_backend() graph_manager.signal_ready() # To handle SIGTERM door_man = utils.DoorMan() while steps < graph_manager.improve_steps.num_steps: graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps) graph_manager.phase = core_types.RunPhase.UNDEFINED episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout() for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout if graph_manager.should_train(): # Make sure we have enough data for the requested batches rollout_steps = graph_manager.memory_backend.get_rollout_steps() if any(rollout_steps.values()) <= 0: utils.json_format_logger("No rollout data retrieved from the rollout worker", **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2)) # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing # as batch size less than 2 causes the batch list to become a scalar which causes an exception for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.network_wrappers['main'].batch_size = episode_batch_size steps += 1 graph_manager.phase = core_types.RunPhase.TRAIN graph_manager.train() graph_manager.phase = core_types.RunPhase.UNDEFINED # Check for Nan's in all agents rollout_has_nan = False for level in graph_manager.level_managers: for agent in level.agents.values(): if np.isnan(agent.loss.get_mean()): rollout_has_nan = True if rollout_has_nan: utils.json_format_logger("NaN detected in loss function, aborting training.", **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: graph_manager.save_checkpoint() else: graph_manager.occasionally_save_checkpoint() # Clear any data stored in signals that is no longer necessary graph_manager.reset_internal_state() for level in graph_manager.level_managers: for agent in level.agents.values(): agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout if door_man.terminate_now: utils.json_format_logger("Received SIGTERM. Checkpointing before exiting.", **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) graph_manager.save_checkpoint() break except ValueError as err: if utils.is_error_bad_ckpnt(err): utils.log_and_exit("User modified model: {}".format(err), utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400) else: utils.log_and_exit("An error occured while training: {}".format(err), utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500) except Exception as ex: utils.log_and_exit("An error occured while training: {}".format(ex), utils.SIMAPP_TRAINING_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500) finally: graph_manager.data_store.upload_finished_file()
def load_from_store(self, expected_checkpoint_number=-1): try: if not os.path.exists(self.params.checkpoint_dir): os.makedirs(self.params.checkpoint_dir) while True: s3_client = self._get_client() state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) # wait until lock is removed response = s3_client.list_objects_v2(Bucket=self.params.bucket, Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value)) if "Contents" not in response: try: # fetch checkpoint state file from S3 s3_client.download_file(Bucket=self.params.bucket, Key=self._get_s3_key(state_file.filename), Filename=state_file.path) except Exception as e: time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND) continue else: time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND) continue # check if there's a Finished file response = s3_client.list_objects_v2(Bucket=self.params.bucket, Prefix=self._get_s3_key(SyncFiles.FINISHED.value)) if "Contents" in response: try: finished_file_path = os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value)) s3_client.download_file(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.FINISHED.value), Filename=finished_file_path) except Exception as e: pass # check if there's a Ready file response = s3_client.list_objects_v2(Bucket=self.params.bucket, Prefix=self._get_s3_key(SyncFiles.TRAINER_READY.value)) if "Contents" in response: try: ready_file_path = os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value)) s3_client.download_file(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.TRAINER_READY.value), Filename=ready_file_path) except Exception as e: pass checkpoint_state = state_file.read() if checkpoint_state is not None: # if we get a checkpoint that is older that the expected checkpoint, we wait for # the new checkpoint to arrive. if checkpoint_state.num < expected_checkpoint_number: time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND) continue response = s3_client.list_objects_v2(Bucket=self.params.bucket, Prefix=self._get_s3_key("")) if "Contents" in response: # Check to see if the desired checkpoint is in the bucket has_chkpnt = any(list(map(lambda obj: os.path.split(obj['Key'])[1].\ startswith(checkpoint_state.name), response['Contents']))) for obj in response["Contents"]: full_key_prefix = os.path.normpath(self.key_prefix) + "/" filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj["Key"].\ replace(full_key_prefix, ""))) dirname, basename = os.path.split(filename) # Download all the checkpoints but not the frozen models since they # are not necessary _, file_extension = os.path.splitext(obj["Key"]) if file_extension != '.pb' \ and (basename.startswith(checkpoint_state.name) or not has_chkpnt): if not os.path.exists(dirname): os.makedirs(dirname) s3_client.download_file(Bucket=self.params.bucket, Key=obj["Key"], Filename=filename) # Change the coach checkpoint file to point to the latest available checkpoint, # also log that we are changing the checkpoint. if not has_chkpnt: all_ckpnts = _filter_checkpoint_files(os.listdir(self.params.checkpoint_dir)) if all_ckpnts: logger.info("%s not in s3 bucket, downloading all checkpoints \ and using %s", checkpoint_state.name, all_ckpnts[-1]) state_file.write(all_ckpnts[-1]) else: utils.json_format_logger("No checkpoint files found in {}".format(self.params.bucket), **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() return True except botocore.exceptions.ClientError as e: utils.json_format_logger("Unable to download checkpoint from {}, {}" .format(self.params.bucket, e.response['Error']['Code']), **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_400)) utils.simapp_exit_gracefully() except Exception as e: utils.json_format_logger("Unable to download checkpoint from {}, {}" .format(self.params.bucket, e), **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) utils.simapp_exit_gracefully()