def upload_model(self, checkpoint_dir):
     try:
         s3_client = self.get_client()
         num_files = 0
         for root, _, files in os.walk("./" + checkpoint_dir):
             for filename in files:
                 abs_name = os.path.abspath(os.path.join(root, filename))
                 s3_client.upload_file(
                     abs_name, self.bucket, "%s/%s/%s" %
                     (self.s3_prefix, checkpoint_dir, filename))
                 num_files += 1
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Model failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Model failed to upload to {}, {}".format(self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def write_ip_config(self, ip):
     try:
         s3_client = self.get_client()
         data = {"IP": ip}
         json_blob = json.dumps(data)
         file_handle = io.BytesIO(json_blob.encode())
         file_handle_done = io.BytesIO(b'done')
         s3_client.upload_fileobj(file_handle, self.bucket, self.config_key)
         s3_client.upload_fileobj(file_handle_done, self.bucket,
                                  self.done_file_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Write ip config failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Write ip config failed to upload to {}, {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
Ejemplo n.º 3
0
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
    """
    block until there is a checkpoint in checkpoint_dir
    """
    for i in range(timeout):
        if data_store:
            data_store.load_from_store()

        if has_checkpoint(checkpoint_dir):
            return
        time.sleep(10)

    # one last time
    if has_checkpoint(checkpoint_dir):
        return

    utils.json_format_logger(
        "checkpoint never found in {}, Waited {} seconds. Job failed!".format(
            checkpoint_dir, timeout),
        **utils.build_system_error_dict(
            utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
            utils.SIMAPP_EVENT_ERROR_CODE_503))
    traceback.print_exc()
    raise ValueError(
        ('Waited {timeout} seconds, but checkpoint never found in '
         '{checkpoint_dir}').format(
             timeout=timeout,
             checkpoint_dir=checkpoint_dir,
         ))
 def download_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.download_file(self.bucket, s3_key, local_path)
         return True
     except botocore.exceptions.ClientError as e:
         # It is possible that the file isn't there in which case we should return fasle and let the client decide the next action
         if e.response['Error']['Code'] == "404":
             return False
         else:
             utils.json_format_logger(
                 "Unable to download {} from {}: {}".format(
                     s3_key, self.bucket, e.response['Error']['Code']),
                 **utils.build_user_error_dict(
                     utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                     utils.SIMAPP_EVENT_ERROR_CODE_400))
             utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to download {} from {}: {}".format(
                 s3_key, self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def store_ip(self, ip_address):
     try:
         s3_client = self._get_client()
         ip_data = {IP_KEY: ip_address}
         ip_data_json_blob = json.dumps(ip_data)
         ip_data_file_object = io.BytesIO(ip_data_json_blob.encode())
         ip_done_file_object = io.BytesIO(b'done')
         s3_client.upload_fileobj(ip_data_file_object, self.params.bucket,
                                  self.ip_data_key)
         s3_client.upload_fileobj(ip_done_file_object, self.params.bucket,
                                  self.ip_done_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to store ip to {}, {}".format(
                 self.params.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to store ip to {}, {}".format(self.params.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
    def download_preset_if_present(self, local_path):
        try:
            s3_client = self._get_client()
            response = s3_client.list_objects(Bucket=self.params.bucket,
                                              Prefix=self.preset_data_key)

            # If we don't find a preset, return false
            if "Contents" not in response:
                return False

            success = s3_client.download_file(Bucket=self.params.bucket,
                                              Key=self.preset_data_key,
                                              Filename=local_path)
            return success
        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to download presets from {}, {}".format(
                    self.params.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to download presets from {}, {}".format(
                    self.params.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
Ejemplo n.º 7
0
 def download_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.download_file(self.bucket, s3_key, local_path)
         return True
     except botocore.exceptions.ClientError as e:
         if e.response['Error']['Code'] == "404":
             logger.info(
                 "Exception [{}] occured on download file-{} from s3 bucket-{} key-{}"
                 .format(e.response['Error'], local_path, self.bucket,
                         s3_key))
             return False
         else:
             utils.json_format_logger(
                 "boto client exception error [{}] occured on download file-{} from s3 bucket-{} key-{}"
                 .format(e.response['Error'], local_path, self.bucket,
                         s3_key),
                 **utils.build_user_error_dict(
                     utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                     utils.SIMAPP_EVENT_ERROR_CODE_401))
             return False
     except Exception as e:
         utils.json_format_logger(
             "Exception [{}] occcured on download file-{} from s3 bucket-{} key-{}"
             .format(e, local_path, self.bucket, s3_key),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_401))
         return False
    def get_latest_checkpoint(self):
        try:
            filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "latest_ckpt"))
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            while True:
                s3_client = self._get_client()
                state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))

                # wait until lock is removed
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value))
                if "Contents" not in response:
                    try:
                        # fetch checkpoint state file from S3
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(state_file.filename),
                                                Filename=filename)
                    except Exception as e:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue
                else:
                    time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                    continue

                return self._get_current_checkpoint_number(checkpoint_metadata_filepath=filename)

        except Exception as e:
            utils.json_format_logger("Exception [{}] occured while getting latest checkpoint from S3.".format(e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
Ejemplo n.º 9
0
def log_info(message):
    ''' Helper method that logs the exception
        mesage - Message to send to the log
    '''
    json_format_logger(
        message,
        **build_system_error_dict(SIMAPP_MEMORY_BACKEND_EXCEPTION,
                                  SIMAPP_EVENT_ERROR_CODE_500))
Ejemplo n.º 10
0
 def callback_image(self, data):
     try:
         self.image_queue.put_nowait(data)
     except queue.Full:
         pass
     except Exception as ex:
         utils.json_format_logger("Error retrieving frame from gazebo: {}".format(ex),
                    **utils.build_system_error_dict(utils.SIMAPP_ENVIRONMENT_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
Ejemplo n.º 11
0
    def download_model(self, checkpoint_dir):
        s3_client = self.get_client()
        filename = "None"
        try:
            filename = os.path.abspath(
                os.path.join(checkpoint_dir, "checkpoint"))
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)

            while True:
                response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                     Prefix=self._get_s3_key(
                                                         self.lock_file))

                if "Contents" not in response:
                    # If no lock is found, try getting the checkpoint
                    try:
                        s3_client.download_file(
                            Bucket=self.bucket,
                            Key=self._get_s3_key("checkpoint"),
                            Filename=filename)
                    except Exception as e:
                        time.sleep(2)
                        continue
                else:
                    time.sleep(2)
                    continue

                ckpt = CheckpointState()
                if os.path.exists(filename):
                    contents = open(filename, 'r').read()
                    text_format.Merge(contents, ckpt)
                    rel_path = ckpt.model_checkpoint_path
                    checkpoint = int(rel_path.split('_Step')[0])

                    response = s3_client.list_objects_v2(
                        Bucket=self.bucket, Prefix=self._get_s3_key(rel_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            filename = os.path.abspath(
                                os.path.join(
                                    checkpoint_dir, obj["Key"].replace(
                                        self.model_checkpoints_prefix, "")))
                            s3_client.download_file(Bucket=self.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return True

        except Exception as e:
            utils.json_format_logger(
                "{} while downloading the model {} from S3".format(
                    e, filename),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            return False
Ejemplo n.º 12
0
    def download_model(self, checkpoint_dir):
        s3_client = self.get_client()
        logger.info("Downloading pretrained model from %s/%s %s" % (self.bucket, self.model_checkpoints_prefix, checkpoint_dir))
        filename = "None"
        try:
            filename = os.path.abspath(os.path.join(checkpoint_dir, "checkpoint"))
            if not os.path.exists(checkpoint_dir):
                logger.info("Model folder %s does not exist, creating" % filename)
                os.makedirs(checkpoint_dir)

            while True:
                response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                     Prefix=self._get_s3_key(self.lock_file))

                if "Contents" not in response:
                    # If no lock is found, try getting the checkpoint
                    try:
                        key = self._get_s3_key("checkpoint")
                        logger.info("Downloading %s" % key)
                        s3_client.download_file(Bucket=self.bucket,
                                                Key=key,
                                                Filename=filename)
                    except Exception as e:
                        logger.info("Something went wrong, will retry in 2 seconds %s" % e)
                        time.sleep(2)
                        continue
                else:
                    logger.info("Found a lock file %s , waiting" % self._get_s3_key(self.lock_file))
                    time.sleep(2)
                    continue

                ckpt = CheckpointState()
                if os.path.exists(filename):
                    contents = open(filename, 'r').read()
                    text_format.Merge(contents, ckpt)
                    rel_path = ckpt.model_checkpoint_path
                    checkpoint = int(rel_path.split('_Step')[0])

                    response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                         Prefix=self._get_s3_key(rel_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            filename = os.path.abspath(os.path.join(checkpoint_dir,
                                                                    obj["Key"].replace(self.model_checkpoints_prefix,
                                                                                       "")))

                            logger.info("Downloading model file %s" % filename)
                            s3_client.download_file(Bucket=self.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return True

        except Exception as e:
            utils.json_format_logger ("{} while downloading the model {} from S3".format(e, filename),
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
            return False
Ejemplo n.º 13
0
 def upload_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.upload_file(Filename=local_path,
                               Bucket=self.bucket,
                               Key=s3_key)
         return True
     except Exception as e:
         utils.json_format_logger("{} on upload file-{} to s3 bucket-{} key-{}".format(e, local_path, self.bucket, s3_key),
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
         return False
 def _get_current_checkpoint_number(self, checkpoint_metadata_filepath=None):
     try:
         if not os.path.exists(checkpoint_metadata_filepath):
             return None
         with open(checkpoint_metadata_filepath, 'r') as fp:
             data = fp.read()
             return int(data.split('_')[0])
     except Exception as e:
         utils.json_format_logger("Exception[{}] occured while reading checkpoint metadata".format(e),
                                  **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
         raise e
Ejemplo n.º 15
0
 def get_ip(self):
     s3_client = self.get_client()
     self._wait_for_ip_upload()
     try:
         s3_client.download_file(self.bucket, self.config_key, 'ip.json')
         with open("ip.json") as f:
             ip = json.load(f)["IP"]
         return ip
     except Exception as e:
         utils.json_format_logger("Exception [{}] occured, Cannot fetch IP of redis server running in SageMaker. Job failed!".format(e),
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
         sys.exit(1)
Ejemplo n.º 16
0
def download_customer_reward_function(s3_client, reward_file_s3_key):
    reward_function_local_path = os.path.join(CUSTOM_FILES_PATH,
                                              "customer_reward_function.py")
    success_reward_function_download = s3_client.download_file(
        s3_key=reward_file_s3_key, local_path=reward_function_local_path)
    if not success_reward_function_download:
        utils.json_format_logger(
            "Could not download the customer reward function file. Job failed!",
            **utils.build_system_error_dict(
                utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_503))
        traceback.print_exc()
        sys.exit(1)
Ejemplo n.º 17
0
 def callback_image(self, data):
     try:
         self.image_queue.put_nowait(data)
     except queue.Full:
         # Only warn if its the middle of an episode, not during training
         if self.allow_servo_step_signals:
             logger.info("Warning: dropping image due to queue full")
         pass
     except Exception as ex:
         utils.json_format_logger(
             "Error retrieving frame from gazebo: {}".format(ex),
             **utils.build_system_error_dict(
                 utils.SIMAPP_ENVIRONMENT_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
Ejemplo n.º 18
0
    def racecar_reset(self):
        try:
            for joint in EFFORT_JOINTS:
                self.clear_forces_client(joint)
            prev_index, next_index = self.find_prev_next_waypoints(self.start_ndist)
            self.reset_car_client(self.start_ndist, next_index)
            # First clear the queue so that we set the state to the start image
            _ = self.image_queue.get(block=True, timeout=None)
            self.set_next_state()

        except Exception as ex:
            utils.json_format_logger("Unable to reset the car: {}".format(ex),
                         **utils.build_system_error_dict(utils.SIMAPP_ENVIRONMENT_EXCEPTION,
                                                         utils.SIMAPP_EVENT_ERROR_CODE_500))
def download_customer_reward_function(s3_client, reward_file_s3_key):
    reward_function_local_path = os.path.join(CUSTOM_FILES_PATH,
                                              "customer_reward_function.py")
    success_reward_function_download = s3_client.download_file(
        s3_key=reward_file_s3_key, local_path=reward_function_local_path)

    if not success_reward_function_download:
        utils.json_format_logger(
            "Unable to download the reward function code.",
            **utils.build_user_error_dict(
                utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_400))
        traceback.print_exc()
        utils.simapp_exit_gracefully()
 def write_simtrace_data(self,jsondata):
     if self.data_state != SIMTRACE_DATA_UPLOAD_UNKNOWN_STATE:
         try:
             csvdata = []
             for key in SIMTRACE_CSV_DATA_HEADER:
                 csvdata.append(jsondata[key])
             self.csvwriter.writerow(csvdata)
             self.total_upload_size += sys.getsizeof(csvdata)
             logger.debug ("csvdata={} size data={} csv={}".format(csvdata, sys.getsizeof(csvdata), sys.getsizeof(self.simtrace_csv_data.getvalue())))
         except Exception as ex:
             utils.json_format_logger("Invalid SIM_TRACE data format , Exception={}. Job failed!".format(ex),
                    **utils.build_system_error_dict(utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
             traceback.print_exc()
             utils.simapp_exit_gracefully()
Ejemplo n.º 21
0
    def _get_current_checkpoint(self):
        try:
            checkpoint_metadata_filepath = os.path.abspath(
                os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME))
            checkpoint = CheckpointState()
            if os.path.exists(checkpoint_metadata_filepath) == False:
                return None

            contents = open(checkpoint_metadata_filepath, 'r').read()
            text_format.Merge(contents, checkpoint)
            return checkpoint
        except Exception as e:
            utils.json_format_logger("Exception[{}] occured while reading checkpoint metadata".format(e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))
            raise e
Ejemplo n.º 22
0
    def get_latest_checkpoint(self):
        try:
            filename = os.path.abspath(
                os.path.join(self.params.checkpoint_dir, "latest_ckpt"))
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            while True:
                s3_client = self._get_client()
                # Check if there's a lock file
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(self.params.lock_file))

                if "Contents" not in response:
                    try:
                        # If no lock is found, try getting the checkpoint
                        s3_client.download_file(
                            Bucket=self.params.bucket,
                            Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME),
                            Filename=filename)
                    except Exception as e:
                        logger.info(
                            "Error occured while getting latest checkpoint %s. Waiting."
                            % e)
                        time.sleep(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        )
                        continue
                else:
                    time.sleep(
                        SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                    )
                    continue

                checkpoint = self._get_current_checkpoint(
                    checkpoint_metadata_filepath=filename)
                if checkpoint:
                    checkpoint_number = self._get_checkpoint_number(checkpoint)
                    return checkpoint_number

        except Exception as e:
            utils.json_format_logger(
                "Exception [{}] occured while getting latest checkpoint from S3."
                .format(e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_503))
Ejemplo n.º 23
0
 def _wait_for_ip_upload(self, timeout=600):
     s3_client = self.get_client()
     time_elapsed = 0
     while True:
         response = s3_client.list_objects(Bucket=self.bucket, Prefix=self.done_file_key)
         if "Contents" not in response:
             time.sleep(1)
             time_elapsed += 1
             if time_elapsed % 5 == 0:
                 logger.info ("Waiting for SageMaker Redis server IP... Time elapsed: %s seconds" % time_elapsed)
             if time_elapsed >= timeout:
                 #raise RuntimeError("Cannot retrieve IP of redis server running in SageMaker")
                 utils.json_format_logger("Cannot retrieve IP of redis server running in SageMaker. Job failed!",
                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
                 sys.exit(1)
         else:
             return
 def upload_finished_file(self):
     try:
         s3_client = self._get_client()
         s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(SyncFiles.FINISHED.value))
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger("Unable to upload finished file to {}, {}"
                                  .format(self.params.bucket, e.response['Error']['Code']),
                                  **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger("Unable to upload finished file to {}, {}"
                                  .format(self.params.bucket, e),
                                  **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                  utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def upload_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.upload_file(Filename=local_path,
                               Bucket=self.bucket,
                               Key=s3_key)
         return True
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to upload {} to {}: {}".format(
                 s3_key, self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
     except Exception as e:
         utils.json_format_logger(
             "Unable to upload {} to {}: {}".format(s3_key, self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
     return False
 def upload_hyperparameters(self, hyperparams_json):
     try:
         s3_client = self.get_client()
         file_handle = io.BytesIO(hyperparams_json.encode())
         s3_client.upload_fileobj(file_handle, self.bucket,
                                  self.hyperparameters_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Hyperparameters failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Hyperparameters failed to upload to {}, {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def get_ip(self):
     s3_client = self.get_client()
     self._wait_for_ip_upload()
     try:
         s3_client.download_file(self.bucket, self.config_key, 'ip.json')
         with open("ip.json") as f:
             ip = json.load(f)["IP"]
         return ip
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to retrieve redis ip from {}: {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to retrieve redis ip from {}: {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
    """
    block until there is a checkpoint in checkpoint_dir
    """
    for i in range(timeout):
        if data_store:
            data_store.load_from_store()

        if has_checkpoint(checkpoint_dir):
            return
        time.sleep(10)

    # one last time
    if has_checkpoint(checkpoint_dir):
        return

    utils.json_format_logger(
        "Checkpoint never found in {}, waited {} seconds.".format(
            checkpoint_dir, timeout),
        **utils.build_system_error_dict(
            utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
            utils.SIMAPP_EVENT_ERROR_CODE_500))
    traceback.print_exc()
    utils.simapp_exit_gracefully()
    def _wait_for_ip_upload(self, timeout=600):
        s3_client = self.get_client()
        time_elapsed = 0

        while time_elapsed < timeout:
            try:
                response = s3_client.list_objects(Bucket=self.bucket,
                                                  Prefix=self.done_file_key)
            except botocore.exceptions.ClientError as e:
                utils.json_format_logger(
                    "Unable to access {}: {}".format(
                        self.bucket, e.response['Error']['Code']),
                    **utils.build_user_error_dict(
                        utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_400))
                utils.simapp_exit_gracefully()
            except Exception as e:
                utils.json_format_logger(
                    "Unable to access {}: {}".format(self.bucket, e),
                    **utils.build_system_error_dict(
                        utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_500))
                utils.simapp_exit_gracefully()

            if "Contents" in response:
                return
            time.sleep(1)
            time_elapsed += 1
            if time_elapsed % 5 == 0:
                logger.info(
                    "Waiting for SageMaker Redis server IP... Time elapsed: %s seconds"
                    % time_elapsed)

        utils.json_format_logger(
            "Timed out while attempting to retrieve the Redis IP",
            **utils.build_system_error_dict(
                utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_500))
        utils.simapp_exit_gracefully()
Ejemplo n.º 30
0
def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()

        while steps < graph_manager.improve_steps.num_steps:
            graph_manager.phase = core_types.RunPhase.TRAIN
            graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
            graph_manager.phase = core_types.RunPhase.UNDEFINED

            episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout()

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

            if graph_manager.should_train():
                # Make sure we have enough data for the requested batches
                rollout_steps = graph_manager.memory_backend.get_rollout_steps()
                if any(rollout_steps.values()) <= 0:
                    utils.json_format_logger("No rollout data retrieved from the rollout worker",
                                             **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                             utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.network_wrappers['main'].batch_size = episode_batch_size

                steps += 1

                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                # Check for Nan's in all agents
                rollout_has_nan = False
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        if np.isnan(agent.loss.get_mean()):
                            rollout_has_nan = True
                if rollout_has_nan:
                    utils.json_format_logger("NaN detected in loss function, aborting training.",
                                             **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                             utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()

                # Clear any data stored in signals that is no longer necessary
                graph_manager.reset_internal_state()

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

            if door_man.terminate_now:
                utils.json_format_logger("Received SIGTERM. Checkpointing before exiting.",
                                         **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                         utils.SIMAPP_EVENT_ERROR_CODE_500))
                graph_manager.save_checkpoint()
                break

    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            utils.log_and_exit("User modified model: {}".format(err),
                               utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_400)
        else:
            utils.log_and_exit("An error occured while training: {}".format(err),
                               utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        utils.log_and_exit("An error occured while training: {}".format(ex),
                           utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                           utils.SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()