示例#1
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu,
         visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False
    abc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    #print(abc)
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose(
        [Rotation(), Resize(size=(input_size[0], input_size[1]))])
    if data_path is not None:

        data = LoadDataset(data_path=data_path,
                           mode="test",
                           transform=transform)

    seq_proj = [int(x) for x in seq_proj.split('x')]

    #net = load_model(abc, seq_proj, backend, snapshot, cuda)
    net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend)
    #net = nn.DataParallel(net)
    if snapshot is not None:
        load_weights(net, torch.load(snapshot))
    if cuda:
        net = net.cuda()
    #import pdb;pdb.set_trace()
    net = net.eval()
    detect(net, data, cuda, visualize)
示例#2
0
    val_loader = torch.utils.data.DataLoader(testdataset,
                                             shuffle=False,
                                             batch_size=opt.batch_size,
                                             num_workers=int(opt.workers),
                                             collate_fn=alignCollate(
                                                 imgH=imgH,
                                                 imgW=imgW,
                                                 keep_ratio=keep_ratio))

    alphabet = keys.alphabetChinese
    print("char num ", len(alphabet))
    model = CRNN(32, 1, len(alphabet) + 1, 256, 1)

    converter = strLabelConverter(''.join(alphabet))

    state_dict = torch.load("../SceneOcr/model/ocr-lstm.pth",
                            map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if "num_batches_tracked" not in k:
            # name = name.replace('module.', '')  # remove `module.`
            new_state_dict[name] = v
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])

    # load params
    model.load_state_dict(new_state_dict)
    model.eval()

    curAcc = val(model, converter, val_loader, max_iter=5)
示例#3
0
class Demo(object):
    def __init__(self, args):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        self.args = args
        self.alphabet = alphabetChinese
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False)
        self.transformer = resizeNormalize(args.imgH)

        print('loading pretrained model from %s' % args.model_path)
        checkpoint = torch.load(args.model_path)
        if 'model_state_dict' in checkpoint.keys():
            checkpoint = checkpoint['model_state_dict']
        from collections import OrderedDict
        model_dict = OrderedDict()
        for k, v in checkpoint.items():
            if 'module' in k:
                model_dict[k[7:]] = v
            else:
                model_dict[k] = v
        self.net.load_state_dict(model_dict)

        if args.cuda and torch.cuda.is_available():
            print('available gpus is,', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
        
        self.net.eval()
    
    def predict(self, image):
        image = self.transformer(image)
        if torch.cuda.is_available():
            image = image.cuda()
        image = image.view(1, *image.size())
        image = Variable(image)

        preds = self.net(image)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        raw_pred = self.converter.decode(preds.data, preds_size.data, raw=True)
        sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False)
        print('%-20s => %-20s' % (raw_pred, sim_pred))

        return sim_pred

    def predict_batch(self, images):
        N = len(images)
        n_batch = N // self.args.batch_size
        n_batch += 1 if N % self.args.batch_size else 0
        res = []
        for i in range(n_batch):
            batch = images[i*self.args.batch_size : min((i+1)*self.args.batch_size, N)]
            maxW = 0
            for i in range(len(batch)):
                batch[i] = self.transformer(batch[i])
                imgW = batch[i].shape[2]
                maxW = max(maxW, imgW)
            
            for i in range(len(batch)):
                if batch[i].shape[2] < maxW:
                    batch[i] = torch.cat((batch[i], torch.zeros((1, self.args.imgH, maxW-batch[i].shape[2]), dtype=batch[i].dtype)), 2) 
            batch_imgs = torch.cat([t.unsqueeze(0) for t in batch], 0)
            preds = self.net(batch_imgs)
            preds_size = Variable(torch.IntTensor([preds.size(0)]*len(batch)))
            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            raw_preds = self.converter.decode(preds.data, preds_size.data, raw=True)
            sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)
            for raw_pred, sim_pred in zip(raw_preds, sim_preds):
                print('%-20s => %-20s' % (raw_pred, sim_pred))
            res.extend(sim_preds)
        return res

    def inference(self, image_path, batch_pred=False):
        if os.path.isdir(image_path):
            file_list = os.listdir(image_path)
            image_list = [os.path.join(image_path, i) for i in file_list if i.rsplit('.')[-1].lower() in img_types] 
        else:
            image_list = [image_path]
        
        res = []
        images = []
        for img_path in image_list:
            image = Image.open(img_path).convert('L')
            if not batch_pred:
                sim_pred = self.predict(image)
                res.append(sim_pred)
            else:
                images.append(image)
        if batch_pred and images:
            res = self.predict_batch(images)
        return res
示例#4
0
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as f:
            od_graph_def.ParseFromString(f.read())
            tf.import_graph_def(od_graph_def, name='')

    # Load CRNN model
    alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
    crnn = CRNN(32, 1, 37, 256)
    if torch.cuda.is_available():
        crnn = crnn.cuda()
    crnn.load_state_dict(torch.load(args.recognition_model_path))
    converter = utils.strLabelConverter(alphabet)
    transformer = dataset.resizeNormalize((100, 32))
    crnn.eval()

    # Open a video file or an image file
    cap = cv2.VideoCapture(args.input if args.input else 0)

    while cv2.waitKey(1) < 0:
        has_frame, frame = cap.read()
        if not has_frame:
            cv2.waitKey(0)
            break

        im_height, im_width, _ = frame.shape
        tf_frame = np.expand_dims(frame, axis=0)
        output_dict = run_inference_for_single_image(tf_frame, detection_graph)

        for i in range(output_dict['num_detections']):
示例#5
0
class Trainer(object):
    def __init__(self):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        if args.chars_file == '':
            self.alphabet = alphabetChinese
        else:
            self.alphabet = utils.load_chars(args.chars_file)
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.train_dataloader, self.val_dataloader = self.dataloader(
            self.alphabet)
        self.criterion = CTCLoss()
        self.optimizer = self.get_optimizer()
        self.converter = utils.strLabelConverter(self.alphabet,
                                                 ignore_case=False)
        self.best_acc = 0.00001

        model_name = '%s' % (args.dataset_name)
        if not os.path.exists(args.save_prefix):
            os.mkdir(args.save_prefix)
        args.save_prefix += model_name

        if args.pretrained != '':
            print('loading pretrained model from %s' % args.pretrained)
            checkpoint = torch.load(args.pretrained)

            if 'model_state_dict' in checkpoint.keys():
                # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                args.start_epoch = checkpoint['epoch']
                self.best_acc = checkpoint['best_acc']
                checkpoint = checkpoint['model_state_dict']

            from collections import OrderedDict
            model_dict = OrderedDict()
            for k, v in checkpoint.items():
                if 'module' in k:
                    model_dict[k[7:]] = v
                else:
                    model_dict[k] = v
            self.net.load_state_dict(model_dict)

        if not args.cuda and torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )

        elif args.cuda and torch.cuda.is_available():
            print('available gpus is ', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
            self.criterion = self.criterion.cuda()

    def dataloader(self, alphabet):
        # train_transform = transforms.Compose(
        #     [transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        #     resizeNormalize(args.imgH)])
        # train_dataset = BaseDataset(args.train_dir, alphabet, transform=train_transform)
        train_dataset = NumDataset(args.train_dir,
                                   alphabet,
                                   transform=resizeNormalize(args.imgH))
        train_dataloader = DataLoader(dataset=train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        if os.path.exists(args.val_dir):
            # val_dataset = BaseDataset(args.val_dir, alphabet, transform=resizeNormalize(args.imgH))
            val_dataset = NumDataset(args.val_dir,
                                     alphabet,
                                     mode='test',
                                     transform=resizeNormalize(args.imgH))
            val_dataloader = DataLoader(dataset=val_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
        else:
            val_dataloader = None

        return train_dataloader, val_dataloader

    def get_optimizer(self):
        if args.optimizer == 'sgd':
            optimizer = optim.SGD(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(
                self.net.parameters(),
                lr=args.lr,
                betas=(args.beta1, 0.999),
            )
        else:
            optimizer = optim.RMSprop(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        return optimizer

    def train(self):
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        log_file_path = args.save_prefix + '_train.log'
        log_dir = os.path.dirname(log_file_path)
        if log_dir and not os.path.exists(log_dir):
            os.mkdir(log_dir)
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)
        logger.info(args)
        logger.info('Start training from [Epoch {}]'.format(args.start_epoch +
                                                            1))

        losses = utils.Averager()
        train_accuracy = utils.Averager()

        for epoch in range(args.start_epoch, args.nepoch):
            self.net.train()
            btic = time.time()
            for i, (imgs, labels) in enumerate(self.train_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                self.optimizer.zero_grad()
                loss_avg.backward()
                self.optimizer.step()

                losses.update(loss_avg.item(), batch_size)

                _, preds_m = preds.max(2)
                preds_m = preds_m.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds_m.data,
                                                  preds_size.data,
                                                  raw=False)
                n_correct = 0
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1
                train_accuracy.update(n_correct, batch_size, MUL_n=False)

                if args.log_interval and not (i + 1) % args.log_interval:
                    logger.info(
                        '[Epoch {}/{}][Batch {}/{}], Speed: {:.3f} samples/sec, Loss:{:.3f}'
                        .format(epoch + 1, args.nepoch, i + 1,
                                len(self.train_dataloader),
                                batch_size / (time.time() - btic),
                                losses.val()))
                    losses.reset()

            logger.info(
                'Training accuracy: {:.3f}, [#correct:{} / #total:{}]'.format(
                    train_accuracy.val(), train_accuracy.sum,
                    train_accuracy.count))
            train_accuracy.reset()

            if args.val_interval and not (epoch + 1) % args.val_interval:
                acc = self.validate(logger)
                if acc > self.best_acc:
                    self.best_acc = acc
                    save_path = '{:s}_best.pth'.format(args.save_prefix)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)
                logging.info("best acc is:{:.3f}".format(self.best_acc))
                if args.save_interval and not (epoch + 1) % args.save_interval:
                    save_path = '{:s}_{:04d}_{:.3f}.pth'.format(
                        args.save_prefix, epoch + 1, acc)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)

    def validate(self, logger):
        if self.val_dataloader is None:
            return 0
        logger.info('Start validate.')
        losses = utils.Averager()
        self.net.eval()
        n_correct = 0
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(self.val_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor(
                    [preds.size(0)] * batch_size)  # timestep * batchsize
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                losses.update(loss_avg.item(), batch_size)

                _, preds = preds.max(2)
                preds = preds.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds.data,
                                                  preds_size.data,
                                                  raw=False)
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1

        accuracy = n_correct / float(losses.count)

        logger.info(
            'Evaling loss: {:.3f}, accuracy: {:.3f}, [#correct:{} / #total:{}]'
            .format(losses.val(), accuracy, n_correct, losses.count))

        return accuracy