Esempio n. 1
0
async def setup_learner():
  await download_file(export_file_url, path / export_file_name)
  try:
      learn = load_learner(path, export_file_name)
      return learn
  except RuntimeError as e:
      if len(e.args) > 0 and 'CPU-only machine' in e.args[0]:
          print(e)
          message = "\n\nThis model was trained with an old version of fastai and will not work in a CPU environment.\n\nPlease update the fastai library in your training environment and export your model again.\n\nSee instructions for 'Returning to work' at https://course.fast.ai."
          raise RuntimeError(message)
      else:
          raise
Esempio n. 2
0
def get_learner(model_path, model_file, test_path, test_file):
    """
    Loads the model learner from given model and test path and file.

    :param model_path: Path to dir where .pkl file is located.
    :param model_file: If multiple .pkl files are located in the same path, provide the exact model file name.
    :param test_path: Path to dir where test data is located
    :param test_file: Preprocessed test_labels.csv file, as was done in preprocess.py. It eases the fetching of ImageList.
    :return: The model learner.
    """
    learn = load_learner(model_path, file=model_file, test=ImageList.from_csv(test_path, test_file, folder='test'))
    return learn
Esempio n. 3
0
def setup_learner():
    asyncio.get_event_loop().run_until_complete(
        asyncio.ensure_future(download_file()))
    try:
        learn = load_learner('model/', export_file_name)
        return learn.to_fp32()
    except RuntimeError as e:
        if len(e.args) > 0 and 'CPU-only machine' in e.args[0]:
            print(e)
            message = "\n\nThis model was trained with an old version of fastai and will not work in a CPU environment.\n\nPlease update the fastai library in your training environment and export your model again.\n\nSee instructions for 'Returning to work' at https://course.fast.ai."
            raise RuntimeError(message)
        else:
            raise
Esempio n. 4
0
async def setup_model():
    if (pkl_file.is_file()):
        model = load_learner(models_path, model_name+".pkl")
        return model 
    elif(pth_file.is_file()):
        return model_loader()

    elif(avaible_models[model_name].split("/")[0]=="https:"):
        await download_file(file_url, models_path/model_name)
        try: # is a pickle file?
            try:
                model = load_learner(models_path, model_name)
            except:
                model = torch.load(models_path, model_name,map_location="cpu")
            file_name = models_path/model_name
            file_name.replace(file_name.with_suffix('.pkl'))
            return model 
        except: # if isn't a pickle file it is a .pth file, .pth files saved in pytorch
            file_name = models_path/model_name
            file_name.replace(file_name.with_suffix('.pth'))
            return model_loader()
    else:
        raise Exception("something goes wrong, maybe your url doesn't start with http:")
Esempio n. 5
0
def search(bot, update):
    """Send reply of user's message."""
    photo_file = bot.get_file(update.message.photo[-1].file_id)
    photo_file.download('testing.jpeg')
    try:
        bs = 32
        path = "classes"

        np.random.seed(42)
        data = ImageDataBunch.from_folder(
            path,
            train='.',
            valid_pct=0.2,
            ds_tfms=get_transforms(),
            size=224,
            num_workers=4).normalize(imagenet_stats)

        learn = cnn_learner(data, models.resnet34,
                            metrics=error_rate).load("stage-1")
        learn.export()
        learn = load_learner("classes")

        cat, tensor, probs = learn.predict(open_image("testing.jpeg"))

        l = list(probs)
        a = tensor.__str__()
        a = int(a.strip("tensor" "()"))
        l = list(probs)[a]
        l = l.__str__()
        b = float(l.strip("tensor" "()"))
        if b >= 0.9:
            update.message.reply_text(
                '`' + str(cat) + '`',
                parse_mode=ParseMode.MARKDOWN,
                reply_to_message_id=update.message.message_id)
#             print("prediction :")
#             print(cat)
        else:
            cat = "sry I am not sure "
            update.message.reply_text(
                '`' + str(cat) + '`',
                parse_mode=ParseMode.MARKDOWN,
                reply_to_message_id=update.message.message_id)


#             print("prediction :")
#             print("Not Sure")

    except Exception as e:
        update.message.reply_text(e)
Esempio n. 6
0
def numeral():
    img = open_image(request.files['file'])
    learn = load_learner(MODEL)
    pred_class, pred_idx, outputs = learn.predict(img)

    response = Response(
        status=200,
        headers=Headers({
            'Access-Control-Allow-Origin': '*'
        })
    )
    response.data = json.dumps({
        'prediction': learn.data.classes[pred_idx]
    })
    return response
Esempio n. 7
0
def predict_img(img_test):
    # Temporarily displays a message while executing
    with st.spinner('Wait for it...Predicting...'):
        time.sleep(3)

    model = load_learner('Shreshth1991/FossilImage/releases/download/v1.0.1/')
    #model = load_learner('C:\\Users\\H231148\\OneDrive - Halliburton\\Desktop\\models','model.pkl')
    model.predict(img_test)
    pred_class, pred_idx, outputs = model.predict(img_test)
    res = zip(model.data.classes, outputs.tolist())
    predictions = sorted(res, key=lambda x: x[1], reverse=True)
    top_predictions = predictions[0:5]
    df = pd.DataFrame(top_predictions, columns=["Fossil", "Probability"])
    df['Probability'] = df['Probability'] * 100
    st.write(df)
Esempio n. 8
0
def predict(img, display_img):
    st.image(display_img, use_column_width=True)

    with st.spinner('Wait for it...'):
        time.sleep(5)

    model = load_learner('model/images/')
    pred_class = model.predict(img)[0]
    pred_prob = (torch.max(model.predict(img)[2]).item() * 100)

    if str(pred_class) == 'poorichole':
        st.success("This is Poori Chole. Probability -> " + str(pred_prob) +
                   '%.')
    else:
        st.success("This is Samosa. Probability -> " + str(pred_prob) + '%.')
def upload_file():
    
    if request.method == 'POST':
        image = request.files['file']
        filename = secure_filename(image.filename)
        
        #saving file in upload path
        image.save(Path(app.config["IMAGE_UPLOADS"]+"/"+ filename))

        my_dict = {}
        #loading images from upload path      
        img_list_loader = ImageList.from_folder(upload_path)
        
        #Checking if valid images are uploaded
        if len(img_list_loader.items)>0:
            #loading model
            load_model = load_learner(model, 
                                  test=img_list_loader)
            #running inference
            preds,y = load_model.get_preds(ds_type=DatasetType.Test)
            index =0
            
            #Processing results for UI
            for preds,img_src in zip(preds,img_list_loader.items):

                top3_return_msg,top_pred = print_top_3_pred(preds)
                
                if(np.round(preds[top_pred].numpy()*100,2)<threshold):
                    custom_msg = "NA"
                    Prediction_percent = "NA"
                else:
                    custom_msg= str(get_label(int(top_pred)))
                    Prediction_percent = str("{:.2f}%".format(np.round(preds[top_pred].numpy()*100,2)))

                temp_val=[]
                temp_val.append(img_src)
                temp_val.append(custom_msg)
                temp_val.append(Prediction_percent)
                temp_val.append(top3_return_msg)

                my_dict[index]=temp_val
                index+=1

            return render_template('result.html', mydict=my_dict)

            
        elif len(img_list_loader.items)== 0:
            return "ERROR: Invalid image. Go back to upload new image"
Esempio n. 10
0
def segment(path: str, name: str):
    url = path  #cwd+path

    # def acc_camvid(input, target):
    #     target = target.squeeze(1)
    #     mask = target != void_code
    #     return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

    learn = load_learner(cwd + r'\static', "segmentation")

    im = open_image(url)

    data = learn.predict(im)

    data[0].save(cwd + f'\\static\\segment\\{name}.png')

    save(data[1], cwd + f'\\static\\segment\\{name}.pt')
Esempio n. 11
0
def classify_photo(pic=None, destination=None):
    """
    Feeds a picture through the sneaker detection pipeline.

    :param pic: Picture of a sneakers to be classified.
    :type pic: werkzeug.datastructures.FileStorage or Pathstr
    :param pic_filename: UUID4 random filename of picture.
    :type pic_filename: str

    :return: Prediction class, Prediction class index, Output probabilities.
    """
    if pic is not None:
        img = open_image(pic)
    else:
        img = open_image(destination)
    classifier_path = "app/models/cnn_classifier/"
    classifier = load_learner(classifier_path)
    pred_class, pred_idx, outputs = classifier.predict(img)

    if max(outputs) < 0.92:

        # if destination is not None:
        #     img_data = cv2.imread(destination)
        # else:
        #     img_data = cv2.imread(pic)

        # with global_graph.as_default():
        #     img_paths = detector.predict(img_data, destination)
        # img_paths.append(destination)

        # classifier_outputs = []

        # for path in img_paths:
        #     img = open_image(path)
        #     pred_class, pred_idx, outputs = classifier.predict(img)
        #     classifier_outputs.append(
        #         [max(outputs), pred_class, pred_idx, outputs])
        # classifier_outputs.sort(key=lambda x: x[0], reverse=True)

        # pred_class = classifier_outputs[0][1]
        # pred_idx = classifier_outputs[0][2]
        # outputs = classifier_outputs[0][3]

        return pred_class, pred_idx, outputs

    return pred_class, pred_idx, outputs
Esempio n. 12
0
async def setup_learner():
    """
    Setup learner.
    """
    await download_file(MODEL_URL, path / MODEL_NAME)
    try:
        learner = load_learner(path, MODEL_NAME)
        return learner
    except RuntimeError as e:
        if len(e.args) > 0 and "CPU-only machine" in e.args[0]:
            print(e)
            message = "This model was trained with an old version of fastai \
                and will not work in a CPU environment. Please update the \
                fastai library in your training environment and export your \
                model again."

            raise RuntimeError(message)
        else:
            raise
Esempio n. 13
0
def main():
    st.markdown(
        "<h1 style='text-align: center;'>What is this Vietnamese food?�</h1>",
        unsafe_allow_html=True)
    st.markdown(
        "<center><img src='https://www.google.com/logos/doodles/2020/celebrating-banh-mi-6753651837108330.3-2xa.gif' width='500'></center>",
        unsafe_allow_html=True)
    learn = load_learner("models/")

    # Input URL
    st.write("")
    url = st.text_input(
        "URL: ",
        "https://cuisine-vn.com/wp-content/uploads/2020/03/google-first-honors-vietnamese-bread-promoting-more-than-10-countries-around-the-world-2.jpg",
    )

    if url:
        # Get and show image
        img_input = open_image_url(url)
        st.markdown("<h2 style='text-align: center;'>Image📷</h2>",
                    unsafe_allow_html=True)
        st.markdown(f"<center><img src='{url}' width='500'></center>",
                    unsafe_allow_html=True)

        # Predict
        st.write("")
        st.markdown("<h2 style='text-align: center;'>Output�</h2>",
                    unsafe_allow_html=True)
        pred_class, pred_idx, outputs = learn.predict(img_input)
        st.markdown(info[str(pred_class)])
        st.markdown(f"**Probability:** {outputs[pred_idx] * 100:.2f}%")

        # Plot
        plot_probs(outputs)

    # Reference
    st.markdown("""## Resources
[![](https://img.shields.io/badge/GitHub-View_Repository-blue?logo=GitHub)](https://github.com/chriskhanhtran/vn-food-app)
- [How the Vietnamese Food Classifier was trained](https://github.com/chriskhanhtran/vn-food-app/blob/master/notebook.ipynb)
- [Fast AI: Lesson 1 - What's your pet](https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson1-pets.ipynb)
- [Fast AI: Lesson 2 - Creating your own dataset from Google Images](https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson2-download.ipynb)
- [PyImageSearch: How to (quickly) build a deep learning image dataset](https://www.pyimagesearch.com/2018/04/09/how-to-quickly-build-a-deep-learning-image-dataset/)
""")
Esempio n. 14
0
    def pre_analyse(self, config):
        path, fname = os.path.split(config["path"])
        self.LEARNER = load_learner(path, fname)
        self.LABELS = self.LEARNER.data.classes

        if "labels" in config.keys() and len(config["labels"]) > 0:
            self.LABELS = [lb for lb in self.LABELS if lb in config["labels"]]

        self.logger(f"Model successfully loaded from {path}/{fname}.")

        def get_preds(img_path):
            img = open_image(img_path)
            _, _, losses = self.LEARNER.predict(img)
            return [
                x
                for x in zip(self.LEARNER.data.classes, map(float, losses))
                if x[0] in self.LABELS
            ]

        self.get_preds = get_preds
Esempio n. 15
0
def predict_fifa():
    import fastai.vision as fastai
    global fifa_learn

    name = request.args.get('name')
    path = './static/uploads/%s' % name
    print('------------------')
    print(path)
    if not os.path.exists(path):
        return "File doesn't exist, soz, go to the home page! %s" % path

    img = fastai.open_image(path)
    if fifa_learn is None:
        fifa_learn = fastai.load_learner('.', 'fifa.learn')
    pred_class, pred_idx, outputs = fifa_learn.predict(img)
    return render_template('fifa-or-real-predict.html',
                           img=path,
                           predict_class=pred_class,
                           predict_confidence=outputs,
                           name=name)
Esempio n. 16
0
def classify(name: str):
    try:
        # url = cwd+name
        # print(url)
        url = name
        learner = load_learner(cwd + r'\static', r'classification.pkl')

        # print(learner.data.classes)

        im = open_image(url)
        data = []
        with open(cwd + r'\static\classes.pkl', 'rb') as d:
            data = load(d)

        res = learner.predict(im)
        predictedClass = data[res[1]]
        # print(predictedClass)
    except Exception as e:
        predictedClass = f'Unknown Error occured {e}'
    return predictedClass, url
Esempio n. 17
0
    def __init__(self):
        # load room type model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.learn = fv.load_learner('./models', file='resnet-roomtype.pkl')
        self.learn.model = self.learn.model.module

        # Load caption model
        checkpoint = torch.load(
            'BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar',
            map_location=str(device))
        self.decoder = checkpoint['decoder']
        self.decoder = self.decoder.to(device)
        self.decoder.eval()
        self.encoder = checkpoint['encoder']
        self.encoder = self.encoder.to(device)
        self.encoder.eval()

        # Load word map (word2ix)
        with open('WORDMAP_coco_5_cap_per_img_5_min_word_freq.json', 'r') as j:
            self.word_map = json.load(j)
        self.rev_word_map = {v: k for k, v in self.word_map.items()}  # ix2word
Esempio n. 18
0
def predict(img, display_img):

    # Display the test image
    st.image(display_img, use_column_width=True)

    # Temporarily displays a message while executing
    with st.spinner('Wait for it...'):
        time.sleep(3)

    # Load model and make prediction
    model = load_learner('model/data/train/')
    pred_class = model.predict(img)[0]
    pred_prob = round(torch.max(model.predict(img)[2]).item() * 100)

    # Display the prediction
    if str(pred_class) == 'mtp':
        st.success("This is Son Tung M-TP with the probability of " +
                   str(pred_prob) + '%.')
    else:
        st.success("This is G-Dragon with the probability of " +
                   str(pred_prob) + '%.')
Esempio n. 19
0
def main(req: func.HttpRequest) -> func.HttpResponse:
    logging.info('HTTP trigger function processed a request.')
    # Get Image base64 uri and convert it to png
    try:
        req_body = req.get_json()
    except ValueError:
        return func.HttpResponse("Error: HTTP Request Body is empty",
                                 status_code=400)

    img_url = req_body['imgURL']
    path = Path.cwd()
    cnn_model = load_learner(path=path, file="model.pkl")
    if img_url:
        header, img_b64_data = img_url.split(",", 1)
        with open("alphabet.png", "wb") as f:
            try:
                f.write(base64.b64decode(img_b64_data))
            except Exception:
                return func.HttpResponse(
                    f"Error: Unable to parse data url in request body",
                    status_code=400)

        # Load up the image and predict alphabet
        img = open_image("alphabet.png")
        prediction = cnn_model.predict(img)
        class_idx = prediction[1].item()
        predicted_alpha = ALPHABETS[class_idx]
        logging.info(f"Predicted Alphabet: {ALPHABETS[class_idx]}")
    else:
        return func.HttpResponse(
            f"Error: Request body must have field `imgURL` that contains base64 image uri",
            status_code=400)

    res = {"Predicted Alphabet": predicted_alpha}
    headers = {"Access-Control-Allow-Origin": "*"}
    return func.HttpResponse(json.dumps(res),
                             mimetype="application/json",
                             headers=headers)
Esempio n. 20
0
def model_from_s3(s3_path):
    """
    Load model from S3 as `fastai.Learner`.

    Parameters
    -------------
    s3_path: string
        S3 path to bucket to download from
        Ex: s3://recycling-classification/folder/filename.pkl

    """
    bucket, key = get_bucket_and_key(s3_path)

    s3_client = boto3.client('s3')

    temp_file = NamedTemporaryFile()
    s3_client.download_file(bucket, key, temp_file.name)
    temp_file_pathname = temp_file.name[:temp_file.name.rindex('/')]
    temp_file_filename = temp_file.name[temp_file.name.rindex('/')+1:]
    learner = load_learner(temp_file_pathname, file=temp_file_filename)
    temp_file.close()

    return learner
Esempio n. 21
0
def upload_file():
	if request.method == 'GET':
		return render_template('index.html')
	if request.method == 'POST':
			# check if the post request has the file part
		if 'file' not in request.files:
			flash('No file part')
			return redirect(request.url)
	file = request.files['file']
	# if user does not select file, browser also
	# submit an empty part without filename
	if file.filename == '':
		flash('No selected file')
		return redirect(request.url)
	if file and allowed_file(file.filename):
		filename = secure_filename(file.filename)
		file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
		model_path = 'model/'
		learn = load_learner(model_path)
		img = open_image(UPLOAD_FOLDER + filename)
		prediction = learn.predict(img)
		return render_template('result.html', number = prediction[0], score = prediction[2])
	return 'Null'
Esempio n. 22
0
def main():
    #load trained learner
    learn = load_learner(PATH_DATA_ID_44)
    #load dictionary of subject:index_interval
    subject_segment_map = load_object(PATH_DATA + 'mapper_subject.file')
    #load dataframe. with name of files.png
    subject_segment_map_img = load_object(PATH_DATA_ID_44 +
                                          'image_name_map.file')
    #create dataframe from index interval and name of subject
    map_df = from_map_to_df(subject_segment_map, 'subjects')
    #the resulting dataframe has columns subjects('id_subject')
    # and name('file name of .png')
    df_images_id = map_df.join(subject_segment_map_img)

    predictions_hold = []
    for each_png in df_images_id.name:
        x = open_image(path_to_images + each_png)
        p = learn.predict(x)
        tensor = p[2]
        single_pred = np.round(tensor.cpu().detach().numpy()[0], 5)
        predictions_hold = predictions_hold + [single_pred]

    df_prediction=\
        df_images_id.join(pd.DataFrame(predictions_hold,
                                        columns=['Predicted_Target_id44']))
    df_prediction_decision = df_prediction.groupby('subjects').mean()
    df_prediction_decision.rename(
        columns={'Predicted_Target_id44': 'Decision_id44'}, inplace=True)
    print(df_prediction.head())
    print(df_prediction_decision.head)

    df_prediction.to_csv(PATH_RESULTS + 'Predicted_Target_id44.csv')
    df_prediction_decision.to_csv(PATH_RESULTS + 'Decision_id44.csv')


#if __name__ == "__main__":
#    main()
Esempio n. 23
0
async def analyze(request):
    """
    Analyze an image.
    """
    # Get image
    img_data = await request.form()
    img_bytes = await (img_data["file"].read())
    img = open_image(BytesIO(img_bytes))

    # Load model
    learner = load_learner(path, MODEL_NAME)

    # Make predictions
    pred_class, pred_idx, outputs = learner.predict(img)

    # Get predicted class
    result = str(pred_class).capitalize()

    # Combine labels and outputs probabilities
    predictions = dict(zip(LABELS, outputs.tolist()))

    # Transform predictions into percentages
    predictions = dict((k, percent(v)) for k, v in predictions.items())

    # Sort predictions in descending order
    predictions = dict(
        sorted(predictions.items(), key=operator.itemgetter(1), reverse=True))

    # Keep only 3 best predictions
    predictions = dict(list(predictions.items())[0:3])

    # Format predictions
    labels = list(
        map(lambda x: x.replace("_", " ").capitalize(), predictions.keys()))
    data = list(predictions.values())

    return JSONResponse({"result": result, "labels": labels, "data": data})
Esempio n. 24
0
    open_image,
    load_learner,
)
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import torch
from pathlib import Path
from io import BytesIO
import sys
import uvicorn
import aiohttp
import asyncio
import os


learner = load_learner(Path("/app"), Path("/app/training/trained_model.pkl"))

# TODO: less open CORS def. We'd need to pass the frontend server's domain name via ENV var.
middleware = [Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])]
app = Starlette(middleware=middleware)


@app.route("/upload", methods=["POST"])
async def upload(request):
    data = await request.form()
    bytes = await (data["file"].read())
    return predict_image_from_bytes(bytes)


@app.route("/classify-url", methods=["POST"])
async def classify_url(request):
Esempio n. 25
0
    json_str = read_file.read()
label_dict = literal_eval(json_str)
gt_labels = [k + '_' + v for k, v in label_dict.items()]


def get_prediction_image(path, sz=256):
    bgr_img = imread(path)
    b, g, r = split(bgr_img)
    rgb_img = merge([r, g, b])
    rgb_img = rgb_img / 255.0
    img = Image(px=pil2tensor(rgb_img, np.float32))
    img = img.resize((3, sz, sz))
    return img.px.reshape(1, 3, sz, sz)


learner = load_learner(path='.', file='export.pkl', device='cpu')
model = learner.model.cpu()

img_read = get_prediction_image(sys.argv[1])
with torch_no_grad():
    output = model(img_read)

bboxes, preds, scores = get_predictions(output, 0, detect_thresh=0.015)

pred_labels = [classes[pred] for pred in preds]

missing = set(gt_labels) - set(pred_labels)

if (len(missing.union(rotated)) == 0):
    print("No Errors")
for mc in missing:
Esempio n. 26
0
from io import BytesIO
import sys
import uvicorn
import aiohttp
import asyncio


async def get_bytes(url):
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            return await response.read()


app = Starlette()

pneumonia_learner = load_learner('')


@app.route("/upload", methods=["POST"])
async def upload(request):
    data = await request.form()
    bytes = await (data["file"].read())
    return predict_image_from_bytes(bytes)


@app.route("/classify-url", methods=["GET"])
async def classify_url(request):
    bytes = await get_bytes(request.query_params["url"])
    return predict_image_from_bytes(bytes)

Esempio n. 27
0
from pathlib import Path

from fastai.vision import load_learner, open_image

MODEL_DIR = Path(__file__).resolve().parents[1] / "models"
MODEL_NAME = "multilabel_model_20190407.pkl"
model = load_learner(MODEL_DIR, MODEL_NAME)
CLASSES = model.data.classes


def predict_single(path):
    image = open_image(path)
    pred_classes, preds, probs = model.predict(image)
    probs = [prob.item() for prob in probs]
    return dict(zip(CLASSES, probs))


def predict_multiple(path_list):
    predictions = []
    for path in path_list:
        predictions.append(predict_single(path))
    return predictions


if __name__ == "__main__":
    test_image_path = Path(__file__).parent / "test/flower.jpeg"
    prediction = predict_single(test_image_path)
Esempio n. 28
0
from fastai.vision import load_learner, open_image
import cv2
import numpy as np

learn = load_learner('E:\\STUDY\\fastai\\Pytorch\\Detection\\models',
                     'export.pkl')
cap = cv2.VideoCapture(0)
while True:
    ret, frame = cap.read()

    a = cv2.resize(frame, (224, 224), interpolation=cv2.INTER_AREA)

    cv2.imwrite('a.png', a)

    a = open_image('a.png')

    pred_class = 0
    font = cv2.FONT_HERSHEY_SIMPLEX
    pred_class, pred_idx, outputs = learn.predict(a)
    if (outputs[pred_idx] < 0.4):
        pred_class = 'None'
    cv2.putText(frame, str(pred_class), (50, 50), font, 1, (0, 255, 255), 2,
                cv2.LINE_4)
    cv2.imshow("frame", frame)
    #  print(pred_class)

    key = cv2.waitKey(1)
    if key == 27:
        break
cap.release()
cv2.destroyAllWindows()
def init_model(model_file):
    learn = load_learner(path="", file=model_file)
    return learn
Esempio n. 30
0
 def __init__(self):
     self.device = torch.device(
         "cuda" if torch.cuda.is_available() else "cpu")
     self.model = load_learner(path_model)
     self.to_tensor = transforms.ToTensor()