Beispiel #1
0
vocab_size = len(vocab)
num_layers = 1
MODEL_PATH = os.getenv('MODEL_PATH')  #"my_checkpoint.pth.tar"
MODEL_URL = os.getenv(
    'MODEL_URL'
)  #"https://vonage-models.s3.amazonaws.com/my_checkpoint.pth.tar"

if not path.exists(MODEL_PATH):
    print("downloading model....")
    r = requests.get(MODEL_URL)
    open(MODEL_PATH, 'wb').write(r.content)

print('done!\nloading up the saved model weights...')

myModel = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to("cpu")
myModel.load_state_dict(
    torch.load(MODEL_PATH, map_location=torch.device('cpu'))['state_dict'])
myModel.eval()

app = Flask(__name__)

UPLOAD_FOLDER = os.path.dirname(os.path.abspath(__file__)) + '/uploads/'

ALLOWED_EXTENSIONS = set(['png', 'jpg', 'jpeg', 'gif'])

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER


def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
def inferrence(model, dataset, image):
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    model.eval()
    image = transform(image).unsqueeze(0).to(
        'cuda' if torch.cuda.is_available() else 'cpu')
    image_predict = model.caption_image(image, dataset.vocab)
    print("Predicted :" + " ".join(image_predict))


if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_img = Image.open("test_examples/footable.jpg").convert("RGB")
    _, dataset = get_loader(root_folder="archive/Images",
                            annotation_file="archive/captions.txt",
                            transform=None,
                            batch_size=64,
                            num_workers=0)
    embed_size = 256
    hidden_size = 256
    vocab_size = len(dataset.vocab)
    num_layers = 1
    model = CNNtoRNN(embed_size, hidden_size, vocab_size,
                     num_layers).to(device)
    model.load_state_dict(torch.load("my_checkpoint.pth.tar")["state_dict"])
    model.eval()
    inferrence(model, dataset, test_img)