def TfExDataSourceConstructor(*args): """Select a EEG or ECG/EKG TfExampleDataSource instance, given by FLAGS.file_type. Args: *args: arguments passed to the constructor Returns: TfExampleEegDataSource or TfExampleEkgDataSource instance """ if FLAGS.file_type == 'EEG': return data_source.TfExampleEegDataSource(*args) elif FLAGS.file_type == 'ECG' or FLAGS.file_type == 'EKG': return data_source.TfExampleEkgDataSource(*args) else: logging.warning('Unknown file type %s, using EEG', FLAGS.file_type) return data_source.TfExampleEegDataSource(*args)
def setUp(self): super(WaveformDataServiceTest, self).setUp() tf_ex = tf.train.Example() feature = tf_ex.features.feature feature['eeg_channel/num_samples'].int64_list.value.append(20) feature['eeg_channel/sampling_frequency_hz'].float_list.value.append(1.0) for i in range(20): feature['test_feature'].float_list.value.append(i*2) for i in range(20): feature[('eeg_channel/EEG ' 'test_sub-REF/samples')].float_list.value.append(i+1) for i in range(20): feature[('eeg_channel/EEG ' 'test_min-REF/samples')].float_list.value.append(1) self.waveform_data_source = data_source.TfExampleEegDataSource(tf_ex, 'test_key')
def setUp(self): super(PredictionDataServiceTest, self).setUp() tf_ex = tf.train.Example() feature = tf_ex.features.feature feature['eeg_channel/num_samples'].int64_list.value.append(20) feature['eeg_channel/sampling_frequency_hz'].float_list.value.append( 1.0) for i in range(20): feature['test_feature'].float_list.value.append(i * 2) for i in range(20): feature[('eeg_channel/EEG ' 'test_sub-REF/samples')].float_list.value.append(i + 1) for i in range(20): feature[('eeg_channel/EEG ' 'test_min-REF/samples')].float_list.value.append(1) time = timestamp_pb2.Timestamp() time.seconds = 10 time.nanos = 800000000 feature['start_time'].bytes_list.value.append(time.SerializeToString()) feature['segment/patient_id'].bytes_list.value.append(b'test patient') waveform_data_source = data_source.TfExampleEegDataSource( tf_ex, 'test key') pred_outputs = prediction_output_pb2.PredictionOutputs() pred_output = pred_outputs.prediction_output.add() pred_output.chunk_info.chunk_id = 'test chunk' pred_output.chunk_info.chunk_start_time.seconds = 10 pred_output.chunk_info.chunk_start_time.nanos = 800000000 pred_output.chunk_info.chunk_size_sec = 2 label = pred_output.label.add() label.name = 'test label' pred_output_2 = pred_outputs.prediction_output.add() pred_output_2.chunk_info.chunk_id = 'test chunk 2' pred_output_2.chunk_info.chunk_start_time.seconds = 12 pred_output_2.chunk_info.chunk_start_time.nanos = 800000000 pred_output_2.chunk_info.chunk_size_sec = 2 label = pred_output_2.label.add() label.name = 'test label' self._pred = prediction_data_service.PredictionDataService( pred_outputs, waveform_data_source, 100)