예제 #1
0
    def load_weights_from_torch(self):
        # load b0
        import torch
        if self.number == 0:
            b0 = fetch(
                "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth"
            )
        elif self.number == 2:
            b0 = fetch(
                "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth"
            )
        else:
            raise Exception("no pretrained weights")
        b0 = torch.load(io.BytesIO(b0))

        for k, v in b0.items():
            if '_blocks.' in k:
                k = "%s[%s].%s" % tuple(k.split(".", 2))
            mk = "self." + k
            #print(k, v.shape)
            try:
                mv = eval(mk)
            except AttributeError:
                try:
                    mv = eval(mk.replace(".weight", ""))
                except AttributeError:
                    mv = eval(mk.replace(".bias", "_bias"))
            vnp = v.numpy().astype(np.float32)
            mv.data[:] = vnp if k != '_fc.weight' else vnp.T
            if GPU:
                mv.cuda_()
예제 #2
0
def fetch_mnist():
  import gzip
  parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
  X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
  Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
  X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
  Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
  return X_train, Y_train, X_test, Y_test
예제 #3
0
def load_cifar():
    tt = tarfile.open(fileobj=io.BytesIO(
        fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')),
                      mode='r:gz')
    db = pickle.load(tt.extractfile('cifar-10-batches-py/data_batch_1'),
                     encoding="bytes")
    X = db[b'data'].reshape((-1, 3, 32, 32))
    Y = np.array(db[b'labels'])
    return X, Y
예제 #4
0
  def load_weights_from_torch(self):
    # load b0
    # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551
    if self.number == 0:
      b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
    elif self.number == 2:
      b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth")
    elif self.number == 4:
      b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth")
    elif self.number == 7:
      b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth")
    else:
      raise Exception("no pretrained weights")

    if USE_TORCH:
      import io
      import torch
      b0 = torch.load(io.BytesIO(b0))
    else:
      b0 = fake_torch_load(b0)

    for k,v in b0.items():
      if '_blocks.' in k:
        k = "%s[%s].%s" % tuple(k.split(".", 2))
      mk = "self."+k
      #print(k, v.shape)
      try:
        mv = eval(mk)
      except AttributeError:
        try:
          mv = eval(mk.replace(".weight", ""))
        except AttributeError:
          mv = eval(mk.replace(".bias", "_bias"))
      vnp = v.numpy().astype(np.float32) if USE_TORCH else v
      vnp = vnp if k != '_fc.weight' else vnp.T

      if mv.shape == vnp.shape or vnp.shape == ():
        mv.data[:] = vnp
      else:
        print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
예제 #5
0
            vnp = v.numpy().astype(np.float32)
            mv.data[:] = vnp if k != '_fc.weight' else vnp.T


if __name__ == "__main__":
    # instantiate my net
    model = EfficientNet()
    model.load_weights_from_torch()

    # load image and preprocess
    from PIL import Image
    if len(sys.argv) > 1:
        url = sys.argv[1]
    else:
        url = "https://raw.githubusercontent.com/karpathy/micrograd/master/puppy.jpg"
    img = Image.open(io.BytesIO(fetch(url)))
    aspect_ratio = img.size[0] / img.size[1]
    img = img.resize((int(224 * max(aspect_ratio, 1.0)),
                      int(224 * max(1.0 / aspect_ratio, 1.0))))

    img = np.array(img)
    y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
    img = img[y0:y0 + 224, x0:x0 + 224]
    img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
    img = img.astype(np.float32).reshape(1, 3, 224, 224)
    img /= 255.0
    img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
    img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))

    # if you want to look at the micrograd puppy
    """
예제 #6
0
                mv = eval(mk)
            except AttributeError:
                try:
                    mv = eval(mk.replace(".weight", ""))
                except AttributeError:
                    mv = eval(mk.replace(".bias", "_bias"))
            vnp = v.numpy().astype(np.float32)
            mv.data[:] = vnp if k != '_fc.weight' else vnp.T


if __name__ == "__main__":
    # instantiate my net
    model = EfficientNet()
    model.load_weights_from_torch()

    # load cat image
    from PIL import Image
    img = Image.open(
        io.BytesIO(
            fetch(
                "https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"
            )))
    img = img.resize((224, 224))
    img = np.moveaxis(np.array(img), [2, 0, 1], [0, 1, 2])
    img = img.astype(np.float32).reshape(1, 3, 224, 224)
    print(img.shape, img.dtype)

    # run the net
    out = model.forward(Tensor(img))
    print(np.argmax(out.data), np.max(out.data))