def test_imagenet_labels(): # Compare first five labels for quick check IMAGENET_LABELS_FIRST_FIVE = ( "tench", "goldfish", "great_white_shark", "tiger_shark", "hammerhead", ) labels = imagenet_labels() for i in range(5): assert labels[i] == IMAGENET_LABELS_FIRST_FIVE[i] # Check total number of labels assert len(labels) == 1000
def model_to_learner(model: nn.Module, im_size: int = IMAGENET_IM_SIZE) -> Learner: """Create Learner based on pyTorch ImageNet model. Args: model (nn.Module): Base ImageNet model. E.g. models.resnet18() im_size (int): Image size the model will expect to have. Returns: Learner: a model trainer for prediction """ # Currently, fast.ai api requires to pass a DataBunch to create a model trainer (learner). # To use the learner for prediction tasks without retraining, we have to pass an empty DataBunch. # single_from_classes is deprecated, but this is the easiest go-around method. # Create ImageNet data spec as an empty DataBunch. # Related thread: https://forums.fast.ai/t/solved-get-prediction-from-the-imagenet-model-without-creating-a-databunch/36197/5 empty_data = ImageDataBunch.single_from_classes( "", classes=imagenet_labels(), size=im_size).normalize(imagenet_stats) return Learner(empty_data, model)
if args.model == "list": for m in all_models: print(m) sys.exit(0) elif args.model == None: modeln = ["resnet152"] elif args.model == "all": modeln = all_models if not len(args.path): sys.stderr.write("Cannot utilise all models from the webcam. " + "Do not choose --model=all.\n") sys.exit(1) else: modeln = [args.model] try: labels = imagenet_labels() # The 1000 labels. except: sys.stderr.write("Failed to obtain labels probably because of " + "a network connection error.\n") sys.exit(1) # ---------------------------------------------------------------------- # Load the pre-built model # ---------------------------------------------------------------------- for path in args.path: if is_url(path): tempdir = tempfile.gettempdir() imfile = os.path.join(tempdir, "temp.jpg") urllib.request.urlretrieve(path, imfile)
# Prepare processing function # ---------------------------------------------------------------------- # Webcam classification def classify_frame(capture, learner, label): """Use the learner to predict the class label. """ _, frame = capture.read() # Capture frame-by-frame _, ind, prob = learner.predict(Image(utils.cv2torch(frame))) utils.put_text(frame, f"{label[ind]} ({prob[ind]:.2f})") return utils.cv2matplotlib(frame) labels = imagenet_labels() # Load model labels # Load ResNet model # * https://download.pytorch.org/models/resnet18-5c106cde.pth -> ~/.cache/torch/checkpoints/resnet18-5c106cde.pth learn = model_to_learner(models.resnet18(pretrained=True), IMAGENET_IM_SIZE) #learn = model_to_learner(models.resnet152(pretrained=True), IMAGENET_IM_SIZE) #learn = model_to_learner(models.xresnet152(pretrained=True), IMAGENET_IM_SIZE) # Want to load from local copy rather than from ~/.torch? Maybe #learn = load_learner(file="resnet18-5c106cde.pth") #model = untar_data("resnet18-5c106cde.pth") #learn = load_learner(model) func = partial(classify_frame, learner=learn, label=labels)
sys.stderr = stderr option_parser = argparse.ArgumentParser(add_help=False) option_parser.add_argument('path', nargs="+", help='path or url to image') #option_parser.add_argument( # '--model', # help="use this model instead of '{}'.".format(RESNET18)) args = option_parser.parse_args() sys.stderr = devnull # ---------------------------------------------------------------------- labels = imagenet_labels() # Convert a pretrained imagenet model into Learner for prediction. # You can load an exported model by learn = load_learner(path) as well. # --model= # models.BasicBlock models.Darknet models.DynamicUnet # models.ResLayer models.ResNet models.SqueezeNet # models.UnetBlock models.WideResNet models.XResNet # models.alexnet models.darknet models.densenet121 # models.densenet161 models.densenet169 models.densenet201 # models.resnet101 models.resnet152 models.resnet18 # models.resnet34 models.resnet50 models.squeezenet1_0 # models.squeezenet1_1 models.unet models.vgg16_bn # models.vgg19_bn models.wrn models.wrn_22 # models.xception models.xresnet models.xresnet101 # models.xresnet152 models.xresnet18 models.xresnet34