def load_weights_from_torch(self): # load b0 import torch if self.number == 0: b0 = fetch( "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" ) elif self.number == 2: b0 = fetch( "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth" ) else: raise Exception("no pretrained weights") b0 = torch.load(io.BytesIO(b0)) for k, v in b0.items(): if '_blocks.' in k: k = "%s[%s].%s" % tuple(k.split(".", 2)) mk = "self." + k #print(k, v.shape) try: mv = eval(mk) except AttributeError: try: mv = eval(mk.replace(".weight", "")) except AttributeError: mv = eval(mk.replace(".bias", "_bias")) vnp = v.numpy().astype(np.float32) mv.data[:] = vnp if k != '_fc.weight' else vnp.T if GPU: mv.cuda_()
def fetch_mnist(): import gzip parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy() X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28)) Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:] X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28)) Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:] return X_train, Y_train, X_test, Y_test
def load_cifar(): tt = tarfile.open(fileobj=io.BytesIO( fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz') db = pickle.load(tt.extractfile('cifar-10-batches-py/data_batch_1'), encoding="bytes") X = db[b'data'].reshape((-1, 3, 32, 32)) Y = np.array(db[b'labels']) return X, Y
def load_weights_from_torch(self): # load b0 # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551 if self.number == 0: b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth") elif self.number == 2: b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth") elif self.number == 4: b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth") elif self.number == 7: b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth") else: raise Exception("no pretrained weights") if USE_TORCH: import io import torch b0 = torch.load(io.BytesIO(b0)) else: b0 = fake_torch_load(b0) for k,v in b0.items(): if '_blocks.' in k: k = "%s[%s].%s" % tuple(k.split(".", 2)) mk = "self."+k #print(k, v.shape) try: mv = eval(mk) except AttributeError: try: mv = eval(mk.replace(".weight", "")) except AttributeError: mv = eval(mk.replace(".bias", "_bias")) vnp = v.numpy().astype(np.float32) if USE_TORCH else v vnp = vnp if k != '_fc.weight' else vnp.T if mv.shape == vnp.shape or vnp.shape == (): mv.data[:] = vnp else: print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
vnp = v.numpy().astype(np.float32) mv.data[:] = vnp if k != '_fc.weight' else vnp.T if __name__ == "__main__": # instantiate my net model = EfficientNet() model.load_weights_from_torch() # load image and preprocess from PIL import Image if len(sys.argv) > 1: url = sys.argv[1] else: url = "https://raw.githubusercontent.com/karpathy/micrograd/master/puppy.jpg" img = Image.open(io.BytesIO(fetch(url))) aspect_ratio = img.size[0] / img.size[1] img = img.resize((int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))) img = np.array(img) y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2 img = img[y0:y0 + 224, x0:x0 + 224] img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) img = img.astype(np.float32).reshape(1, 3, 224, 224) img /= 255.0 img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) # if you want to look at the micrograd puppy """
mv = eval(mk) except AttributeError: try: mv = eval(mk.replace(".weight", "")) except AttributeError: mv = eval(mk.replace(".bias", "_bias")) vnp = v.numpy().astype(np.float32) mv.data[:] = vnp if k != '_fc.weight' else vnp.T if __name__ == "__main__": # instantiate my net model = EfficientNet() model.load_weights_from_torch() # load cat image from PIL import Image img = Image.open( io.BytesIO( fetch( "https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg" ))) img = img.resize((224, 224)) img = np.moveaxis(np.array(img), [2, 0, 1], [0, 1, 2]) img = img.astype(np.float32).reshape(1, 3, 224, 224) print(img.shape, img.dtype) # run the net out = model.forward(Tensor(img)) print(np.argmax(out.data), np.max(out.data))