def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size, workers, is_train, keep_ratio): num_synthetic_dataset = len(synthetic_dataset) num_real_dataset = len(real_dataset) synthetic_indices = list(np.random.permutation(num_synthetic_dataset)) synthetic_indices = synthetic_indices[num_real_dataset:] real_indices = list( np.random.permutation(num_real_dataset) + num_synthetic_dataset) concated_indices = synthetic_indices + real_indices assert len(concated_indices) == num_synthetic_dataset sampler = SubsetRandomSampler(concated_indices) concated_dataset = ConcatDataset([synthetic_dataset, real_dataset]) print('total image: ', len(concated_dataset)) data_loader = DataLoader(concated_dataset, batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True, drop_last=True, sampler=sampler, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) return concated_dataset, data_loader
def forward(self, image_path, coordinates): """ @input image paths : One image path without '.xml' or '.png' coordinates: A List of coordinates @output : A List of characters """ args = self.args encoder = self.encoder decoder = self.decoder if args.cuda: device = self.device image = get_data_image(image_path, args) cropped_images = crop_image(image, coordinates, args, resample=Image.BICUBIC) # list of imgs cropped_images = [{ 'images': x, 'rec_targets': 0, 'rec_lengths': 0 } for x in cropped_images] # data loader test_pred = [] test_image = [] test_loader = DataLoader(cropped_images, batch_size=args.batch_size, shuffle=False, collate_fn=AlignCollate(imgH=args.height, imgW=args.width, keep_ratio=True)) for batch_idx, batch in enumerate(test_loader): if args.cuda: x = batch[0].to(device) else: x = batch[0] encoder_feats = self.encoder(x) rec_pred, rec_pred_scores = decoder.beam_search(encoder_feats,\ args.beam_width, args.eos) rec_pred = rec_pred.detach().cpu().numpy() test_pred.extend(rec_pred) test_image.extend(x.detach().cpu().numpy()) test_pred_char = [ self.idx2char(x, self.id2char_dict) for x in test_pred ] return test_pred_char
def get_data_txt(data_dir, gt_file_path, embed_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio): if isinstance(data_dir, list) and len(data_dir) > 1: dataset_list = [] for data_dir_, gt_file_, embed_dir_ in zip(data_dir, gt_file_path, embed_dir): # dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples)) dataset_list.append(CustomDataset( data_dir_, gt_file_, embed_dir_, voc_type, max_len, num_samples)) dataset = ConcatDataset(dataset_list) else: # dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples) dataset = CustomDataset(data_dir, gt_file_path, embed_dir, voc_type, max_len, num_samples) print('total image: ', len(dataset)) if is_train: """ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True, pin_memory=True, drop_last=True, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) """ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True, pin_memory=True, drop_last=True, collate_fn=AlignCollate( imgH=height, imgW=width, keep_ratio=keep_ratio)) else: data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) return dataset, data_loader
def get_data_lmdb(data_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio, voc_file=None): if isinstance(data_dir, list): dataset_list = [] for data_dir_ in data_dir: dataset_list.append(LmdbDataset( data_dir_, voc_type, max_len, num_samples, voc_file=voc_file)) dataset = ConcatDataset(dataset_list) else: dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples, voc_file=voc_file) print('total image: ', len(dataset)) if is_train: data_loader = DataLoader( dataset, batch_size=batch_size, num_workers=workers, shuffle=True, pin_memory=True, drop_last=True, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio) ) else: data_loader = DataLoader( dataset, batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio) ) return dataset, data_loader
def get_data(data_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio, augment=False): transform = albu.Compose([ albu.RGBShift(p=0.5), albu.RandomBrightnessContrast(p=0.5), albu.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=0.5) ]) if augment else None if isinstance(data_dir, list): dataset = ConcatDataset([ LmdbDataset(data_dir_, voc_type, max_len, num_samples, transform) for data_dir_ in data_dir ]) else: dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples, transform) print('total image: ', len(dataset)) data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=is_train, pin_memory=True, drop_last=is_train, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) return dataset, data_loader
def get_data(data_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio, n_max_samples=-1): if isinstance(data_dir, list): dataset_list = [] for data_dir_ in data_dir: dataset_list.append( LmdbDataset(data_dir_, voc_type, max_len, num_samples)) dataset = ConcatDataset(dataset_list) else: dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples) print('total image: ', len(dataset)) if n_max_samples > 0: n_all_samples = len(dataset) assert n_max_samples < n_all_samples # make sample indices static for every run sample_indices_cache_file = '.sample_indices.cache.pkl' if os.path.exists(sample_indices_cache_file): with open(sample_indices_cache_file, 'rb') as fin: sample_indices = pickle.load(fin) print('load sample indices from sample_indices_cache_file: ', n_max_samples) else: sample_indices = np.random.choice(n_all_samples, n_max_samples, replace=False) with open(sample_indices_cache_file, 'wb') as fout: pickle.dump(sample_indices, fout) print('random sample: ', n_max_samples) sub_sampler = SubsetRandomSampler(sample_indices) else: sub_sampler = None if is_train: data_loader = DataLoader( dataset, batch_size=batch_size, num_workers=workers, sampler=sub_sampler, shuffle=(True if sub_sampler is None else False), pin_memory=True, drop_last=True, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) else: data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False, collate_fn=AlignCollate( imgH=height, imgW=width, keep_ratio=keep_ratio)) return dataset, data_loader
def main_aster(): # from config import get_args # args = get_args(sys.argv[1:]) from pred_params import Get_ocr_args args = Get_ocr_args() print('Evaluation : ' + str(args.eval)) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = True torch.backends.cudnn.deterministic = True # args.cuda = True and torch.cuda.is_available() if args.cuda: print('using cuda.') torch.set_default_tensor_type('torch.cuda.FloatTensor') else: print('using cpu.') torch.set_default_tensor_type('torch.FloatTensor') # Create Character dict & max seq len args, char2id_dict, id2char_dict = Create_char_dict(args) print(id2char_dict) rec_num_classes = len(id2char_dict) # Get rec num classes / max len print('max len : ' + str(args.max_len)) # Create data list # train_list, args = Create_data_list(args, char2id_dict, True) # test_list, args = Create_data_list(args, char2id_dict, False) train_list, char2id_dict, id2char_dict, args = Create_data_list( args, char2id_dict, id2char_dict, True) test_list, char2id_dict, id2char_dict, args = Create_data_list( args, char2id_dict, id2char_dict, False) encoder = ResNet_ASTER(with_lstm=True, n_group=args.n_group, use_cuda=args.cuda) encoder_out_planes = encoder.out_planes decoder = AttentionRecognitionHead(num_classes=rec_num_classes, in_planes=encoder_out_planes, sDim=args.decoder_sdim, attDim=args.attDim, max_len_labels=args.max_len, use_cuda=args.cuda) # if rectification is on """ if args.STN_ON: self.tps = TPSSpatialTransformer( output_image_size = tuple(args.global_args.tps_outputsize), num_control_points = args.num_control_points, """ # Load pretrained weights if not args.eval: if args.use_pretrained: # use pretrained model pretrain_path = './data/demo.pth.tar' pretrained_dict = torch.load(pretrain_path)['state_dict'] encoder_dict = {} decoder_dict = {} for i, x in enumerate(pretrained_dict.keys()): if 'encoder' in x: encoder_dict['.'.join( x.split('.')[1:])] = pretrained_dict[x] elif 'decoder' in x: decoder_dict['.'.join( x.split('.')[1:])] = pretrained_dict[x] encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) print('pretrained model loaded') else: # init model parameters def init_weights(m): if type(m) == nn.Linear: torch.nn.init.xavier_uniform(m.weight) #m.bias.data.fill_(0.01) encoder.apply(init_weights) decoder.apply(init_weights) print('Random weight initialized!') else: # no training # encoder.load_state_dict(torch.load('../params/encoder_final')) # decoder.load_state_dict(torch.load('../params/decoder_final')) encoder.load_state_dict(torch.load('params/encoder_final')) decoder.load_state_dict(torch.load('params/decoder_final')) print('fine-tuned model loaded') rec_crit = SequenceCrossEntropyLoss() if args.cuda == True: device = torch.device('cuda') else: device = torch.device('cpu') encoder.to(device) decoder.to(device) # param_groups = model.parameters() param_groups = encoder.parameters() param_groups = filter(lambda p: p.requires_grad, param_groups) optimizer = torch.optim.Adadelta(param_groups, lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 5], gamma=0.1) test_proba = [] test_pred = [] test_label = [] test_image = [] train_loader = DataLoader(train_list, batch_size=args.batch_size, shuffle=False, collate_fn=AlignCollate(imgH=args.height, imgW=args.width, keep_ratio=True)) test_loader = DataLoader(test_list, batch_size=args.batch_size, shuffle=False, collate_fn=AlignCollate(imgH=args.height, imgW=args.width, keep_ratio=True)) if not args.eval: for epoch in range(args.n_epochs): for batch_idx, batch in enumerate(train_loader): x, rec_targets, rec_lengths = batch[0], batch[1], batch[2] x = x.to(device) encoder_feats = encoder(x) # bs x w x C rec_pred = decoder([encoder_feats, rec_targets, rec_lengths]) loss_rec = rec_crit(rec_pred, rec_targets, rec_lengths) if batch_idx == 0: print('train Loss : ' + str(loss_rec)) rec_pred_idx = np.argmax(rec_pred.detach().cpu().numpy(), axis=-1) print(rec_pred[:3]) print(rec_pred_idx[:5]) optimizer.zero_grad() loss_rec.backward() optimizer.step() if args.cuda: torch.save(encoder.state_dict(), 'params/encoder_final') torch.save(decoder.state_dict(), 'params/decoder_final') else: torch.save(encoder.state_dict(), 'params/encoder_final_cpu') torch.save(decoder.state_dict(), 'params/decoder_final_cpu') for batch_idx, batch in enumerate(test_loader): x, rec_targets, rec_lengths = batch[0], batch[1], batch[2] encoder_feats = encoder(x) rec_pred, rec_pred_scores = decoder.beam_search(encoder_feats,\ args.beam_width, args.eos) rec_pred = rec_pred.detach().cpu().numpy() rec_targets = rec_targets.numpy() print('predictions') print(rec_pred[:5]) print('label') print(rec_targets[:5]) test_proba.extend(rec_pred_scores) test_pred.extend(rec_pred) test_label.extend(rec_targets) test_image.extend(x.detach().cpu().numpy()) hit = 0 miss = 0 try: for i, x in enumerate(rec_pred): if rec_pred[i] == rec_targets[i]: hit += 1 else: miss += 1 accuracy = hit / (hit + miss) print("batch accuracy=", accuracy) except: pass hit = 0 miss = 0 if args.save_preds == True: with open('aster_pred.pkl', 'wb') as f: pickle.dump([ test_label, test_pred, test_proba, char2id_dict, id2char_dict, test_image ], f) def get_score(test_label, test_pred): total_n = 0 true_n = 0 eos = 94 for i, x in enumerate(test_label): total_n += 1 eos_idx = 0 for j, y in enumerate(x): if y != eos: eos_idx += 1 else: break label = x[:eos_idx] pred = test_pred[i][:eos_idx] if np.array_equal(label, pred): true_n += 1 print('Accuracy') print(true_n / total_n) get_score(test_label, test_pred)
def main_aster(folder_name): """ @Input folder_name : name of the folder where training data are stored. @Output trained parameters are stored in 'params' folder """ # arguments are stored in pred_params.py from pred_params import Get_ocr_args args = Get_ocr_args() print('Evaluation : ' + str(args.eval)) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = True torch.backends.cudnn.deterministic = True if args.cuda: print('using cuda.') torch.set_default_tensor_type('torch.cuda.FloatTensor') else: print('using cpu.') torch.set_default_tensor_type('torch.FloatTensor') # Create Character dict & max seq len args, char2id_dict, id2char_dict = Create_char_dict(args) print(id2char_dict) rec_num_classes = len(id2char_dict) # Get rec num classes / max len print('max len : ' + str(args.max_len)) # Get file list for train set filenames = glob.glob('./data/' + folder_name + '/*/*.xml') filenames = [x[:-4] for x in filenames] print('file len : ' + str(len(filenames))) # files are not splitted into train/valid set. train_list = Create_data_list_byfolder(args, char2id_dict, id2char_dict, filenames) encoder = ResNet_ASTER(with_lstm=True, n_group=args.n_group, use_cuda=args.cuda) encoder_out_planes = encoder.out_planes decoder = AttentionRecognitionHead(num_classes=rec_num_classes, in_planes=encoder_out_planes, sDim=args.decoder_sdim, attDim=args.attDim, max_len_labels=args.max_len, use_cuda=args.cuda) # Load pretrained weights if not args.eval: if args.use_pretrained: # use pretrained model pretrain_path = './data/demo.pth.tar' if args.cuda: pretrained_dict = torch.load(pretrain_path)['state_dict'] else: pretrained_dict = torch.load(pretrain_path, map_location='cpu')['state_dict'] encoder_dict = {} decoder_dict = {} for i, x in enumerate(pretrained_dict.keys()): if 'encoder' in x: encoder_dict['.'.join( x.split('.')[1:])] = pretrained_dict[x] elif 'decoder' in x: decoder_dict['.'.join( x.split('.')[1:])] = pretrained_dict[x] encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) print('pretrained model loaded') else: # init model parameters def init_weights(m): if type(m) == nn.Linear: torch.nn.init.xavier_uniform(m.weight) #m.bias.data.fill_(0.01) encoder.apply(init_weights) decoder.apply(init_weights) print('Random weight initialized!') else: # loading parameters for inference if args.cuda: encoder.load_state_dict(torch.load('params/encoder_final')) decoder.load_state_dict(torch.load('params/decoder_final')) else: encoder.load_state_dict( torch.load('params/encoder_final', map_location=torch.device('cpu'))) decoder.load_state_dict( torch.load('params/decoder_final', map_location=torch.device('cpu'))) print('fine-tuned model loaded') # Training Phase rec_crit = SequenceCrossEntropyLoss() if (args.cuda == True) & torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') encoder.to(device) decoder.to(device) param_groups = encoder.parameters() param_groups = filter(lambda p: p.requires_grad, param_groups) optimizer = torch.optim.Adadelta(param_groups, lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 5], gamma=0.1) train_loader = DataLoader(train_list, batch_size=args.batch_size, shuffle=False, collate_fn=AlignCollate(imgH=args.height, imgW=args.width, keep_ratio=True)) for epoch in range(args.n_epochs): for batch_idx, batch in enumerate(train_loader): x, rec_targets, rec_lengths = batch[0], batch[1], batch[2] x = x.to(device) encoder_feats = encoder(x) # bs x w x C rec_pred = decoder([encoder_feats, rec_targets, rec_lengths]) loss_rec = rec_crit(rec_pred, rec_targets, rec_lengths) if batch_idx == 0: print('train Loss : ' + str(loss_rec)) rec_pred_idx = np.argmax(rec_pred.detach().cpu().numpy(), axis=-1) print(rec_pred[:3]) print(rec_pred_idx[:5]) optimizer.zero_grad() loss_rec.backward() optimizer.step() # Training phase ends # this is where trained model parameters are saved torch.save(encoder.state_dict(), 'params/encoder_final') torch.save(decoder.state_dict(), 'params/decoder_final')