예제 #1
0
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
    """
    block until there is a checkpoint in checkpoint_dir
    """
    for i in range(timeout):
        if data_store:
            data_store.load_from_store()

        if has_checkpoint(checkpoint_dir):
            return
        time.sleep(10)

    # one last time
    if has_checkpoint(checkpoint_dir):
        return

    utils.json_format_logger(
        "checkpoint never found in {}, Waited {} seconds. Job failed!".format(
            checkpoint_dir, timeout),
        **utils.build_system_error_dict(
            utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
            utils.SIMAPP_EVENT_ERROR_CODE_503))
    traceback.print_exc()
    raise ValueError(
        ('Waited {timeout} seconds, but checkpoint never found in '
         '{checkpoint_dir}').format(
             timeout=timeout,
             checkpoint_dir=checkpoint_dir,
         ))
예제 #2
0
 def upload_model(self, checkpoint_dir):
     try:
         s3_client = self.get_client()
         num_files = 0
         for root, _, files in os.walk("./" + checkpoint_dir):
             for filename in files:
                 abs_name = os.path.abspath(os.path.join(root, filename))
                 s3_client.upload_file(
                     abs_name, self.bucket, "%s/%s/%s" %
                     (self.s3_prefix, checkpoint_dir, filename))
                 num_files += 1
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Model failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Model failed to upload to {}, {}".format(self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
예제 #3
0
 def write_ip_config(self, ip):
     try:
         s3_client = self.get_client()
         data = {"IP": ip}
         json_blob = json.dumps(data)
         file_handle = io.BytesIO(json_blob.encode())
         file_handle_done = io.BytesIO(b'done')
         s3_client.upload_fileobj(file_handle, self.bucket, self.config_key)
         s3_client.upload_fileobj(file_handle_done, self.bucket,
                                  self.done_file_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Write ip config failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Write ip config failed to upload to {}, {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def download_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.download_file(self.bucket, s3_key, local_path)
         return True
     except botocore.exceptions.ClientError as e:
         # It is possible that the file isn't there in which case we should return fasle and let the client decide the next action
         if e.response['Error']['Code'] == "404":
             return False
         else:
             utils.json_format_logger(
                 "Unable to download {} from {}: {}".format(
                     s3_key, self.bucket, e.response['Error']['Code']),
                 **utils.build_user_error_dict(
                     utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                     utils.SIMAPP_EVENT_ERROR_CODE_400))
             utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to download {} from {}: {}".format(
                 s3_key, self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
    def get_latest_checkpoint(self):
        try:
            filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "latest_ckpt"))
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            while True:
                s3_client = self._get_client()
                state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))

                # wait until lock is removed
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value))
                if "Contents" not in response:
                    try:
                        # fetch checkpoint state file from S3
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(state_file.filename),
                                                Filename=filename)
                    except Exception as e:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue
                else:
                    time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                    continue

                return self._get_current_checkpoint_number(checkpoint_metadata_filepath=filename)

        except Exception as e:
            utils.json_format_logger("Exception [{}] occured while getting latest checkpoint from S3.".format(e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
예제 #6
0
 def callback_image(self, data):
     try:
         self.image_queue.put_nowait(data)
     except queue.Full:
         pass
     except Exception as ex:
         utils.json_format_logger("Error retrieving frame from gazebo: {}".format(ex),
                    **utils.build_system_error_dict(utils.SIMAPP_ENVIRONMENT_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
예제 #7
0
def log_info(message):
    ''' Helper method that logs the exception
        mesage - Message to send to the log
    '''
    json_format_logger(
        message,
        **build_system_error_dict(SIMAPP_MEMORY_BACKEND_EXCEPTION,
                                  SIMAPP_EVENT_ERROR_CODE_500))
예제 #8
0
    def download_model(self, checkpoint_dir):
        s3_client = self.get_client()
        filename = "None"
        try:
            filename = os.path.abspath(
                os.path.join(checkpoint_dir, "checkpoint"))
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)

            while True:
                response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                     Prefix=self._get_s3_key(
                                                         self.lock_file))

                if "Contents" not in response:
                    # If no lock is found, try getting the checkpoint
                    try:
                        s3_client.download_file(
                            Bucket=self.bucket,
                            Key=self._get_s3_key("checkpoint"),
                            Filename=filename)
                    except Exception as e:
                        time.sleep(2)
                        continue
                else:
                    time.sleep(2)
                    continue

                ckpt = CheckpointState()
                if os.path.exists(filename):
                    contents = open(filename, 'r').read()
                    text_format.Merge(contents, ckpt)
                    rel_path = ckpt.model_checkpoint_path
                    checkpoint = int(rel_path.split('_Step')[0])

                    response = s3_client.list_objects_v2(
                        Bucket=self.bucket, Prefix=self._get_s3_key(rel_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            filename = os.path.abspath(
                                os.path.join(
                                    checkpoint_dir, obj["Key"].replace(
                                        self.model_checkpoints_prefix, "")))
                            s3_client.download_file(Bucket=self.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return True

        except Exception as e:
            utils.json_format_logger(
                "{} while downloading the model {} from S3".format(
                    e, filename),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            return False
예제 #9
0
    def download_model(self, checkpoint_dir):
        s3_client = self.get_client()
        logger.info("Downloading pretrained model from %s/%s %s" % (self.bucket, self.model_checkpoints_prefix, checkpoint_dir))
        filename = "None"
        try:
            filename = os.path.abspath(os.path.join(checkpoint_dir, "checkpoint"))
            if not os.path.exists(checkpoint_dir):
                logger.info("Model folder %s does not exist, creating" % filename)
                os.makedirs(checkpoint_dir)

            while True:
                response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                     Prefix=self._get_s3_key(self.lock_file))

                if "Contents" not in response:
                    # If no lock is found, try getting the checkpoint
                    try:
                        key = self._get_s3_key("checkpoint")
                        logger.info("Downloading %s" % key)
                        s3_client.download_file(Bucket=self.bucket,
                                                Key=key,
                                                Filename=filename)
                    except Exception as e:
                        logger.info("Something went wrong, will retry in 2 seconds %s" % e)
                        time.sleep(2)
                        continue
                else:
                    logger.info("Found a lock file %s , waiting" % self._get_s3_key(self.lock_file))
                    time.sleep(2)
                    continue

                ckpt = CheckpointState()
                if os.path.exists(filename):
                    contents = open(filename, 'r').read()
                    text_format.Merge(contents, ckpt)
                    rel_path = ckpt.model_checkpoint_path
                    checkpoint = int(rel_path.split('_Step')[0])

                    response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                         Prefix=self._get_s3_key(rel_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            filename = os.path.abspath(os.path.join(checkpoint_dir,
                                                                    obj["Key"].replace(self.model_checkpoints_prefix,
                                                                                       "")))

                            logger.info("Downloading model file %s" % filename)
                            s3_client.download_file(Bucket=self.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return True

        except Exception as e:
            utils.json_format_logger ("{} while downloading the model {} from S3".format(e, filename),
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
            return False
예제 #10
0
 def upload_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.upload_file(Filename=local_path,
                               Bucket=self.bucket,
                               Key=s3_key)
         return True
     except Exception as e:
         utils.json_format_logger("{} on upload file-{} to s3 bucket-{} key-{}".format(e, local_path, self.bucket, s3_key),
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
         return False
 def _get_current_checkpoint_number(self, checkpoint_metadata_filepath=None):
     try:
         if not os.path.exists(checkpoint_metadata_filepath):
             return None
         with open(checkpoint_metadata_filepath, 'r') as fp:
             data = fp.read()
             return int(data.split('_')[0])
     except Exception as e:
         utils.json_format_logger("Exception[{}] occured while reading checkpoint metadata".format(e),
                                  **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
         raise e
예제 #12
0
 def get_ip(self):
     s3_client = self.get_client()
     self._wait_for_ip_upload()
     try:
         s3_client.download_file(self.bucket, self.config_key, 'ip.json')
         with open("ip.json") as f:
             ip = json.load(f)["IP"]
         return ip
     except Exception as e:
         utils.json_format_logger("Exception [{}] occured, Cannot fetch IP of redis server running in SageMaker. Job failed!".format(e),
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
         sys.exit(1)
예제 #13
0
def download_customer_reward_function(s3_client, reward_file_s3_key):
    reward_function_local_path = os.path.join(CUSTOM_FILES_PATH,
                                              "customer_reward_function.py")
    success_reward_function_download = s3_client.download_file(
        s3_key=reward_file_s3_key, local_path=reward_function_local_path)
    if not success_reward_function_download:
        utils.json_format_logger(
            "Could not download the customer reward function file. Job failed!",
            **utils.build_system_error_dict(
                utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_503))
        traceback.print_exc()
        sys.exit(1)
    def _wait_for_ip_upload(self, timeout=600):
        s3_client = self.get_client()
        time_elapsed = 0

        while time_elapsed < timeout:
            try:
                response = s3_client.list_objects(Bucket=self.bucket,
                                                  Prefix=self.done_file_key)
            except botocore.exceptions.ClientError as e:
                utils.json_format_logger(
                    "Unable to access {}: {}".format(
                        self.bucket, e.response['Error']['Code']),
                    **utils.build_user_error_dict(
                        utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_400))
                utils.simapp_exit_gracefully()
            except Exception as e:
                utils.json_format_logger(
                    "Unable to access {}: {}".format(self.bucket, e),
                    **utils.build_system_error_dict(
                        utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_500))
                utils.simapp_exit_gracefully()

            if "Contents" in response:
                return
            time.sleep(1)
            time_elapsed += 1
            if time_elapsed % 5 == 0:
                logger.info(
                    "Waiting for SageMaker Redis server IP... Time elapsed: %s seconds"
                    % time_elapsed)

        utils.json_format_logger(
            "Timed out while attempting to retrieve the Redis IP",
            **utils.build_system_error_dict(
                utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_500))
        utils.simapp_exit_gracefully()
 def write_simtrace_data(self,jsondata):
     if self.data_state != SIMTRACE_DATA_UPLOAD_UNKNOWN_STATE:
         try:
             csvdata = []
             for key in SIMTRACE_CSV_DATA_HEADER:
                 csvdata.append(jsondata[key])
             self.csvwriter.writerow(csvdata)
             self.total_upload_size += sys.getsizeof(csvdata)
             logger.debug ("csvdata={} size data={} csv={}".format(csvdata, sys.getsizeof(csvdata), sys.getsizeof(self.simtrace_csv_data.getvalue())))
         except Exception as ex:
             utils.json_format_logger("Invalid SIM_TRACE data format , Exception={}. Job failed!".format(ex),
                    **utils.build_system_error_dict(utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
             traceback.print_exc()
             utils.simapp_exit_gracefully()
예제 #16
0
 def callback_image(self, data):
     try:
         self.image_queue.put_nowait(data)
     except queue.Full:
         # Only warn if its the middle of an episode, not during training
         if self.allow_servo_step_signals:
             logger.info("Warning: dropping image due to queue full")
         pass
     except Exception as ex:
         utils.json_format_logger(
             "Error retrieving frame from gazebo: {}".format(ex),
             **utils.build_system_error_dict(
                 utils.SIMAPP_ENVIRONMENT_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
예제 #17
0
    def racecar_reset(self):
        try:
            for joint in EFFORT_JOINTS:
                self.clear_forces_client(joint)
            prev_index, next_index = self.find_prev_next_waypoints(self.start_ndist)
            self.reset_car_client(self.start_ndist, next_index)
            # First clear the queue so that we set the state to the start image
            _ = self.image_queue.get(block=True, timeout=None)
            self.set_next_state()

        except Exception as ex:
            utils.json_format_logger("Unable to reset the car: {}".format(ex),
                         **utils.build_system_error_dict(utils.SIMAPP_ENVIRONMENT_EXCEPTION,
                                                         utils.SIMAPP_EVENT_ERROR_CODE_500))
예제 #18
0
    def _get_current_checkpoint(self):
        try:
            checkpoint_metadata_filepath = os.path.abspath(
                os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME))
            checkpoint = CheckpointState()
            if os.path.exists(checkpoint_metadata_filepath) == False:
                return None

            contents = open(checkpoint_metadata_filepath, 'r').read()
            text_format.Merge(contents, checkpoint)
            return checkpoint
        except Exception as e:
            utils.json_format_logger("Exception[{}] occured while reading checkpoint metadata".format(e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
            raise e
예제 #19
0
    def get_latest_checkpoint(self):
        try:
            filename = os.path.abspath(
                os.path.join(self.params.checkpoint_dir, "latest_ckpt"))
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            while True:
                s3_client = self._get_client()
                # Check if there's a lock file
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(self.params.lock_file))

                if "Contents" not in response:
                    try:
                        # If no lock is found, try getting the checkpoint
                        s3_client.download_file(
                            Bucket=self.params.bucket,
                            Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME),
                            Filename=filename)
                    except Exception as e:
                        logger.info(
                            "Error occured while getting latest checkpoint %s. Waiting."
                            % e)
                        time.sleep(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        )
                        continue
                else:
                    time.sleep(
                        SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                    )
                    continue

                checkpoint = self._get_current_checkpoint(
                    checkpoint_metadata_filepath=filename)
                if checkpoint:
                    checkpoint_number = self._get_checkpoint_number(checkpoint)
                    return checkpoint_number

        except Exception as e:
            utils.json_format_logger(
                "Exception [{}] occured while getting latest checkpoint from S3."
                .format(e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_503))
예제 #20
0
 def _wait_for_ip_upload(self, timeout=600):
     s3_client = self.get_client()
     time_elapsed = 0
     while True:
         response = s3_client.list_objects(Bucket=self.bucket, Prefix=self.done_file_key)
         if "Contents" not in response:
             time.sleep(1)
             time_elapsed += 1
             if time_elapsed % 5 == 0:
                 logger.info ("Waiting for SageMaker Redis server IP... Time elapsed: %s seconds" % time_elapsed)
             if time_elapsed >= timeout:
                 #raise RuntimeError("Cannot retrieve IP of redis server running in SageMaker")
                 utils.json_format_logger("Cannot retrieve IP of redis server running in SageMaker. Job failed!",
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
                 sys.exit(1)
         else:
             return
 def upload_finished_file(self):
     try:
         s3_client = self._get_client()
         s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(SyncFiles.FINISHED.value))
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger("Unable to upload finished file to {}, {}"
                                  .format(self.params.bucket, e.response['Error']['Code']),
                                  **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger("Unable to upload finished file to {}, {}"
                                  .format(self.params.bucket, e),
                                  **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                  utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def upload_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.upload_file(Filename=local_path,
                               Bucket=self.bucket,
                               Key=s3_key)
         return True
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to upload {} to {}: {}".format(
                 s3_key, self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
     except Exception as e:
         utils.json_format_logger(
             "Unable to upload {} to {}: {}".format(s3_key, self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
     return False
 def upload_hyperparameters(self, hyperparams_json):
     try:
         s3_client = self.get_client()
         file_handle = io.BytesIO(hyperparams_json.encode())
         s3_client.upload_fileobj(file_handle, self.bucket,
                                  self.hyperparameters_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Hyperparameters failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Hyperparameters failed to upload to {}, {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
    """
    block until there is a checkpoint in checkpoint_dir
    """
    for i in range(timeout):
        if data_store:
            data_store.load_from_store()

        if has_checkpoint(checkpoint_dir):
            return
        time.sleep(10)

    # one last time
    if has_checkpoint(checkpoint_dir):
        return

    utils.json_format_logger(
        "Checkpoint never found in {}, waited {} seconds.".format(
            checkpoint_dir, timeout),
        **utils.build_system_error_dict(
            utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
            utils.SIMAPP_EVENT_ERROR_CODE_500))
    traceback.print_exc()
    utils.simapp_exit_gracefully()
 def get_ip(self):
     s3_client = self.get_client()
     self._wait_for_ip_upload()
     try:
         s3_client.download_file(self.bucket, self.config_key, 'ip.json')
         with open("ip.json") as f:
             ip = json.load(f)["IP"]
         return ip
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to retrieve redis ip from {}: {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to retrieve redis ip from {}: {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
예제 #26
0
def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()

        while steps < graph_manager.improve_steps.num_steps:
            graph_manager.phase = core_types.RunPhase.TRAIN
            graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
            graph_manager.phase = core_types.RunPhase.UNDEFINED

            episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout()

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

            if graph_manager.should_train():
                # Make sure we have enough data for the requested batches
                rollout_steps = graph_manager.memory_backend.get_rollout_steps()
                if any(rollout_steps.values()) <= 0:
                    utils.json_format_logger("No rollout data retrieved from the rollout worker",
                                             **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                             utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.network_wrappers['main'].batch_size = episode_batch_size

                steps += 1

                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                # Check for Nan's in all agents
                rollout_has_nan = False
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        if np.isnan(agent.loss.get_mean()):
                            rollout_has_nan = True
                if rollout_has_nan:
                    utils.json_format_logger("NaN detected in loss function, aborting training.",
                                             **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                             utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()

                # Clear any data stored in signals that is no longer necessary
                graph_manager.reset_internal_state()

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

            if door_man.terminate_now:
                utils.json_format_logger("Received SIGTERM. Checkpointing before exiting.",
                                         **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                         utils.SIMAPP_EVENT_ERROR_CODE_500))
                graph_manager.save_checkpoint()
                break

    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            utils.log_and_exit("User modified model: {}".format(err),
                               utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_400)
        else:
            utils.log_and_exit("An error occured while training: {}".format(err),
                               utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        utils.log_and_exit("An error occured while training: {}".format(ex),
                           utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                           utils.SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()
    def save_to_store(self):
        try:
            s3_client = self._get_client()

            # Delete any existing lock file
            s3_client.delete_object(Bucket=self.params.bucket,
                                    Key=self._get_s3_key(
                                        self.params.lock_file))

            # We take a lock by writing a lock file to the same location in S3
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(
                                         self.params.lock_file))

            # Start writing the model checkpoints to S3
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)

            checkpoint_file = None
            for root, _, files in os.walk(self.params.checkpoint_dir):
                num_files_uploaded = 0
                for filename in files:
                    # Skip the checkpoint file that has the latest checkpoint number
                    if filename == CHECKPOINT_METADATA_FILENAME:
                        checkpoint_file = (root, filename)
                        continue

                    if not filename.startswith(str(checkpoint_number)):
                        continue

                    # Upload all the other files from the checkpoint directory
                    abs_name = os.path.abspath(os.path.join(root, filename))
                    rel_name = os.path.relpath(abs_name,
                                               self.params.checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(rel_name))
                    num_files_uploaded += 1
            logger.info("Uploaded {} files for checkpoint {}".format(
                num_files_uploaded, checkpoint_number))

            # After all the checkpoint files have been uploaded, we upload the version file.
            abs_name = os.path.abspath(
                os.path.join(checkpoint_file[0], checkpoint_file[1]))
            rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
            s3_client.upload_file(Filename=abs_name,
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(rel_name))

            # Release the lock by deleting the lock file from S3
            s3_client.delete_object(Bucket=self.params.bucket,
                                    Key=self._get_s3_key(
                                        self.params.lock_file))

            # Upload the frozen graph which is used for deployment
            if self.graph_manager:
                utils.write_frozen_graph(self.graph_manager)
                # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint
                iteration_id = self.graph_manager.level_managers[0].agents[
                    'agent'].training_iteration
                frozen_graph_fpath = utils.SM_MODEL_OUTPUT_DIR + "/model.pb"
                frozen_graph_s3_name = "model_%s.pb" % iteration_id
                s3_client.upload_file(
                    Filename=frozen_graph_fpath,
                    Bucket=self.params.bucket,
                    Key=self._get_s3_key(frozen_graph_s3_name))
                logger.info("saved intermediate frozen graph: {}".format(
                    self._get_s3_key(frozen_graph_s3_name)))

            # Clean up old checkpoints
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)
                checkpoint_number_to_delete = checkpoint_number - 4

                # List all the old checkpoint files to be deleted
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(
                        str(checkpoint_number_to_delete) + "_"))
                if "Contents" in response:
                    num_files = 0
                    for obj in response["Contents"]:
                        s3_client.delete_object(Bucket=self.params.bucket,
                                                Key=obj["Key"])
                        num_files += 1
        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to upload checkpoint to {}, {}".format(
                    self.params.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to upload checkpoint to {}, {}".format(
                    self.params.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
    def load_from_store(self, expected_checkpoint_number=-1):
        try:
            filename = os.path.abspath(
                os.path.join(self.params.checkpoint_dir,
                             CHECKPOINT_METADATA_FILENAME))
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            # CT: remove all prior checkpoint files to save local storage space
            if len(os.listdir(self.params.checkpoint_dir)) > 0:
                files_to_remove = os.listdir(self.params.checkpoint_dir)
                utils.json_format_logger('Removing %d old checkpoint files' %
                                         len(files_to_remove))
                for f in files_to_remove:
                    try:
                        os.unlink(
                            os.path.abspath(
                                os.path.join(self.params.checkpoint_dir, f)))
                    except Exception as e:
                        # swallow errors
                        pass

            while True:
                s3_client = self._get_client()
                # Check if there's a finished file
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(SyncFiles.FINISHED.value))
                if "Contents" in response:
                    finished_file_path = os.path.abspath(
                        os.path.join(self.params.checkpoint_dir,
                                     SyncFiles.FINISHED.value))
                    s3_client.download_file(Bucket=self.params.bucket,
                                            Key=self._get_s3_key(
                                                SyncFiles.FINISHED.value),
                                            Filename=finished_file_path)
                    return False

                # Check if there's a lock file
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(self.params.lock_file))

                if "Contents" not in response:
                    try:
                        # If no lock is found, try getting the checkpoint
                        s3_client.download_file(
                            Bucket=self.params.bucket,
                            Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME),
                            Filename=filename)
                    except Exception as e:
                        utils.json_format_logger(
                            "Sleeping {} seconds while lock file is present".
                            format(
                                SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                            ))
                        time.sleep(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        )
                        continue
                else:
                    utils.json_format_logger(
                        "Sleeping {} seconds while lock file is present".
                        format(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        ))
                    time.sleep(
                        SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                    )
                    continue

                checkpoint = self._get_current_checkpoint()
                if checkpoint:
                    checkpoint_number = self._get_checkpoint_number(checkpoint)

                    # if we get a checkpoint that is older that the expected checkpoint, we wait for
                    #  the new checkpoint to arrive.
                    if checkpoint_number < expected_checkpoint_number:
                        time.sleep(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        )
                        continue

                    # Found a checkpoint to be downloaded
                    response = s3_client.list_objects_v2(
                        Bucket=self.params.bucket,
                        Prefix=self._get_s3_key(
                            checkpoint.model_checkpoint_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            # Get the local filename of the checkpoint file
                            full_key_prefix = os.path.normpath(
                                self.key_prefix) + "/"
                            filename = os.path.abspath(
                                os.path.join(
                                    self.params.checkpoint_dir,
                                    obj["Key"].replace(full_key_prefix, "")))
                            s3_client.download_file(Bucket=self.params.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return True

        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to download checkpoint from {}, {}".format(
                    self.params.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to download checkpoint from {}, {}".format(
                    self.params.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
예제 #29
0
def training_worker(graph_manager, checkpoint_dir, use_pretrained_model,
                    framework, memory_backend_params):
    """
    restore a checkpoint then perform rollouts using the restored model
    """
    # initialize graph
    task_parameters = TaskParameters()
    task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
    task_parameters.__dict__['checkpoint_save_secs'] = 20
    task_parameters.__dict__['experiment_path'] = SM_MODEL_OUTPUT_DIR

    if framework.lower() == "mxnet":
        task_parameters.framework_type = Frameworks.mxnet
        if hasattr(graph_manager, 'agent_params'):
            for network_parameters in graph_manager.agent_params.network_wrappers.values(
            ):
                network_parameters.framework = Frameworks.mxnet
        elif hasattr(graph_manager, 'agents_params'):
            for ap in graph_manager.agents_params:
                for network_parameters in ap.network_wrappers.values():
                    network_parameters.framework = Frameworks.mxnet

    if use_pretrained_model:
        task_parameters.__dict__[
            'checkpoint_restore_dir'] = PRETRAINED_MODEL_DIR

    graph_manager.create_graph(task_parameters)

    # save randomly initialized graph
    graph_manager.save_checkpoint()

    # training loop
    steps = 0

    graph_manager.memory_backend = deepracer_memory.DeepRacerTrainerBackEnd(
        memory_backend_params)

    # To handle SIGTERM
    door_man = DoorMan()

    try:
        while steps < graph_manager.improve_steps.num_steps:
            graph_manager.phase = core_types.RunPhase.TRAIN
            graph_manager.fetch_from_worker(
                graph_manager.agent_params.algorithm.
                num_consecutive_playing_steps)
            graph_manager.phase = core_types.RunPhase.UNDEFINED

            if graph_manager.should_train():
                steps += 1

                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                # Check for Nan's in all agents
                rollout_has_nan = False
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        if np.isnan(agent.loss.get_mean()):
                            rollout_has_nan = True
                #! TODO handle NaN's on a per agent level for distributed training
                if rollout_has_nan:
                    utils.json_format_logger(
                        "NaN detected in loss function, aborting training. Job failed!",
                        **utils.build_system_error_dict(
                            utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                            utils.SIMAPP_EVENT_ERROR_CODE_503))
                    sys.exit(1)

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()
                # Clear any data stored in signals that is no longer necessary
                graph_manager.reset_internal_state()

            if door_man.terminate_now:
                utils.json_format_logger(
                    "Received SIGTERM. Checkpointing before exiting.",
                    **utils.build_system_error_dict(
                        utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_500))
                graph_manager.save_checkpoint()
                break

    except Exception as e:
        utils.json_format_logger(
            "An error occured while training: {}. Job failed!.".format(e),
            **utils.build_system_error_dict(
                utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_503))
        traceback.print_exc()
        sys.exit(1)
    finally:
        graph_manager.data_store.upload_finished_file()
    def __init__(self):

        # Create the observation space
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(TRAINING_IMAGE_SIZE[1],
                                                   TRAINING_IMAGE_SIZE[0], 3),
                                            dtype=np.uint8)
        # Create the action space
        self.action_space = spaces.Box(low=np.array([-1, 0]),
                                       high=np.array([+1, +1]),
                                       dtype=np.float32)

        if node_type == SIMULATION_WORKER:

            # ROS initialization
            rospy.init_node('rl_coach', anonymous=True)

            # wait for required services
            rospy.wait_for_service(
                '/deepracer_simulation_environment/get_waypoints')
            rospy.wait_for_service(
                '/deepracer_simulation_environment/reset_car')
            rospy.wait_for_service('/gazebo/get_model_state')
            rospy.wait_for_service('/gazebo/get_link_state')
            rospy.wait_for_service('/gazebo/clear_joint_forces')

            self.get_model_state = rospy.ServiceProxy(
                '/gazebo/get_model_state', GetModelState)
            self.get_link_state = rospy.ServiceProxy('/gazebo/get_link_state',
                                                     GetLinkState)
            self.clear_forces_client = rospy.ServiceProxy(
                '/gazebo/clear_joint_forces', JointRequest)
            self.reset_car_client = rospy.ServiceProxy(
                '/deepracer_simulation_environment/reset_car', ResetCarSrv)
            get_waypoints_client = rospy.ServiceProxy(
                '/deepracer_simulation_environment/get_waypoints',
                GetWaypointSrv)

            # Create the publishers for sending speed and steering info to the car
            self.velocity_pub_dict = OrderedDict()
            self.steering_pub_dict = OrderedDict()

            for topic in VELOCITY_TOPICS:
                self.velocity_pub_dict[topic] = rospy.Publisher(topic,
                                                                Float64,
                                                                queue_size=1)

            for topic in STEERING_TOPICS:
                self.steering_pub_dict[topic] = rospy.Publisher(topic,
                                                                Float64,
                                                                queue_size=1)

            # Read in parameters
            self.world_name = rospy.get_param('WORLD_NAME')
            self.job_type = rospy.get_param('JOB_TYPE')
            self.aws_region = rospy.get_param('AWS_REGION')
            self.metrics_s3_bucket = rospy.get_param('METRICS_S3_BUCKET')
            self.metrics_s3_object_key = rospy.get_param(
                'METRICS_S3_OBJECT_KEY')
            self.metrics = []
            self.simulation_job_arn = 'arn:aws:robomaker:' + self.aws_region + ':' + \
                                      rospy.get_param('ROBOMAKER_SIMULATION_JOB_ACCOUNT_ID') + \
                                      ':simulation-job/' + rospy.get_param('AWS_ROBOMAKER_SIMULATION_JOB_ID')

            if self.job_type == TRAINING_JOB:
                from custom_files.customer_reward_function import reward_function
                self.reward_function = reward_function
                self.metric_name = rospy.get_param('METRIC_NAME')
                self.metric_namespace = rospy.get_param('METRIC_NAMESPACE')
                self.training_job_arn = rospy.get_param('TRAINING_JOB_ARN')
                self.target_number_of_episodes = rospy.get_param(
                    'NUMBER_OF_EPISODES')
                self.target_reward_score = rospy.get_param(
                    'TARGET_REWARD_SCORE')
            else:
                from markov.defaults import reward_function
                self.reward_function = reward_function
                self.number_of_trials = 0
                self.target_number_of_trials = rospy.get_param(
                    'NUMBER_OF_TRIALS')

            # Request the waypoints
            waypoints = None
            try:
                resp = get_waypoints_client()
                waypoints = np.array(resp.waypoints).reshape(
                    resp.row, resp.col)
            except Exception as ex:
                utils.json_format_logger(
                    "Unable to retrieve waypoints: {}".format(ex),
                    **utils.build_system_error_dict(
                        utils.SIMAPP_ENVIRONMENT_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_500))

            is_loop = np.all(waypoints[0, :] == waypoints[-1, :])
            if is_loop:
                self.center_line = LinearRing(waypoints[:, 0:2])
                self.inner_border = LinearRing(waypoints[:, 2:4])
                self.outer_border = LinearRing(waypoints[:, 4:6])
                self.road_poly = Polygon(self.outer_border,
                                         [self.inner_border])
            else:
                self.center_line = LineString(waypoints[:, 0:2])
                self.inner_border = LineString(waypoints[:, 2:4])
                self.outer_border = LineString(waypoints[:, 4:6])
                self.road_poly = Polygon(
                    np.vstack(
                        (self.outer_border, np.flipud(self.inner_border))))
            self.center_dists = [
                self.center_line.project(Point(p), normalized=True)
                for p in self.center_line.coords[:-1]
            ] + [1.0]
            self.track_length = self.center_line.length
            # Queue used to maintain image consumption synchronicity
            self.image_queue = queue.Queue(IMG_QUEUE_BUF_SIZE)
            rospy.Subscriber('/camera/zed/rgb/image_rect_color', sensor_image,
                             self.callback_image)

            # Initialize state data
            self.episodes = 0
            self.start_ndist = 0.0
            self.reverse_dir = False
            self.change_start = rospy.get_param(
                'CHANGE_START_POSITION', (self.job_type == TRAINING_JOB))
            self.alternate_dir = rospy.get_param('ALTERNATE_DRIVING_DIRECTION',
                                                 False)
            self.is_simulation_done = False
            self.steering_angle = 0
            self.speed = 0
            self.action_taken = 0
            self.prev_progress = 0
            self.prev_point = Point(0, 0)
            self.prev_point_2 = Point(0, 0)
            self.next_state = None
            self.reward = None
            self.reward_in_episode = 0
            self.done = False
            self.steps = 0
            self.simulation_start_time = 0
            self.allow_servo_step_signals = False