Пример #1
0
def test_imagenet_labels():
    # Compare first five labels for quick check
    IMAGENET_LABELS_FIRST_FIVE = (
        "tench",
        "goldfish",
        "great_white_shark",
        "tiger_shark",
        "hammerhead",
    )

    labels = imagenet_labels()
    for i in range(5):
        assert labels[i] == IMAGENET_LABELS_FIRST_FIVE[i]

    # Check total number of labels
    assert len(labels) == 1000
Пример #2
0
def model_to_learner(model: nn.Module,
                     im_size: int = IMAGENET_IM_SIZE) -> Learner:
    """Create Learner based on pyTorch ImageNet model.

    Args:
        model (nn.Module): Base ImageNet model. E.g. models.resnet18()
        im_size (int): Image size the model will expect to have.

    Returns:
         Learner: a model trainer for prediction
    """

    # Currently, fast.ai api requires to pass a DataBunch to create a model trainer (learner).
    # To use the learner for prediction tasks without retraining, we have to pass an empty DataBunch.
    # single_from_classes is deprecated, but this is the easiest go-around method.
    # Create ImageNet data spec as an empty DataBunch.
    # Related thread: https://forums.fast.ai/t/solved-get-prediction-from-the-imagenet-model-without-creating-a-databunch/36197/5
    empty_data = ImageDataBunch.single_from_classes(
        "", classes=imagenet_labels(), size=im_size).normalize(imagenet_stats)

    return Learner(empty_data, model)
Пример #3
0
if args.model == "list":
    for m in all_models: print(m)
    sys.exit(0)
elif args.model == None:
    modeln = ["resnet152"]
elif args.model == "all":
    modeln = all_models
    if not len(args.path):
        sys.stderr.write("Cannot utilise all models from the webcam. " +
                         "Do not choose --model=all.\n")
        sys.exit(1)
else:
    modeln = [args.model]
    
try:
    labels = imagenet_labels()   # The 1000 labels.
except:
    sys.stderr.write("Failed to obtain labels probably because of " +
                     "a network connection error.\n")
    sys.exit(1)

# ----------------------------------------------------------------------
# Load the pre-built model
# ----------------------------------------------------------------------

for path in args.path:

    if is_url(path):
        tempdir = tempfile.gettempdir()
        imfile = os.path.join(tempdir, "temp.jpg")
        urllib.request.urlretrieve(path, imfile)
Пример #4
0
# Prepare processing function
# ----------------------------------------------------------------------

# Webcam classification


def classify_frame(capture, learner, label):
    """Use the learner to predict the class label.
    """
    _, frame = capture.read()  # Capture frame-by-frame
    _, ind, prob = learner.predict(Image(utils.cv2torch(frame)))
    utils.put_text(frame, f"{label[ind]} ({prob[ind]:.2f})")
    return utils.cv2matplotlib(frame)


labels = imagenet_labels()  # Load model labels

# Load ResNet model
# * https://download.pytorch.org/models/resnet18-5c106cde.pth -> ~/.cache/torch/checkpoints/resnet18-5c106cde.pth
learn = model_to_learner(models.resnet18(pretrained=True), IMAGENET_IM_SIZE)
#learn = model_to_learner(models.resnet152(pretrained=True), IMAGENET_IM_SIZE)
#learn = model_to_learner(models.xresnet152(pretrained=True), IMAGENET_IM_SIZE)

# Want to load from local copy rather than from ~/.torch? Maybe
#learn = load_learner(file="resnet18-5c106cde.pth")

#model = untar_data("resnet18-5c106cde.pth")
#learn = load_learner(model)

func = partial(classify_frame, learner=learn, label=labels)
Пример #5
0
sys.stderr = stderr
option_parser = argparse.ArgumentParser(add_help=False)

option_parser.add_argument('path', nargs="+", help='path or url to image')

#option_parser.add_argument(
#    '--model',
#    help="use this model instead of '{}'.".format(RESNET18))

args = option_parser.parse_args()
sys.stderr = devnull

# ----------------------------------------------------------------------

labels = imagenet_labels()

# Convert a pretrained imagenet model into Learner for prediction.
# You can load an exported model by learn = load_learner(path) as well.
# --model=
# models.BasicBlock 	models.Darknet 	models.DynamicUnet
# models.ResLayer 	models.ResNet 	models.SqueezeNet
# models.UnetBlock 	models.WideResNet 	models.XResNet
# models.alexnet 	models.darknet 	models.densenet121
# models.densenet161 	models.densenet169 	models.densenet201
# models.resnet101 	models.resnet152 	models.resnet18
# models.resnet34 	models.resnet50 	models.squeezenet1_0
# models.squeezenet1_1 	models.unet 	models.vgg16_bn
# models.vgg19_bn 	models.wrn 	models.wrn_22
# models.xception 	models.xresnet 	models.xresnet101
# models.xresnet152 	models.xresnet18 	models.xresnet34