예제 #1
0
 def load(self):
     if torch.cuda.is_available():
         self.session = crnn.CRNN(32, 1, 37, 256, 1).cuda()
         self.cuda = True
     else:
         self.session = crnn.CRNN(32, 1, 37, 256, 1)
     self.session.load_state_dict(torch.load(self.model_path))
     self.converter = utils.strLabelConverter(self.alphabet)
     self.transformer = dataset.resizeNormalize((100, 32))
예제 #2
0
 def load(self):
     if torch.cuda.is_available():
         self.session = crnn.CRNN(32, 1, 37, 256, 1).cuda()
         self.cuda = True
     else:
         self.session = crnn.CRNN(32, 1, 37, 256, 1)
     self.session.load_state_dict(torch.load(self.model_path))
     self.converter = utils.strLabelConverter(self.alphabet)
     self.transformer = dataset.resizeNormalize((100, 32))
예제 #3
0
 def load(self):
     logging.info("Loding CRNN model first apply will be slow")
     if torch.cuda.is_available():
         self.session = crnn.CRNN(32, 1, 37, 256, 1).cuda()
         self.cuda = True
     else:
         self.session = crnn.CRNN(32, 1, 37, 256, 1)
     self.session.load_state_dict(torch.load(self.model_path))
     self.session.eval()
     self.converter = utils.strLabelConverter(self.alphabet)
     self.transformer = dataset.resizeNormalize((100, 32))
예제 #4
0
def recognize_text(video_pk):
    """
    Recognize text in regions with name CTPN_TEXTBOX using CRNN
    :param detector_pk
    :param video_pk
    :return:
    """
    setup_django()
    from dvaapp.models import Region
    from django.conf import settings
    from PIL import Image
    import sys
    video_pk = int(video_pk)
    import dvalib.crnn.utils as utils
    import dvalib.crnn.dataset as dataset
    import torch
    from torch.autograd import Variable
    from PIL import Image
    import dvalib.crnn.models.crnn as crnn
    model_path = '/root/DVA/dvalib/crnn/data/crnn.pth'
    alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
    model = crnn.CRNN(32, 1, 37, 256, 1)
    model.load_state_dict(torch.load(model_path))
    converter = utils.strLabelConverter(alphabet)
    transformer = dataset.resizeNormalize((100, 32))
    for r in Region.objects.all().filter(video_id=video_pk,
                                         object_name='CTPN_TEXTBOX'):
        img_path = "{}/{}/detections/{}.jpg".format(settings.MEDIA_ROOT,
                                                    video_pk, r.pk)
        image = Image.open(img_path).convert('L')
        image = transformer(image)
        image = image.view(1, *image.size())
        image = Variable(image)
        model.eval()
        preds = model(image)
        _, preds = preds.max(2)
        preds = preds.squeeze(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
        dr = Region()
        dr.video_id = r.video_id
        dr.object_name = "CRNN_TEXT"
        dr.x = r.x
        dr.y = r.y
        dr.w = r.w
        dr.h = r.h
        dr.region_type = Region.ANNOTATION
        dr.metadata_text = sim_pred
        dr.frame_id = r.frame_id
        dr.save()