コード例 #1
0
    def __init__(self):
        self.x = "hello"
        path = untar_data(URLs.PETS)
        path_anno = path / 'annotations'
        path_images = path / 'images'
        fnames = get_image_files(path_images)

        np.random.seed(2)
        pat = re.compile(r'/([^/]+)_\d+.jpg$')

        bs = 6
        # create data loaderi
        self.data = ImageDataBunch.from_name_re(
            path, fnames, pat, ds_tfms=get_transforms(), size=224,
            bs=bs).normalize(imagenet_stats)

        self.learner = create_cnn(self.dataloader,
                                  models.resnet34,
                                  metrics=error_rate)
コード例 #2
0
cat_images_path = Path("/tmp")
cat_fnames = [
    "/{}_1.jpg".format(c)
    for c in [
        "esdoorn",
        "haagbeuk",
        "kastanje",
        "beuk",
        "populier",
        "eik",
    ]
]
cat_data = ImageDataBunch.from_name_re(
    cat_images_path,
    cat_fnames,
    r"/([^/]+)_\d+.jpg$",
    ds_tfms=get_transforms(),
    size=224,
)
cat_learner = ConvLearner(cat_data, models.resnet50)
cat_learner.model.load_state_dict(
    torch.load("rn50-stage-2.pth", map_location="cpu")
)


@app.route("/upload", methods=["POST"])
async def upload(request):
    data = await request.form()
    bytes = await (data["file"].read())
    return predict_image_from_bytes(bytes)
コード例 #3
0
ファイル: app.py プロジェクト: devforfu/quick_draw_api
ROOT_DIR = (Path().parent / 'templates').as_posix()
MODEL_NAME = os.environ.get('MODEL_NAME', 'resnet50')
SLOW_PREDICTION = False

app = Starlette()
app.debug = True
app.mount('/static', StaticFiles(directory='static'))
env = Environment(loader=FileSystemLoader(ROOT_DIR), trim_blocks=True)

categories = [
    f'/{name.strip()}_1.jpg' for name in Path('categories.txt').open()
]

placeholder_data = ImageDataBunch.from_name_re(Path('/tmp'),
                                               categories,
                                               pat=r'/([^/]+)_\d+.jpg$',
                                               ds_tfms=get_transforms(),
                                               device='cpu',
                                               size=224)

learn = create_cnn(placeholder_data, resnet50)
state = torch.load(f'models/{MODEL_NAME}.pth', map_location='cpu')
learn.model.load_state_dict(state, state)
predictor = Predictor(learn, *imagenet_stats)


@app.route('/')
def home(request):
    template = env.get_template('index.html')
    return HTMLResponse(template.render(static_url='/static'))