Пример #1
0
def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size,
                   workers, is_train, keep_ratio):
    num_synthetic_dataset = len(synthetic_dataset)
    num_real_dataset = len(real_dataset)

    synthetic_indices = list(np.random.permutation(num_synthetic_dataset))
    synthetic_indices = synthetic_indices[num_real_dataset:]
    real_indices = list(
        np.random.permutation(num_real_dataset) + num_synthetic_dataset)
    concated_indices = synthetic_indices + real_indices
    assert len(concated_indices) == num_synthetic_dataset

    sampler = SubsetRandomSampler(concated_indices)
    concated_dataset = ConcatDataset([synthetic_dataset, real_dataset])
    print('total image: ', len(concated_dataset))

    data_loader = DataLoader(concated_dataset,
                             batch_size=batch_size,
                             num_workers=workers,
                             shuffle=False,
                             pin_memory=True,
                             drop_last=True,
                             sampler=sampler,
                             collate_fn=AlignCollate(imgH=height,
                                                     imgW=width,
                                                     keep_ratio=keep_ratio))
    return concated_dataset, data_loader
    def forward(self, image_path, coordinates):
        """
        @input
        image paths : One image path without '.xml' or '.png'
        coordinates: A List of coordinates

        @output : A List of characters
        """

        args = self.args
        encoder = self.encoder
        decoder = self.decoder
        if args.cuda:
            device = self.device

        image = get_data_image(image_path, args)

        cropped_images = crop_image(image,
                                    coordinates,
                                    args,
                                    resample=Image.BICUBIC)  # list of imgs

        cropped_images = [{
            'images': x,
            'rec_targets': 0,
            'rec_lengths': 0
        } for x in cropped_images]

        #   data loader
        test_pred = []
        test_image = []

        test_loader = DataLoader(cropped_images,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 collate_fn=AlignCollate(imgH=args.height,
                                                         imgW=args.width,
                                                         keep_ratio=True))

        for batch_idx, batch in enumerate(test_loader):
            if args.cuda:
                x = batch[0].to(device)
            else:
                x = batch[0]

            encoder_feats = self.encoder(x)
            rec_pred, rec_pred_scores = decoder.beam_search(encoder_feats,\
                                                    args.beam_width, args.eos)

            rec_pred = rec_pred.detach().cpu().numpy()
            test_pred.extend(rec_pred)
            test_image.extend(x.detach().cpu().numpy())

        test_pred_char = [
            self.idx2char(x, self.id2char_dict) for x in test_pred
        ]

        return test_pred_char
Пример #3
0
def get_data_txt(data_dir,
                 gt_file_path,
                 embed_dir,
                 voc_type,
                 max_len,
                 num_samples,
                 height,
                 width,
                 batch_size,
                 workers,
                 is_train,
                 keep_ratio):
    if isinstance(data_dir, list) and len(data_dir) > 1:
        dataset_list = []
        for data_dir_, gt_file_, embed_dir_ in zip(data_dir,
                                                   gt_file_path, embed_dir):
            # dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples))
            dataset_list.append(CustomDataset(
                data_dir_, gt_file_, embed_dir_, voc_type, max_len, num_samples))
        dataset = ConcatDataset(dataset_list)
    else:
        # dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples)
        dataset = CustomDataset(data_dir, gt_file_path,
                                embed_dir, voc_type, max_len, num_samples)
    print('total image: ', len(dataset))

    if is_train:
        """
        data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers,
          shuffle=True, pin_memory=True, drop_last=True,
          collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))
        """
        data_loader = DataLoader(dataset, batch_size=batch_size,
                                 num_workers=workers,
                                 shuffle=True, pin_memory=True, drop_last=True,
                                 collate_fn=AlignCollate(
                                     imgH=height,
                                     imgW=width,
                                     keep_ratio=keep_ratio))
    else:
        data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers,
                                 shuffle=False, pin_memory=True, drop_last=False,
                                 collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))

    return dataset, data_loader
Пример #4
0
def get_data_lmdb(data_dir, voc_type,
                  max_len, num_samples,
                  height, width,
                  batch_size, workers,
                  is_train, keep_ratio,
                  voc_file=None):
    if isinstance(data_dir, list):
        dataset_list = []
        for data_dir_ in data_dir:
            dataset_list.append(LmdbDataset(
                data_dir_, voc_type, max_len, num_samples, voc_file=voc_file))
        dataset = ConcatDataset(dataset_list)
    else:
        dataset = LmdbDataset(data_dir, voc_type, max_len,
                              num_samples, voc_file=voc_file)
    print('total image: ', len(dataset))

    if is_train:
        data_loader = DataLoader(
            dataset, batch_size=batch_size,
            num_workers=workers,
            shuffle=True, pin_memory=True, drop_last=True,
            collate_fn=AlignCollate(imgH=height,
                                    imgW=width,
                                    keep_ratio=keep_ratio)
        )
    else:
        data_loader = DataLoader(
            dataset, batch_size=batch_size,
            num_workers=workers,
            shuffle=False, pin_memory=True,
            drop_last=False,
            collate_fn=AlignCollate(imgH=height,
                                    imgW=width,
                                    keep_ratio=keep_ratio)
        )

    return dataset, data_loader
Пример #5
0
def get_data(data_dir,
             voc_type,
             max_len,
             num_samples,
             height,
             width,
             batch_size,
             workers,
             is_train,
             keep_ratio,
             augment=False):
    transform = albu.Compose([
        albu.RGBShift(p=0.5),
        albu.RandomBrightnessContrast(p=0.5),
        albu.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=0.5)
    ]) if augment else None

    if isinstance(data_dir, list):
        dataset = ConcatDataset([
            LmdbDataset(data_dir_, voc_type, max_len, num_samples, transform)
            for data_dir_ in data_dir
        ])
    else:
        dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples,
                              transform)
    print('total image: ', len(dataset))

    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             num_workers=workers,
                             shuffle=is_train,
                             pin_memory=True,
                             drop_last=is_train,
                             collate_fn=AlignCollate(imgH=height,
                                                     imgW=width,
                                                     keep_ratio=keep_ratio))

    return dataset, data_loader
Пример #6
0
def get_data(data_dir,
             voc_type,
             max_len,
             num_samples,
             height,
             width,
             batch_size,
             workers,
             is_train,
             keep_ratio,
             n_max_samples=-1):
    if isinstance(data_dir, list):
        dataset_list = []
        for data_dir_ in data_dir:
            dataset_list.append(
                LmdbDataset(data_dir_, voc_type, max_len, num_samples))
        dataset = ConcatDataset(dataset_list)
    else:
        dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples)
    print('total image: ', len(dataset))

    if n_max_samples > 0:
        n_all_samples = len(dataset)
        assert n_max_samples < n_all_samples
        # make sample indices static for every run
        sample_indices_cache_file = '.sample_indices.cache.pkl'
        if os.path.exists(sample_indices_cache_file):
            with open(sample_indices_cache_file, 'rb') as fin:
                sample_indices = pickle.load(fin)
            print('load sample indices from sample_indices_cache_file: ',
                  n_max_samples)
        else:
            sample_indices = np.random.choice(n_all_samples,
                                              n_max_samples,
                                              replace=False)
            with open(sample_indices_cache_file, 'wb') as fout:
                pickle.dump(sample_indices, fout)
            print('random sample: ', n_max_samples)
        sub_sampler = SubsetRandomSampler(sample_indices)
    else:
        sub_sampler = None

    if is_train:
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=workers,
            sampler=sub_sampler,
            shuffle=(True if sub_sampler is None else False),
            pin_memory=True,
            drop_last=True,
            collate_fn=AlignCollate(imgH=height,
                                    imgW=width,
                                    keep_ratio=keep_ratio))
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=workers,
                                 shuffle=False,
                                 pin_memory=True,
                                 drop_last=False,
                                 collate_fn=AlignCollate(
                                     imgH=height,
                                     imgW=width,
                                     keep_ratio=keep_ratio))

    return dataset, data_loader
def main_aster():

    #    from config import get_args
    #    args = get_args(sys.argv[1:])

    from pred_params import Get_ocr_args
    args = Get_ocr_args()

    print('Evaluation : ' + str(args.eval))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True


#    args.cuda = True and torch.cuda.is_available()

    if args.cuda:
        print('using cuda.')
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        print('using cpu.')
        torch.set_default_tensor_type('torch.FloatTensor')

    #   Create Character dict & max seq len
    args, char2id_dict, id2char_dict = Create_char_dict(args)

    print(id2char_dict)
    rec_num_classes = len(id2char_dict)

    #   Get rec num classes / max len
    print('max len : ' + str(args.max_len))

    #   Create data list
    #    train_list, args = Create_data_list(args, char2id_dict, True)
    #    test_list, args = Create_data_list(args, char2id_dict, False)
    train_list, char2id_dict, id2char_dict, args = Create_data_list(
        args, char2id_dict, id2char_dict, True)
    test_list, char2id_dict, id2char_dict, args = Create_data_list(
        args, char2id_dict, id2char_dict, False)

    encoder = ResNet_ASTER(with_lstm=True,
                           n_group=args.n_group,
                           use_cuda=args.cuda)

    encoder_out_planes = encoder.out_planes

    decoder = AttentionRecognitionHead(num_classes=rec_num_classes,
                                       in_planes=encoder_out_planes,
                                       sDim=args.decoder_sdim,
                                       attDim=args.attDim,
                                       max_len_labels=args.max_len,
                                       use_cuda=args.cuda)

    #   if rectification is on
    """
    if args.STN_ON:
        self.tps = TPSSpatialTransformer(
                    output_image_size = tuple(args.global_args.tps_outputsize),
                    num_control_points = args.num_control_points,
    """

    #   Load pretrained weights
    if not args.eval:
        if args.use_pretrained:
            #   use pretrained model
            pretrain_path = './data/demo.pth.tar'
            pretrained_dict = torch.load(pretrain_path)['state_dict']
            encoder_dict = {}
            decoder_dict = {}
            for i, x in enumerate(pretrained_dict.keys()):
                if 'encoder' in x:
                    encoder_dict['.'.join(
                        x.split('.')[1:])] = pretrained_dict[x]
                elif 'decoder' in x:
                    decoder_dict['.'.join(
                        x.split('.')[1:])] = pretrained_dict[x]
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            print('pretrained model loaded')

        else:
            #   init model parameters
            def init_weights(m):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform(m.weight)
                    #m.bias.data.fill_(0.01)

            encoder.apply(init_weights)
            decoder.apply(init_weights)
            print('Random weight initialized!')

    else:
        #   no training
        #        encoder.load_state_dict(torch.load('../params/encoder_final'))
        #        decoder.load_state_dict(torch.load('../params/decoder_final'))
        encoder.load_state_dict(torch.load('params/encoder_final'))
        decoder.load_state_dict(torch.load('params/decoder_final'))
        print('fine-tuned model loaded')

    rec_crit = SequenceCrossEntropyLoss()

    if args.cuda == True:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    encoder.to(device)
    decoder.to(device)

    #    param_groups = model.parameters()
    param_groups = encoder.parameters()
    param_groups = filter(lambda p: p.requires_grad, param_groups)
    optimizer = torch.optim.Adadelta(param_groups,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[4, 5],
                                                     gamma=0.1)

    test_proba = []
    test_pred = []
    test_label = []
    test_image = []

    train_loader = DataLoader(train_list,
                              batch_size=args.batch_size,
                              shuffle=False,
                              collate_fn=AlignCollate(imgH=args.height,
                                                      imgW=args.width,
                                                      keep_ratio=True))
    test_loader = DataLoader(test_list,
                             batch_size=args.batch_size,
                             shuffle=False,
                             collate_fn=AlignCollate(imgH=args.height,
                                                     imgW=args.width,
                                                     keep_ratio=True))

    if not args.eval:
        for epoch in range(args.n_epochs):
            for batch_idx, batch in enumerate(train_loader):

                x, rec_targets, rec_lengths = batch[0], batch[1], batch[2]

                x = x.to(device)
                encoder_feats = encoder(x)  # bs x w x C
                rec_pred = decoder([encoder_feats, rec_targets, rec_lengths])
                loss_rec = rec_crit(rec_pred, rec_targets, rec_lengths)

                if batch_idx == 0:
                    print('train Loss : ' + str(loss_rec))
                    rec_pred_idx = np.argmax(rec_pred.detach().cpu().numpy(),
                                             axis=-1)
                    print(rec_pred[:3])
                    print(rec_pred_idx[:5])

                optimizer.zero_grad()
                loss_rec.backward()
                optimizer.step()

        if args.cuda:
            torch.save(encoder.state_dict(), 'params/encoder_final')
            torch.save(decoder.state_dict(), 'params/decoder_final')
        else:
            torch.save(encoder.state_dict(), 'params/encoder_final_cpu')
            torch.save(decoder.state_dict(), 'params/decoder_final_cpu')

    for batch_idx, batch in enumerate(test_loader):

        x, rec_targets, rec_lengths = batch[0], batch[1], batch[2]

        encoder_feats = encoder(x)
        rec_pred, rec_pred_scores = decoder.beam_search(encoder_feats,\
                                                args.beam_width, args.eos)

        rec_pred = rec_pred.detach().cpu().numpy()
        rec_targets = rec_targets.numpy()
        print('predictions')
        print(rec_pred[:5])
        print('label')
        print(rec_targets[:5])
        test_proba.extend(rec_pred_scores)
        test_pred.extend(rec_pred)
        test_label.extend(rec_targets)
        test_image.extend(x.detach().cpu().numpy())

        hit = 0
        miss = 0
        try:
            for i, x in enumerate(rec_pred):
                if rec_pred[i] == rec_targets[i]:
                    hit += 1
                else:
                    miss += 1

            accuracy = hit / (hit + miss)
            print("batch accuracy=", accuracy)
        except:
            pass

    hit = 0
    miss = 0

    if args.save_preds == True:
        with open('aster_pred.pkl', 'wb') as f:
            pickle.dump([
                test_label, test_pred, test_proba, char2id_dict, id2char_dict,
                test_image
            ], f)

    def get_score(test_label, test_pred):
        total_n = 0
        true_n = 0
        eos = 94
        for i, x in enumerate(test_label):
            total_n += 1
            eos_idx = 0
            for j, y in enumerate(x):
                if y != eos:
                    eos_idx += 1
                else:
                    break
            label = x[:eos_idx]
            pred = test_pred[i][:eos_idx]
            if np.array_equal(label, pred):
                true_n += 1
        print('Accuracy')
        print(true_n / total_n)

    get_score(test_label, test_pred)
def main_aster(folder_name):
    """
    @Input
    folder_name : name of the folder where training data are stored.
     
    @Output
    trained parameters are stored in 'params' folder
    """

    #   arguments are stored in pred_params.py
    from pred_params import Get_ocr_args
    args = Get_ocr_args()

    print('Evaluation : ' + str(args.eval))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

    if args.cuda:
        print('using cuda.')
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        print('using cpu.')
        torch.set_default_tensor_type('torch.FloatTensor')

    #   Create Character dict & max seq len
    args, char2id_dict, id2char_dict = Create_char_dict(args)

    print(id2char_dict)
    rec_num_classes = len(id2char_dict)

    #   Get rec num classes / max len
    print('max len : ' + str(args.max_len))

    #   Get file list for train set
    filenames = glob.glob('./data/' + folder_name + '/*/*.xml')
    filenames = [x[:-4] for x in filenames]
    print('file len : ' + str(len(filenames)))

    #   files are not splitted into train/valid set.
    train_list = Create_data_list_byfolder(args, char2id_dict, id2char_dict,
                                           filenames)

    encoder = ResNet_ASTER(with_lstm=True,
                           n_group=args.n_group,
                           use_cuda=args.cuda)

    encoder_out_planes = encoder.out_planes

    decoder = AttentionRecognitionHead(num_classes=rec_num_classes,
                                       in_planes=encoder_out_planes,
                                       sDim=args.decoder_sdim,
                                       attDim=args.attDim,
                                       max_len_labels=args.max_len,
                                       use_cuda=args.cuda)

    #   Load pretrained weights
    if not args.eval:
        if args.use_pretrained:
            #   use pretrained model
            pretrain_path = './data/demo.pth.tar'
            if args.cuda:
                pretrained_dict = torch.load(pretrain_path)['state_dict']
            else:
                pretrained_dict = torch.load(pretrain_path,
                                             map_location='cpu')['state_dict']

            encoder_dict = {}
            decoder_dict = {}
            for i, x in enumerate(pretrained_dict.keys()):
                if 'encoder' in x:
                    encoder_dict['.'.join(
                        x.split('.')[1:])] = pretrained_dict[x]
                elif 'decoder' in x:
                    decoder_dict['.'.join(
                        x.split('.')[1:])] = pretrained_dict[x]
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            print('pretrained model loaded')

        else:
            #   init model parameters
            def init_weights(m):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform(m.weight)
                    #m.bias.data.fill_(0.01)

            encoder.apply(init_weights)
            decoder.apply(init_weights)
            print('Random weight initialized!')

    else:
        #   loading parameters for inference
        if args.cuda:
            encoder.load_state_dict(torch.load('params/encoder_final'))
            decoder.load_state_dict(torch.load('params/decoder_final'))
        else:
            encoder.load_state_dict(
                torch.load('params/encoder_final',
                           map_location=torch.device('cpu')))
            decoder.load_state_dict(
                torch.load('params/decoder_final',
                           map_location=torch.device('cpu')))
        print('fine-tuned model loaded')

    #   Training Phase

    rec_crit = SequenceCrossEntropyLoss()

    if (args.cuda == True) & torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    encoder.to(device)
    decoder.to(device)

    param_groups = encoder.parameters()
    param_groups = filter(lambda p: p.requires_grad, param_groups)
    optimizer = torch.optim.Adadelta(param_groups,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[4, 5],
                                                     gamma=0.1)

    train_loader = DataLoader(train_list,
                              batch_size=args.batch_size,
                              shuffle=False,
                              collate_fn=AlignCollate(imgH=args.height,
                                                      imgW=args.width,
                                                      keep_ratio=True))

    for epoch in range(args.n_epochs):
        for batch_idx, batch in enumerate(train_loader):

            x, rec_targets, rec_lengths = batch[0], batch[1], batch[2]

            x = x.to(device)
            encoder_feats = encoder(x)  # bs x w x C
            rec_pred = decoder([encoder_feats, rec_targets, rec_lengths])
            loss_rec = rec_crit(rec_pred, rec_targets, rec_lengths)

            if batch_idx == 0:
                print('train Loss : ' + str(loss_rec))
                rec_pred_idx = np.argmax(rec_pred.detach().cpu().numpy(),
                                         axis=-1)
                print(rec_pred[:3])
                print(rec_pred_idx[:5])

            optimizer.zero_grad()
            loss_rec.backward()
            optimizer.step()

    #   Training phase ends

    #   this is where trained model parameters are saved

    torch.save(encoder.state_dict(), 'params/encoder_final')
    torch.save(decoder.state_dict(), 'params/decoder_final')