Exemple #1
0
    def test_scalars_str_single_int(self, mock_dat, mock_caffe):

        # expected serialization of the test image
        x = np.random.randint(-2, 11)  # [low, high), single integer

        # mock caffe calls made by our module
        v = self.PREFIX + self.STR_MAPPINGS[x]
        mock_dat.return_value.SerializeToString.return_value = self.PREFIX + self.STR_MAPPINGS[x]
        mock_caffe.io.array_to_datum.return_value = caffe.proto.caffe_pb2.Datum()

        # use the module and test it
        path_lmdb = os.path.join(self.dir_tmp, 'test_scalars_str_single_lmdb')
        tol.scalars_to_lmdb(x, path_lmdb)
        assert_true(os.path.isdir(path_lmdb), "failed to save LMDB")

        c = 0
        with lmdb.open(path_lmdb, readonly=True).begin() as txn:
            for key, value in txn.cursor():
                # print(k, x[c], value)
                assert_equal(key, tol.IDX_FMT.format(c), "Unexpected key.")
                assert_equal(value, self.PREFIX + self.STR_MAPPINGS[x],
                             "Unexpected content.")
                c += 1

        assert_equal(c, 1, "Unexpected number of samples.")
Exemple #2
0
    def test_scalars_lut(self, mock_dat, mock_caffe):

        # expected serialization of the test image
        x = np.random.randint(-1, 4, size=(10, 1))  # [low, high)

        def lut(value):
            return value - 1

        ser_vals = [self.PREFIX + self.STR_MAPPINGS[v - 1] for v in x.ravel()]

        # mock caffe calls made by our module
        mock_dat.return_value.SerializeToString = MagicMock(side_effect=ser_vals)
        mock_caffe.io.array_to_datum.return_value = caffe.proto.caffe_pb2.Datum()

        # use the module and test it
        path_lmdb = os.path.join(self.dir_tmp, 'xlut_lmdb')
        tol.scalars_to_lmdb(x, path_lmdb, lut=lut)
        assert_true(os.path.isdir(path_lmdb), "failed to save LMDB")

        c = 0
        with lmdb.open(path_lmdb, readonly=True).begin() as txn:
            for key, value in txn.cursor():
                # print(k, x[c], value)
                assert_equal(key, tol.IDX_FMT.format(c), "Unexpected key.")
                assert_equal(value, ser_vals[c], "Unexpected content.")
                c += 1

        assert_equal(c, x.size, "Unexpected number of samples.")