def preprocess(image, model_name): if model_name == "mv2": image /= 128. image = transforms.resize(image, (224, 224)) elif model_name == "vgg16": image = vgg.prepare(image, size=(224, 224)) elif model_name == "resnet50": image = resnet.prepare(image, size=(224, 224)) else: raise Exception("illegal model") return image
def test_prepare(self): x1 = numpy.random.uniform(0, 255, (320, 240, 3)).astype(numpy.uint8) x2 = numpy.random.uniform(0, 255, (320, 240)).astype(numpy.uint8) x3 = numpy.random.uniform(0, 255, (160, 120, 3)).astype(self.dtype) x4 = numpy.random.uniform(0, 255, (1, 160, 120)).astype(self.dtype) x5 = numpy.random.uniform(0, 255, (3, 160, 120)).astype(numpy.uint8) y1 = vgg.prepare(x1) assert y1.shape == (3, 224, 224) assert y1.dtype == self.dtype y2 = vgg.prepare(x2) assert y2.shape == (3, 224, 224) assert y2.dtype == self.dtype y3 = vgg.prepare(x3, size=None) assert y3.shape == (3, 160, 120) assert y3.dtype == self.dtype y4 = vgg.prepare(x4) assert y4.shape == (3, 224, 224) assert y4.dtype == self.dtype y5 = vgg.prepare(x5, size=None) assert y5.shape == (3, 160, 120) assert y5.dtype == self.dtype
def test_prepare(self): x1 = numpy.random.uniform(0, 255, (320, 240, 3)).astype(numpy.uint8) x2 = numpy.random.uniform(0, 255, (320, 240)).astype(numpy.uint8) x3 = numpy.random.uniform(0, 255, (160, 120, 3)).astype(numpy.float32) x4 = numpy.random.uniform(0, 255, (1, 160, 120)).astype(numpy.float32) x5 = numpy.random.uniform(0, 255, (3, 160, 120)).astype(numpy.uint8) y1 = vgg.prepare(x1) self.assertEqual(y1.shape, (3, 224, 224)) self.assertEqual(y1.dtype, numpy.float32) y2 = vgg.prepare(x2) self.assertEqual(y2.shape, (3, 224, 224)) self.assertEqual(y2.dtype, numpy.float32) y3 = vgg.prepare(x3, size=None) self.assertEqual(y3.shape, (3, 160, 120)) self.assertEqual(y3.dtype, numpy.float32) y4 = vgg.prepare(x4) self.assertEqual(y4.shape, (3, 224, 224)) self.assertEqual(y4.dtype, numpy.float32) y5 = vgg.prepare(x5, size=None) self.assertEqual(y5.shape, (3, 160, 120)) self.assertEqual(y5.dtype, numpy.float32)
def transform(in_data): x_a, x_p, x_n = in_data return prepare(x_a), prepare(x_p), prepare(x_n)