def test_reads_with_traces_no_unconfirmed_exposures(self): data_loader = abesim_data_loader.AbesimExposureDataLoader( _get_test_data_path(), unconfirmed_exposures=False) _, labels, grouping = data_loader.get_next_batch(batch_size=5) expect_labels = [0, 1, 0, 0, 1] expect_grouping = [0, 0, 0, 0, 0] self.assertCountEqual(labels, expect_labels) self.assertCountEqual(grouping, expect_grouping)
def test_reads_with_traces_batch_size_too_large(self): data_loader = abesim_data_loader.AbesimExposureDataLoader( _get_test_data_path(), unconfirmed_exposures=False) _, labels, grouping = data_loader.get_next_batch(batch_size=16) self.assertLen(labels, 16) self.assertLen(grouping, 16) _, labels, grouping = data_loader.get_next_batch(batch_size=16) self.assertLen(labels, 14) self.assertLen(grouping, 14)
def test_reads_without_traces_or_triple_raises(self): filename = _get_tmp_file_name() with riegeli.RecordWriter( io.FileIO(filename, mode='wb'), options='transpose', metadata=riegeli.RecordsMetadata()) as writer: writer.write_messages(_get_exposures_missing_required_data()) data_loader = abesim_data_loader.AbesimExposureDataLoader( filename, unconfirmed_exposures=False) with self.assertRaises(ValueError): _ = data_loader.get_next_batch(batch_size=1)
def test_reads_without_traces(self): filename = _get_tmp_file_name() with riegeli.RecordWriter( io.FileIO(filename, mode='wb'), options='transpose', metadata=riegeli.RecordsMetadata()) as writer: writer.write_messages(_get_exposures_without_proximity_traces()) data_loader = abesim_data_loader.AbesimExposureDataLoader( filename, unconfirmed_exposures=False) exposures, labels, grouping = data_loader.get_next_batch(batch_size=1) self.assertCountEqual(exposures, [([1.0], 0, 30)]) self.assertCountEqual(labels, [1]) self.assertCountEqual(grouping, [1])