Exemplo n.º 1
0
 def test_reads_with_traces_no_unconfirmed_exposures(self):
     data_loader = abesim_data_loader.AbesimExposureDataLoader(
         _get_test_data_path(), unconfirmed_exposures=False)
     _, labels, grouping = data_loader.get_next_batch(batch_size=5)
     expect_labels = [0, 1, 0, 0, 1]
     expect_grouping = [0, 0, 0, 0, 0]
     self.assertCountEqual(labels, expect_labels)
     self.assertCountEqual(grouping, expect_grouping)
Exemplo n.º 2
0
 def test_reads_with_traces_batch_size_too_large(self):
     data_loader = abesim_data_loader.AbesimExposureDataLoader(
         _get_test_data_path(), unconfirmed_exposures=False)
     _, labels, grouping = data_loader.get_next_batch(batch_size=16)
     self.assertLen(labels, 16)
     self.assertLen(grouping, 16)
     _, labels, grouping = data_loader.get_next_batch(batch_size=16)
     self.assertLen(labels, 14)
     self.assertLen(grouping, 14)
Exemplo n.º 3
0
    def test_reads_without_traces_or_triple_raises(self):
        filename = _get_tmp_file_name()
        with riegeli.RecordWriter(
                io.FileIO(filename, mode='wb'),
                options='transpose',
                metadata=riegeli.RecordsMetadata()) as writer:
            writer.write_messages(_get_exposures_missing_required_data())

        data_loader = abesim_data_loader.AbesimExposureDataLoader(
            filename, unconfirmed_exposures=False)
        with self.assertRaises(ValueError):
            _ = data_loader.get_next_batch(batch_size=1)
Exemplo n.º 4
0
    def test_reads_without_traces(self):
        filename = _get_tmp_file_name()
        with riegeli.RecordWriter(
                io.FileIO(filename, mode='wb'),
                options='transpose',
                metadata=riegeli.RecordsMetadata()) as writer:
            writer.write_messages(_get_exposures_without_proximity_traces())

        data_loader = abesim_data_loader.AbesimExposureDataLoader(
            filename, unconfirmed_exposures=False)
        exposures, labels, grouping = data_loader.get_next_batch(batch_size=1)
        self.assertCountEqual(exposures, [([1.0], 0, 30)])
        self.assertCountEqual(labels, [1])
        self.assertCountEqual(grouping, [1])