def sample(self, data):
        """Save given data as pickle and upload to s3.
        - collector will stop persisting if the number of samples reached max_sample_count.
        - collector will only persist if sampling_frequency is met.

        Args:
            data (object): The sample data to pickle and upload to S3
        """
        if self._cur_sample_count >= self.max_sample_count:
            return
        self._cur_frequency += 1
        if self._cur_frequency < self.sampling_frequency:
            return

        pickle_filename_format = 'sample_{}.pkl'
        pickle_filename = pickle_filename_format.format(self._cur_sample_count)
        try:
            with open(pickle_filename, 'wb') as out_f:
                pickle.dump(data, out_f, protocol=2)
        except Exception as ex:
            raise GenericTrainerException(
                'Failed to dump the sample data: {}'.format(ex))

        try:
            self._s3_client.upload_file(
                bucket=self._bucket,
                s3_key=os.path.normpath("%s/samples/%s" %
                                        (self.s3_prefix, pickle_filename)),
                local_path=pickle_filename,
                s3_kms_extra_args=dict())
        except Exception as ex:
            raise GenericTrainerException(
                'Failed to upload the sample pickle file to S3: {}'.format(ex))
        self._cur_frequency = 0
        self._cur_sample_count += 1
    def sample(self, data):
        if self._cur_sample_count >= self.max_sample_count:
            return
        self._cur_frequency += 1
        if self._cur_frequency < self.sampling_frequency:
            return

        pickle_filename_format = 'sample_{}.pkl'
        pickle_filename = pickle_filename_format.format(self._cur_sample_count)
        try:
            with open(pickle_filename, 'wb') as out_f:
                pickle.dump(data, out_f, protocol=2)
        except Exception as ex:
            raise GenericTrainerException(
                'Failed to dump the sample data: {}'.format(ex))

        try:
            self.s3_client.upload_file(
                os.path.normpath("%s/samples/%s" %
                                 (self.s3_prefix, pickle_filename)),
                pickle_filename)
        except Exception as ex:
            raise GenericTrainerException(
                'Failed to upload the sample pickle file to S3: {}'.format(ex))
        self._cur_frequency = 0
        self._cur_sample_count += 1
 def get_observation_space(self):
     try:
         return get_observation_space(Input.LEFT_CAMERA.value)
     except GenericError as ex:
         ex.log_except_and_exit(SIMAPP_TRAINING_WORKER_EXCEPTION)
     except Exception as ex:
         raise GenericTrainerException('{}'.format(ex))
 def get_input_embedders(self, network_type):
     try:
         return get_observation_embedder()
     except GenericError as ex:
         ex.log_except_and_exit(SIMAPP_TRAINING_WORKER_EXCEPTION)
     except Exception as ex:
         raise GenericTrainerException('{}'.format(ex))
 def get_input_embedders(self, network_type):
     try:
         return get_lidar_embedders(network_type, Input.SECTOR_LIDAR.value)
     except GenericError as ex:
         ex.log_except_and_exit(SIMAPP_TRAINING_WORKER_EXCEPTION)
     except Exception as ex:
         raise GenericTrainerException('{}'.format(ex))
    def __init__(
        self,
        bucket,
        s3_prefix,
        region_name,
        max_sample_count=None,
        sampling_frequency=None,
        max_retry_attempts=5,
        backoff_time_sec=1.0,
    ):
        """Sample Collector class to collect sample and persist to S3.

        Args:
            bucket (str): S3 bucket string
            s3_prefix (str): S3 prefix string
            region_name (str): S3 region name
            max_sample_count (int): max sample count
            sampling_frequency (int): sampleing frequency
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry
        """
        self.max_sample_count = max_sample_count or 0
        self.sampling_frequency = sampling_frequency or 1
        if self.sampling_frequency < 1:
            err_msg = "sampling_frequency must be larger or equal to 1. (Given: {})".format(
                self.sampling_frequency)
            raise GenericTrainerException(err_msg)
        self.s3_prefix = s3_prefix

        self._cur_sample_count = 0
        self._cur_frequency = 0
        self._bucket = bucket
        self._s3_client = S3Client(region_name, max_retry_attempts,
                                   backoff_time_sec)
Ejemplo n.º 7
0
 def get_input_embedders(self, network_type):
     try:
         return get_front_camera_embedders(network_type)
     except GenericError as ex:
         ex.log_except_and_exit(SIMAPP_TRAINING_WORKER_EXCEPTION)
     except Exception as ex:
         raise GenericTrainerException("{}".format(ex))
 def get_observation_space(self):
     try:
         return get_observation_space(Input.DISCRETIZED_SECTOR_LIDAR.value,
                                      self.model_metadata)
     except GenericError as ex:
         ex.log_except_and_exit(SIMAPP_TRAINING_WORKER_EXCEPTION)
     except Exception as ex:
         raise GenericTrainerException('{}'.format(ex))
    def __init__(self,
                 s3_client,
                 s3_prefix,
                 max_sample_count=None,
                 sampling_frequency=None):
        self.max_sample_count = max_sample_count or 0
        self.sampling_frequency = sampling_frequency or 1
        if self.sampling_frequency < 1:
            err_msg = "sampling_frequency must be larger or equal to 1. (Given: {})".format(
                self.sampling_frequency)
            raise GenericTrainerException(err_msg)
        self.s3_client = s3_client
        self.s3_prefix = s3_prefix

        self._cur_sample_count = 0
        self._cur_frequency = 0
 def create_sensor(racecar_name, sensor_type, config_dict):
     '''Factory method for creating sensors
         type - String containing the desired sensor type
         kwargs - Meta data, usually containing the topics to subscribe to, the
                  concrete sensor classes are responsible for checking the topics.
     '''
     if sensor_type == Input.CAMERA.value:
         return Camera()
     elif sensor_type == Input.LEFT_CAMERA.value:
         return LeftCamera()
     elif sensor_type == Input.STEREO.value:
         return DualCamera()
     elif sensor_type == Input.LIDAR.value:
         return Lidar()
     elif sensor_type == Input.SECTOR_LIDAR.value:
         return SectorLidar()
     elif sensor_type == Input.OBSERVATION.value:
         return Observation()
     else:
         raise GenericTrainerException("Unknown sensor")
Ejemplo n.º 11
0
def test_deepracer_exceptions():
    """The function tests whether the user defined exceptions in deepracer_exceptions.py are
    getting raised properly when we call them from any part of SIMAPP code.

    The test function also checks whether the superclass Exception manages to provide
    the necessary error message passed along as well.

    Raises:
        RewardFunctionError
        GenericTrainerException
        GenericTrainerError
        GenericRolloutException
        GenericRolloutError
        GenericValidatorException
        GenericValidatorError
        GenericException
        GenericError
    """
    with pytest.raises(RewardFunctionError, match=r".*RewardFunctionError.*"):
        raise RewardFunctionError("RewardFunctionError")
    with pytest.raises(GenericTrainerException,
                       match=r".*GenericTrainerException.*"):
        raise GenericTrainerException("GenericTrainerException")
    with pytest.raises(GenericTrainerError, match=r".*GenericTrainerError.*"):
        raise GenericTrainerError("GenericTrainerError")
    with pytest.raises(GenericRolloutException,
                       match=r".*GenericRolloutException.*"):
        raise GenericRolloutException("GenericRolloutException")
    with pytest.raises(GenericRolloutError, match=r".*GenericRolloutError.*"):
        raise GenericRolloutError("GenericRolloutError")
    with pytest.raises(GenericValidatorException,
                       match=r".*GenericValidatorException.*"):
        raise GenericValidatorException("GenericValidatorException")
    with pytest.raises(GenericValidatorError,
                       match=r".*GenericValidatorError.*"):
        raise GenericValidatorError("GenericValidatorError")
    with pytest.raises(GenericException, match=r".*GenericException.*"):
        raise GenericException("GenericException")
    with pytest.raises(GenericError, match=r".*GenericError.*"):
        raise GenericError("GenericError")