Exemplo n.º 1
0
def TfExDataSourceConstructor(*args):
    """Select a EEG or ECG/EKG TfExampleDataSource instance, given by FLAGS.file_type.

  Args:
    *args: arguments passed to the constructor
  Returns:
    TfExampleEegDataSource or TfExampleEkgDataSource instance
  """
    if FLAGS.file_type == 'EEG':
        return data_source.TfExampleEegDataSource(*args)
    elif FLAGS.file_type == 'ECG' or FLAGS.file_type == 'EKG':
        return data_source.TfExampleEkgDataSource(*args)
    else:
        logging.warning('Unknown file type %s, using EEG', FLAGS.file_type)
        return data_source.TfExampleEegDataSource(*args)
 def setUp(self):
   super(WaveformDataServiceTest, self).setUp()
   tf_ex = tf.train.Example()
   feature = tf_ex.features.feature
   feature['eeg_channel/num_samples'].int64_list.value.append(20)
   feature['eeg_channel/sampling_frequency_hz'].float_list.value.append(1.0)
   for i in range(20):
     feature['test_feature'].float_list.value.append(i*2)
   for i in range(20):
     feature[('eeg_channel/EEG '
              'test_sub-REF/samples')].float_list.value.append(i+1)
   for i in range(20):
     feature[('eeg_channel/EEG '
              'test_min-REF/samples')].float_list.value.append(1)
   self.waveform_data_source = data_source.TfExampleEegDataSource(tf_ex,
                                                                  'test_key')
Exemplo n.º 3
0
 def setUp(self):
     super(PredictionDataServiceTest, self).setUp()
     tf_ex = tf.train.Example()
     feature = tf_ex.features.feature
     feature['eeg_channel/num_samples'].int64_list.value.append(20)
     feature['eeg_channel/sampling_frequency_hz'].float_list.value.append(
         1.0)
     for i in range(20):
         feature['test_feature'].float_list.value.append(i * 2)
     for i in range(20):
         feature[('eeg_channel/EEG '
                  'test_sub-REF/samples')].float_list.value.append(i + 1)
     for i in range(20):
         feature[('eeg_channel/EEG '
                  'test_min-REF/samples')].float_list.value.append(1)
     time = timestamp_pb2.Timestamp()
     time.seconds = 10
     time.nanos = 800000000
     feature['start_time'].bytes_list.value.append(time.SerializeToString())
     feature['segment/patient_id'].bytes_list.value.append(b'test patient')
     waveform_data_source = data_source.TfExampleEegDataSource(
         tf_ex, 'test key')
     pred_outputs = prediction_output_pb2.PredictionOutputs()
     pred_output = pred_outputs.prediction_output.add()
     pred_output.chunk_info.chunk_id = 'test chunk'
     pred_output.chunk_info.chunk_start_time.seconds = 10
     pred_output.chunk_info.chunk_start_time.nanos = 800000000
     pred_output.chunk_info.chunk_size_sec = 2
     label = pred_output.label.add()
     label.name = 'test label'
     pred_output_2 = pred_outputs.prediction_output.add()
     pred_output_2.chunk_info.chunk_id = 'test chunk 2'
     pred_output_2.chunk_info.chunk_start_time.seconds = 12
     pred_output_2.chunk_info.chunk_start_time.nanos = 800000000
     pred_output_2.chunk_info.chunk_size_sec = 2
     label = pred_output_2.label.add()
     label.name = 'test label'
     self._pred = prediction_data_service.PredictionDataService(
         pred_outputs, waveform_data_source, 100)