Exemple #1
0
def CreateDataResponse(request):
  """Creates a DataResponse from a DataRequest.

  Args:
    request: A DataRequest instance received from client.
  Returns:
    A DataResponse instance and extra headers.
  """
  waveform_data_source = GetWaveformDataSource(request)
  pred_outputs = GetPredictionsOutputs(request)

  data_response = data_pb2.DataResponse()
  data_response.waveform_metadata.CopyFrom(waveform_data_service.GetMetadata(
      waveform_data_source, _MAX_SAMPLES_CLIENT))
  data_response.waveform_chunk.CopyFrom(waveform_data_service.GetChunk(
      waveform_data_source, request, _MAX_SAMPLES_CLIENT))

  if pred_outputs:
    waveforms_pred = prediction_data_service.PredictionDataService(
        pred_outputs, waveform_data_source, _MAX_SAMPLES_CLIENT)

    data_response.prediction_metadata.CopyFrom(waveforms_pred.GetMetadata())
    data_response.prediction_chunk.CopyFrom(waveforms_pred.GetChunk(request))

  # When only an SSTable file pattern is provided, the cache will return the TF
  # Example under the first iterated key.  Since the order of the keys is not
  # guaranteed, this response will not be cached as it is not idempotent.
  extra_headers = dict()
  no_cache = request.tf_ex_sstable_path and not request.sstable_key
  extra_headers['Cache-Control'] = 'no-cache' if no_cache else 'public'

  return data_response, extra_headers
  def testGetChunkReturns_IndexRaisesValueError(self, mock_create):
    with self.assertRaises(ValueError):
      request = data_pb2.DataRequest()
      request.chunk_duration_secs = 10
      request.chunk_start = 21
      channel_data_id = request.channel_data_ids.add()
      channel_data_id.single_channel.index = 0
      request.low_cut = 1.0
      request.high_cut = 70.0
      request.notch = 60.0

      waveform_data_service.GetChunk(self.waveform_data_source, request, 10)
    mock_create.assert_not_called()
Exemple #3
0
def CreateDataResponse(request):
    """Creates a DataResponse from a DataRequest.

  Args:
    request: A DataRequest instance received from client.
  Returns:
    A DataResponse instance.
  Raises:
    NotImplementedError: Try to load from SSTable or EDF
    IOError: No tf path provided
  """
    data_response = data_pb2.DataResponse()

    pred_outputs = None

    if request.tf_ex_sstable_path:
        raise NotImplementedError('Loading SSTables')

    elif request.edf_path:
        raise NotImplementedError('Loading EDF')

    elif request.tf_ex_file_path:
        tf_example = FetchTfExFromFile(request.tf_ex_file_path)
        waveform_data_source = TfExDataSourceConstructor(tf_example, '')

        if request.prediction_file_path:
            pred_outputs = FetchPredictionsFromFile(
                request.prediction_file_path)

    else:
        raise IOError('No path provided')

    data_response.waveform_metadata.CopyFrom(
        waveform_data_service.GetMetadata(waveform_data_source, _MAX_SAMPLES))
    data_response.waveform_chunk.CopyFrom(
        waveform_data_service.GetChunk(waveform_data_source, request,
                                       _MAX_SAMPLES))

    if pred_outputs:
        waveforms_pred = prediction_data_service.PredictionDataService(
            pred_outputs, waveform_data_source, _MAX_SAMPLES)

        data_response.prediction_metadata.CopyFrom(
            waveforms_pred.GetMetadata())
        data_response.prediction_chunk.CopyFrom(
            waveforms_pred.GetChunk(request))

    return data_response
  def testGetChunk(self, mock_create):
    mock_create.return_value = ('test data', 1)

    request = data_pb2.DataRequest()
    request.chunk_duration_secs = 10
    request.chunk_start = 0
    channel_data_id = request.channel_data_ids.add()
    channel_data_id.single_channel.index = 0
    request.low_cut = 1.0
    request.high_cut = 70.0
    request.notch = 60.0

    response = waveform_data_service.GetChunk(self.waveform_data_source,
                                              request, 10)
    mock_create.assert_called_with(self.waveform_data_source, request, 10)
    self.assertEqual('test data', response.waveform_datatable)
    self.assertEqual(0, response.channel_data_ids[0].single_channel.index)