def __init__(self, opt): opt.src_select_data = opt.src_select_data.split('-') opt.src_batch_ratio = opt.src_batch_ratio.split('-') opt.tar_select_data = opt.tar_select_data.split('-') opt.tar_batch_ratio = opt.tar_batch_ratio.split('-') """ vocab / character number configuration """ if opt.sensitive: # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' opt.character = string.printable[: -6] # same with ASTER setting (use 94 char). if opt.char_dict is not None: opt.character = load_char_dict( opt.char_dict)[3:-2] # 去除Attention 和 CTC引入的一些特殊符号 """ model configuration """ self.converter = AttnLabelConverter(opt.character) opt.num_class = len(self.converter.character) if opt.rgb: opt.input_channel = 3 self.opt = opt print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) self.save_opt_log(opt) self.build_model(opt)
def load_net(self): """ model configuration """ if 'CTC' in self.opt.Prediction: self.converter = CTCLabelConverter(self.opt.character) else: self.converter = AttnLabelConverter(self.opt.character) self.opt.num_class = len(self.converter.character) if self.opt.rgb: self.opt.input_channel = 3 print("input channle :{}".format(self.opt.input_channel)) self.model = Model(self.opt) print('model input parameters', self.opt.imgH, self.opt.imgW, self.opt.num_fiducial, self.opt.input_channel, self.opt.output_channel, self.opt.hidden_size, self.opt.num_class, self.opt.batch_max_length, self.opt.Transformation, self.opt.FeatureExtraction, self.opt.SequenceModeling, self.opt.Prediction) self.model = torch.nn.DataParallel(self.model).to(device) # load model print('loading pretrained model from %s' % self.opt.saved_model) state_dict = torch.load(self.opt.saved_model, map_location=device) self.model.load_state_dict(state_dict) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo self.AlignCollate_demo = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD)
def __init__(self, model_path, gpu_id=None): """ 初始化模型 :param model_path: 模型地址 :param gpu_id: 在哪一块gpu上运行 """ checkpoint = torch.load(model_path) print(f"load {checkpoint['epoch']} epoch params") config = checkpoint['config'] alphabet = config['dataset']['alphabet'] if gpu_id is not None and isinstance( gpu_id, int) and torch.cuda.is_available(): self.device = torch.device("cuda:%s" % gpu_id) else: self.device = torch.device("cpu") print('device:', self.device) self.transform = [] for t in config['dataset']['train']['dataset']['args']['transforms']: if t['type'] in ['ToTensor', 'Normalize']: self.transform.append(t) self.transform = get_transforms(self.transform) self.gpu_id = gpu_id img_h, img_w = 32, 100 for process in config['dataset']['train']['dataset']['args'][ 'pre_processes']: if process['type'] == "Resize": img_h = process['args']['img_h'] img_w = process['args']['img_w'] break self.img_w = img_w self.img_h = img_h self.img_mode = config['dataset']['train']['dataset']['args'][ 'img_mode'] self.alphabet = alphabet img_channel = 3 if config['dataset']['train']['dataset']['args'][ 'img_mode'] != 'GRAY' else 1 if config['arch']['args']['prediction']['type'] == 'CTC': self.converter = CTCLabelConverter(config['dataset']['alphabet']) elif config['arch']['args']['prediction']['type'] == 'Attn': self.converter = AttnLabelConverter(config['dataset']['alphabet']) self.net = get_model(img_channel, len(self.converter.character), config['arch']['args']) self.net.load_state_dict(checkpoint['state_dict']) # self.net = torch.jit.load('crnn_lite_gpu.pt') self.net.to(self.device) self.net.eval() sample_input = torch.zeros( (2, img_channel, img_h, img_w)).to(self.device) self.net.get_batch_max_length(sample_input)
def __init__(self, opt): if 'CTC' in opt.Prediction: self.converter = CTCLabelConverter(opt.character) else: self.converter = AttnLabelConverter(opt.character) opt.num_class = len(self.converter.character) if opt.rgb: opt.input_channel = 3 super().__init__(opt) self = torch.nn.DataParallel(self).to(device) print('loading pretrained model from %s' % opt.saved_model) self.load_state_dict(torch.load(opt.saved_model, map_location=device)) self.opt = opt
def init_model(args): """ model configuration """ if 'CTC' in args.Prediction: converter = CTCLabelConverter(args.character) else: converter = AttnLabelConverter(args.character) args.num_class = len(converter.character) if args.rgb: args.input_channel = 3 model = Model(args) print('model input parameters', args.imgH, args.imgW, args.num_fiducial, args.input_channel, args.output_channel, args.hidden_size, args.num_class, args.batch_max_length, args.Transformation, args.FeatureExtraction, args.SequenceModeling, args.Prediction) try: model = torch.nn.DataParallel(model).to(device) except RuntimeError: raise RuntimeError(device) # load model print('loading pretrained model from %s' % args.saved_model) model.load_state_dict(torch.load(args.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=args.imgH, imgW=args.imgW, keep_ratio_with_pad=args.PAD) model.eval() return converter, model, AlignCollate_demo
def main(config): train_loader, eval_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args']) if os.path.isfile(config['data_loader']['args']['alphabet']): config['data_loader']['args']['alphabet'] = str(np.load(config['data_loader']['args']['alphabet'])) prediction_type = config['arch']['args']['prediction']['type'] # label转换器设置 if prediction_type == 'CTC': converter = CTCLabelConverter(config['data_loader']['args']['alphabet']) else: converter = AttnLabelConverter(config['data_loader']['args']['alphabet']) num_class = len(converter.character) # loss 设置 if prediction_type == 'CTC': criterion = CTCLoss(zero_infinity=True).cuda() else: criterion = CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 model = get_model(num_class, config) config['name'] = config['name'] + '_' + model.name trainer = Trainer(config=config, model=model, criterion=criterion, train_loader=train_loader, val_loader=eval_loader, converter=converter, weights_init=weights_init) trainer.train()
def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) model = Model(opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, Transformation=opt.Transformation, FeatureExtraction=opt.FeatureExtraction, SequenceModeling=opt.SequenceModeling, Prediction=opt.Prediction) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).cuda() # load model if opt.saved_model != '': print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) opt.name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = CTCLoss(reduction='sum') else: criterion = torch.nn.CrossEntropyLoss( ignore_index=0).cuda() # ignore [GO] token = ignore index 0 """ evaluation """ model.eval() if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) print(accuracy_by_best_model) with open('./result/{0}/log_evaluation.txt'.format(opt.name), 'a') as log: log.write(str(accuracy_by_best_model) + '\n')
def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) elif 'Transformer' in opt.Prediction or 'Test' in opt.Prediction or 'Transformer' in opt.SequenceModeling: converter = TransformerLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 2 """ evaluation """ model.eval() with torch.no_grad(): if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: log = open(f'./result/{opt.experiment_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) log.write(eval_data_log) print(f'{accuracy_by_best_model:0.3f}') log.write(f'{accuracy_by_best_model:0.3f}\n') log.close()
def loader(): opt = config() """ model configuration """ if 'CTC' in opt['Prediction']: converter = CTCLabelConverter(opt['character']) else: converter = AttnLabelConverter(opt['character']) opt['num_class'] = len(converter.character) if opt['rgb'] == 3: opt['input_channel'] = 3 model = Model(opt) model = torch.nn.DataParallel(model).to(device) # load model #print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt['saved_model'], map_location=device)) length_for_pred = torch.IntTensor([opt['batch_max_length']] * opt['batch_size']).to(device) text_for_pred = torch.LongTensor( opt['batch_size'], opt['batch_max_length'] + 1).fill_(0).to(device) model.eval() return model, converter, length_for_pred, text_for_pred, opt
def generate_recognition_model(args): if 'CTC' in args.Prediction: converter = CTCLabelConverter(args.character) else: converter = AttnLabelConverter(args.character) args.num_class = len(converter.character) if args.rgb: args.input_channel = 3 model = Model(args) if 'CTC' in args.Prediction: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num: {}'.format(sum(params_num))) optimizer = torch.optim.Adadelta(filtered_parameters, lr=args.reco_lr, rho=0.95, eps=0.999) return model, converter, criterion, optimizer
def initialize(self): start = time.time() # self.saved_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_addKorean_synth/best_accuracy.pth' # self.craft_trained_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train/craft_mlt_25k.pth' # self.saved_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_v2/best_accuracy.pth' # self.craft_trained_model = '/home_hongdo/sungeun.kim/checkpoints/ocr/ocr_train_v2/best_accuracy_craft.pth' # # official self.saved_model = './data_ocr/best_accuracy.pth' self.craft_trained_model = './data_ocr/best_accuracy_craft.pth' self.logfilepath = './data_ocr/log_ocr_result.txt' """ vocab / character number configuration """ # if self.sensitive: # self.character = string.printable[:-6] # same with ASTER setting (use 94 char). cudnn.benchmark = True cudnn.deterministic = True self.num_gpu = torch.cuda.device_count() """ model configuration """ # detetion self.net = CRAFT(self) # initialize print('Loading detection weights from checkpoint ' + self.craft_trained_model) self.net.load_state_dict( copyStateDict( torch.load(self.craft_trained_model, map_location=self.device))) self.net = torch.nn.DataParallel(self.net).to(self.device) self.converter = AttnLabelConverter(self.character) self.num_class = len(self.converter.character) if self.rgb: self.input_channel = 3 self.model = Model(self, self.num_class) # print('model input parameters', self.imgH, self.imgW, self.num_fiducial, self.input_channel, self.output_channel, # self.hidden_size, self.num_class, self.batch_max_length) # load model self.model = torch.nn.DataParallel(self.model).to(self.device) print('Loading recognition weights from checkpoint %s' % self.saved_model) self.model.load_state_dict( torch.load(self.saved_model, map_location=self.device)) if torch.cuda.is_available(): self.model = self.model.cuda() self.net = self.net.cuda() cudnn.benchmark = False # print('Initialization Done! It tooks {:.2f} mins.\n'.format((time.time() - start) / 60)) print( 'Initialization Done! It tooks {:.2f} sec.\n'.format(time.time() - start)) return True
def recognition(image): converter = AttnLabelConverter(args.character) args.num_class = len(converter.character) if args.rgb: args.input_channel = 3 model = Model(args) model = torch.nn.DataParallel(model).to(device) model.load_state_dict(torch.load(args.model_dir, map_location=device)) transformer = resizeNormalize((100, 32)) #Convert RGB images to Gray Scale which is neccesary for our convolution layers. if image.mode == 'RGB': image = image.convert('L') image = transformer(image) batch_size = image.size(0) if torch.cuda.is_available(): image = image.cuda() image = image.view(1, *image.size()) image = Variable(image) model.eval() length_for_pred = torch.IntTensor([args.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, args.batch_max_length + 1).fill_(0).to(device) preds = model(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) preds_prob = F.softmax(preds, dim=2) text_prediction = preds_str[0].replace("[s]", "") return text_prediction
def main(config): import torch from modeling import build_model, build_loss from data_loader import get_dataloader from trainer import Trainer from utils import CTCLabelConverter, AttnLabelConverter, load if os.path.isfile(config['dataset']['alphabet']): config['dataset']['alphabet'] = ''.join( load(config['dataset']['alphabet'])) prediction_type = config['arch']['head']['type'] # loss 设置 criterion = build_loss(config['loss']) if prediction_type == 'CTC': converter = CTCLabelConverter(config['dataset']['alphabet']) elif prediction_type == 'Attn': converter = AttnLabelConverter(config['dataset']['alphabet']) else: raise NotImplementedError img_channel = 3 if config['dataset']['train']['dataset']['args'][ 'img_mode'] != 'GRAY' else 1 config['arch']['backbone']['in_channels'] = img_channel config['arch']['head']['n_class'] = len(converter.character) # model = get_model(img_channel, len(converter.character), config['arch']['args']) model = build_model(config['arch']) img_h, img_w = 32, 100 for process in config['dataset']['train']['dataset']['args'][ 'pre_processes']: if process['type'] == "Resize": img_h = process['args']['img_h'] img_w = process['args']['img_w'] break sample_input = torch.zeros((2, img_channel, img_h, img_w)) num_label = model.get_batch_max_length(sample_input) train_loader = get_dataloader(config['dataset']['train'], num_label) assert train_loader is not None if 'validate' in config['dataset'] and config['dataset']['validate'][ 'dataset']['args']['data_path'][0] is not None: validate_loader = get_dataloader(config['dataset']['validate'], num_label) else: validate_loader = None trainer = Trainer(config=config, model=model, criterion=criterion, train_loader=train_loader, validate_loader=validate_loader, sample_input=sample_input, converter=converter) trainer.train()
def __init__(self, model_path: str, gpu_id=None): ''' 初始化pytorch模型 :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) :param alphabet: 字母表 :param img_shape: 图像的尺寸(w,h) :param net: 网络计算图,如果在model_path中指定的是参数的保存路径,则需要给出网络的计算图 :param img_channel: 图像的通道数: 1,3 :param gpu_id: 在哪一块gpu上运行 ''' self.gpu_id = gpu_id checkpoint = torch.load(model_path) if self.gpu_id is not None and isinstance( self.gpu_id, int) and torch.cuda.is_available(): self.device = torch.device("cuda:%s" % self.gpu_id) else: self.device = torch.device("cpu") print('device:', self.device) config = checkpoint['config'] self.prediction_type = config['arch']['args']['prediction']['type'] if self.prediction_type == 'CTC': self.converter = CTCLabelConverter( config['data_loader']['args']['alphabet']) else: self.converter = AttnLabelConverter( config['data_loader']['args']['alphabet']) num_class = len(self.converter.character) self.net = get_model(num_class, config) self.img_w = config['data_loader']['args']['dataset']['img_w'] self.img_h = config['data_loader']['args']['dataset']['img_h'] self.img_channel = config['data_loader']['args']['dataset'][ 'img_channel'] self.net.load_state_dict(checkpoint['state_dict']) self.net.to(self.device) self.net.eval()
def __init__(self): # model settings # self.path_model = 'model/TPS-ResNet-BiLSTM-Attn.pth' self.batch_size = 1 self.batch_max_length = 25 self.imgH = 32 self.imgW = 100 self.character = '0123456789abcdefghijklmnopqrstuvwxyz' self.Transformation = 'TPS' self.FeatureExtraction = 'ResNet' self.SequenceModeling = 'BiLSTM' self.Prediction = 'Attn' self.num_fiducial = 20 self.input_channel = 1 self.output_channel = 512 self.hidden_size = 256 self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') parser = argparse.ArgumentParser() parser.add_argument('--rgb', action='store_true', help='use rgb input') self.opt = parser.parse_args() self.opt.num_gpu = torch.cuda.device_count() # load model self.converter = AttnLabelConverter(self.character) self.opt.num_class = len(self.converter.character) if self.opt.rgb: self.opt.input_channel = 3 self.model = Model(self.opt) self.model = torch.nn.DataParallel(self.model).to('cuda:0') # load model self.model.load_state_dict(torch.load(self.path_model))
def init_model(model_path): global model,converter,batch_max_length,imgW,imgH,transform cudnn.benchmark = True cudnn.deterministic = True """ model configuration """ converter = AttnLabelConverter(character) num_class = len(converter.character) model = Model(imgW, imgH, num_class,batch_max_length) model = model = torch.nn.DataParallel(model).to(device) # load model model.load_state_dict(torch.load(model_path, map_location=device)) transform = ResizeNormalize((imgW, imgH))
def transform_to_onnx(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt).to(device) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # change output, so that OpenCV sample can run it correctly pose_model = Post_Model(model) model = pose_model # transform to onnx model.eval() batch_size = 1 channel = 1 h_size = 32 w_size = 100 dummy_input = torch.randn(batch_size, channel, h_size, w_size, requires_grad=True).to(device) output_path = "./onnx_output" if not os.path.exists(output_path): os.mkdir(output_path) onnx_output = os.path.join(output_path, opt.onnx_name) print("output path :", onnx_output) dummy_output = model(dummy_input) print("output size :", dummy_output.size()) torch.onnx.export(model, dummy_input, onnx_output, verbose=True, export_params=True, opset_version=11, do_constant_folding=True, input_names = ["input"], output_names = ["output"]) print("Transform Successfully!")
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') custom_output.write( f'{img_name}\t{pred}\t{confidence_score:0.4f}\n') log.close()
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
class PytorchNet: def __init__(self, model_path, gpu_id=None): """ 初始化模型 :param model_path: 模型地址 :param gpu_id: 在哪一块gpu上运行 """ checkpoint = torch.load(model_path) print(f"load {checkpoint['epoch']} epoch params") config = checkpoint['config'] alphabet = config['dataset']['alphabet'] if gpu_id is not None and isinstance( gpu_id, int) and torch.cuda.is_available(): self.device = torch.device("cuda:%s" % gpu_id) else: self.device = torch.device("cpu") print('device:', self.device) self.transform = [] for t in config['dataset']['train']['dataset']['args']['transforms']: if t['type'] in ['ToTensor', 'Normalize']: self.transform.append(t) self.transform = get_transforms(self.transform) self.gpu_id = gpu_id img_h, img_w = 32, 100 for process in config['dataset']['train']['dataset']['args'][ 'pre_processes']: if process['type'] == "Resize": img_h = process['args']['img_h'] img_w = process['args']['img_w'] break self.img_w = img_w self.img_h = img_h self.img_mode = config['dataset']['train']['dataset']['args'][ 'img_mode'] self.alphabet = alphabet img_channel = 3 if config['dataset']['train']['dataset']['args'][ 'img_mode'] != 'GRAY' else 1 if config['arch']['args']['prediction']['type'] == 'CTC': self.converter = CTCLabelConverter(config['dataset']['alphabet']) elif config['arch']['args']['prediction']['type'] == 'Attn': self.converter = AttnLabelConverter(config['dataset']['alphabet']) self.net = get_model(img_channel, len(self.converter.character), config['arch']['args']) self.net.load_state_dict(checkpoint['state_dict']) # self.net = torch.jit.load('crnn_lite_gpu.pt') self.net.to(self.device) self.net.eval() sample_input = torch.zeros( (2, img_channel, img_h, img_w)).to(self.device) self.net.get_batch_max_length(sample_input) def predict(self, img_path, model_save_path=None): """ 对传入的图像进行预测,支持图像地址和numpy数组 :param img_path: 图像地址 :return: """ assert os.path.exists(img_path), 'file is not exists' img = self.pre_processing(img_path) tensor = self.transform(img) tensor = tensor.unsqueeze(dim=0) tensor = tensor.to(self.device) preds, tensor_img = self.net(tensor) preds = preds.softmax(dim=2).detach().cpu().numpy() # result = decode(preds, self.alphabet, raw=True) # print(result) result = self.converter.decode(preds) if model_save_path is not None: # 输出用于部署的模型 save(self.net, tensor, model_save_path) return result, tensor_img def pre_processing(self, img_path): """ 对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度 :param img_path: 图片地址 :return: """ img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0) if self.img_mode == 'RGB': img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w = img.shape[:2] ratio_h = float(self.img_h) / h new_w = int(w * ratio_h) if new_w < self.img_w: img = cv2.resize(img, (new_w, self.img_h)) step = np.zeros((self.img_h, self.img_w - new_w, img.shape[-1]), dtype=img.dtype) img = np.column_stack((img, step)) else: img = cv2.resize(img, (self.img_w, self.img_h)) return img
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #import ipdb;ipdb.set_trace() train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") #print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf-8") as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) #import ipdb;ipdb.set_trace() if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute(1, 0, 2) # to use CTCLoss format # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text, preds_size, length) torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf-8") as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') #pred = pred.encode('utf-8') #gt = gt.encode('utf-8') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): plotDir = os.path.join(opt.exp_dir,opt.exp_name,'plots') if not os.path.exists(plotDir): os.makedirs(plotDir) lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') #considering the real images for discriminator opt.batch_size = opt.batch_size*2 train_dataset = Batch_Balanced_Dataset(opt) log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = AdaINGen(opt) ocrModel = Model(opt) disModel = MsImageDisV1(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for currModel in [model, ocrModel, disModel]: for name, param in currModel.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU ocrModel = torch.nn.DataParallel(ocrModel).to(device) if not opt.ocrFixed: ocrModel.train() else: ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() model = torch.nn.DataParallel(model).to(device) model.train() disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth")))>0: opt.saved_dis_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth"))[-1] #loading pre-trained model if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') if opt.FT: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False) else: ocrModel.load_state_dict(torch.load(opt.saved_ocr_model)) print("OCRModel:") print(ocrModel) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_synth_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_synth_model)) print("SynthModel:") print(model) if opt.saved_dis_model != '' and opt.saved_dis_model != 'None': print(f'loading pretrained discriminator model from {opt.saved_dis_model}') if opt.FT: disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False) else: disModel.load_state_dict(torch.load(opt.saved_dis_model)) print("DisModel:") print(disModel) """ setup loss """ if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 recCriterion = torch.nn.L1Loss() styleRecCriterion = torch.nn.L1Loss() # loss averager loss_avg_ocr = Averager() loss_avg = Averager() loss_avg_dis = Averager() loss_avg_ocrRecon_1 = Averager() loss_avg_ocrRecon_2 = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_styRecon = Averager() ##---------------------------------------## # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.optim=='adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("SynthOptimizer:") print(optimizer) #filter parameters for OCR training ocr_filtered_parameters = [] ocr_params_num = [] for p in filter(lambda p: p.requires_grad, ocrModel.parameters()): ocr_filtered_parameters.append(p) ocr_params_num.append(np.prod(p.size())) print('OCR Trainable params num : ', sum(ocr_params_num)) # setup optimizer if opt.optim=='adam': ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("OCROptimizer:") print(ocr_optimizer) #filter parameters for OCR training dis_filtered_parameters = [] dis_params_num = [] for p in filter(lambda p: p.requires_grad, disModel.parameters()): dis_filtered_parameters.append(p) dis_params_num.append(np.prod(p.size())) print('Dis Trainable params num : ', sum(dis_params_num)) # setup optimizer if opt.optim=='adam': dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay) print("DisOptimizer:") print(dis_optimizer) ##---------------------------------------## """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) ocr_scheduler = get_scheduler(ocr_optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() best_accuracy = -1 best_norm_ED = -1 best_accuracy_ocr = -1 best_norm_ED_ocr = -1 iteration = start_iter cntr=0 while(True): # train part if opt.lr_policy !="None": scheduler.step() ocr_scheduler.step() dis_scheduler.step() image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch() # ## comment # pdb.set_trace() # for imgCntr in range(image_tensors.shape[0]): # save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png') # pdb.set_trace() # ### # print(cntr) cntr+=1 disCnt = int(image_tensors_all.size(0)/2) image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[disCnt:disCnt+disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt] image = image_tensors.to(device) image_real = image_tensors_real.to(device) batch_size = image.size(0) ##-----------------------------------## #generate text(labels) from ocr.forward if opt.ocrFixed: # ocrModel.eval() length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = ocrModel(image, text_for_pred) preds = preds[:, :text_for_loss.shape[1] - 1, :] preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel(image, text_for_pred, is_train=False) _, preds_index = preds.max(2) labels_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(labels_1): pred_EOS = pred.find('[s]') labels_1[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # ocrModel.train() else: labels_1 = labels_gt ##-----------------------------------## text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator images_recon_1, images_recon_2, style = model(image, text_1, text_2) if 'CTC' in opt.Prediction: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1) preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size) preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2) ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1) preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size) preds_1 = preds_1.log_softmax(2).permute(1, 0, 2) preds_2 = ocrModel(images_recon_2, text_2) preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size) preds_2 = preds_2.log_softmax(2).permute(1, 0, 2) ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1) ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2) # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 ) else: if not opt.ocrFixed: #ocr training with orig image preds_ocr = ocrModel(image, text_1[:, :-1]) # align with Attention.forward target_ocr = text_1[:, 1:] # without [GO] Symbol ocrCost_train = ocrCriterion(preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1)) #content loss for reconstructed images preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False) # align with Attention.forward target_1 = text_1[:, 1:] # without [GO] Symbol preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1)) ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1)) # ocrCost = 0.5*(ocrCost_1+ocrCost_2) if not opt.ocrFixed: #training OCR ocrModel.zero_grad() ocrCost_train.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() #if ocr is fixed; ignore this loss loss_avg_ocr.add(ocrCost_train) else: loss_avg_ocr.add(torch.tensor(0.0)) #Domain discriminator: Dis update disModel.zero_grad() disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image)) disCost.backward() # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) dis_optimizer.step() loss_avg_dis.add(disCost) # #[Style Encoder] + [Word Generator] update #Adversarial loss disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2)) #Input reconstruction loss recCost = recCriterion(images_recon_1,image) #Pair style reconstruction loss if opt.styleReconWeight == 0.0: styleRecCost = torch.tensor(0.0) else: if opt.styleDetach: styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style.detach()) else: styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style) #OCR Content cost ocrCost = 0.5*(ocrCost_1+ocrCost_2) cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.styleReconWeight*styleRecCost model.zero_grad() ocrModel.zero_grad() disModel.zero_grad() cost.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) #Individual losses loss_avg_ocrRecon_1.add(opt.ocrWeight*0.5*ocrCost_1) loss_avg_ocrRecon_2.add(opt.ocrWeight*0.5*ocrCost_2) loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(opt.reconWeight*recCost) loss_avg_styRecon.add(opt.styleReconWeight*styleRecCost) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: save_image(tensor2im(image[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_input_'+labels_gt[trImgCntr]+'.png')) save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_recon_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_pair_'+labels_2[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: model.eval() ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() disModel.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res( iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt) model.train() if not opt.ocrFixed: ocrModel.train() else: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() ocrModel.module.SequenceModeling.train() # ocrModel.module.Prediction.eval() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}' current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}' current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}' #plotting lib.plot.plot(os.path.join(plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) lib.plot.plot(os.path.join(plotDir,'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Synth-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Dis-Loss'), valid_loss[2].item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon1-Loss'), valid_loss[3].item()) lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon2-Loss'), valid_loss[4].item()) lib.plot.plot(os.path.join(plotDir,'Valid-Gen-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(plotDir,'Valid-ImgRecon1-Loss'), valid_loss[6].item()) lib.plot.plot(os.path.join(plotDir,'Valid-StyRecon2-Loss'), valid_loss[7].item()) lib.plot.plot(os.path.join(plotDir,'Orig-OCR-WordAccuracy'), current_accuracy[0]) lib.plot.plot(os.path.join(plotDir,'Recon-OCR-WordAccuracy'), current_accuracy[1]) lib.plot.plot(os.path.join(plotDir,'Pair-OCR-WordAccuracy'), current_accuracy[2]) lib.plot.plot(os.path.join(plotDir,'Orig-OCR-CharAccuracy'), current_norm_ED[0]) lib.plot.plot(os.path.join(plotDir,'Recon-OCR-CharAccuracy'), current_norm_ED[1]) lib.plot.plot(os.path.join(plotDir,'Pair-OCR-CharAccuracy'), current_norm_ED[2]) # keep best accuracy model (on valid dataset) if current_accuracy[1] > best_accuracy: best_accuracy = current_accuracy[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth')) if current_norm_ED[1] > best_norm_ED: best_norm_ED = current_norm_ED[1] torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth')) torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth')) best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy[0] > best_accuracy_ocr: best_accuracy_ocr = current_accuracy[0] if not opt.ocrFixed: torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth')) if current_norm_ED[0] > best_norm_ED_ocr: best_norm_ED_ocr = current_norm_ED[0] if not opt.ocrFixed: torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth')) best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]): if 'Attn' in opt.Prediction: # gt_ocr = gt_ocr[:gt_ocr.find('[s]')] pred_ocr = pred_ocr[:pred_ocr.find('[s]')] # gt_1 = gt_1[:gt_1.find('[s]')] pred_1 = pred_1[:pred_1.find('[s]')] # gt_2 = gt_2[:gt_2.find('[s]')] pred_2 = pred_2[:pred_2.find('[s]')] predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n' predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n' predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') loss_avg_ocr.reset() loss_avg.reset() loss_avg_dis.reset() loss_avg_ocrRecon_1.reset() loss_avg_ocrRecon_2.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_styRecon.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+5 == 0: torch.save( model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if not opt.ocrFixed: torch.save( ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_ocr.pth')) torch.save( disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_dis.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo # AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset # demo_loader = torch.utils.data.DataLoader( # demo_data, batch_size=opt.batch_size, # shuffle=False, # num_workers=int(opt.workers), # collate_fn=AlignCollate_demo, pin_memory=True) demo_data = baidu_raw_dataset(root=opt.image_folder, opt=opt) demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), collate_fn=BaiduCollate( opt.imgH, keep_ratio=True), pin_memory=True) outf = open(opt.out_csv, 'w') # predict model.eval() for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) print(f'{img_name}\t{pred}') content = os.path.basename(img_name) + '\t' + pred + '\n' outf.write(content) outf.flush() outf.close()
parser.add_argument( "--input_channel", type=int, default=1, help="the number of input channel of Feature extractor", ) parser.add_argument( "--output_channel", type=int, default=512, help="the number of output channel of Feature extractor", ) parser.add_argument( "--hidden_size", type=int, default=256, help="the size of the LSTM hidden state" ) opt = parser.parse_args() if "CTC" in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) model = Model(opt) model.load_state_dict(torch.load(opt.model_path)) model.to(device) model.eval() for img_path in os.listdir("task1_2_test(361p)"): predict("task1_2_test(361p)/" + img_path, model)
def demoToTxt1(image_folder, saved_model, txtFile): # sensitive parser = argparse.ArgumentParser() parser.add_argument('--image_folder', default=image_folder, help='path to image_folder which contains text images') parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) parser.add_argument('--batch_size', type=int, default=100, help='input batch size') parser.add_argument('--saved_model', default=saved_model, help="path to saved_model to evaluation") """ Data processing """ parser.add_argument('--batch_max_length', type=int, default=20, help='maximum-label-length') parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') parser.add_argument('--rgb', action='store_true', help='use rgb input') parser.add_argument('--character', type=str, default='0123456789', help='character label') parser.add_argument('--sensitive', default=True, help='for sensitive character mode') parser.add_argument('--PAD', default=False, action='store_true', help='whether to keep ratio then pad for image resize') """ Model Architecture """ parser.add_argument('--Transformation', default='TPS', type=str, help='Transformation stage. None|TPS') parser.add_argument('--FeatureExtraction', default='ResNet', type=str, help='FeatureExtraction stage. VGG|RCNN|ResNet') parser.add_argument('--SequenceModeling', default='BiLSTM', type=str, help='SequenceModeling stage. None|BiLSTM') parser.add_argument('--Prediction', default='CTC', type=str, help='Prediction stage. CTC|Attn') parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') parser.add_argument( '--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') parser.add_argument( '--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') opt = parser.parse_args() """ vocab / character number configuration """ if opt.sensitive: opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' # opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). cudnn.benchmark = True cudnn.deterministic = True opt.num_gpu = torch.cuda.device_count() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() saved_file = open(txtFile, 'w') for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) with torch.no_grad(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) print(f'{img_name}\t{pred}') saved_file.write(f'{img_name}\t{pred}\n')
def demo(opt): """ model configuration """ lists = [] #목적지라고 생각하는 사진에서 인식한 text를 담을 배열 converter = AttnLabelConverter(opt.character) #ATTN opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) #model.py의 Model import print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) #파라미터 값 정보 출력 model = torch.nn.DataParallel(model).to(device) #GPU로 데이터 병렬 처리 진행 # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) #모델의 매개변수를 불러옴 AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data1 = RawDataset(root=opt.image_folder1, opt=opt) # use RawDataset 간판탐지결과 demo_data2 = RawDataset(root=opt.image_folder2, opt=opt) # use RawDataset 구글맵문자열탐지결과 demo_loader1 = torch.utils.data.DataLoader(demo_data1, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) demo_loader2 = torch.utils.data.DataLoader(demo_data2, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader1: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) #ATTn preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') #이어서 쓸수 있게 열고 dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') #테이블 양식 출력 log.write( f'{dashed_line}\n{head}\n{dashed_line}\n') #txt에 테이블 양식 저장 preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence score 값을 계산 confidence_score = pred_max_prob.cumprod(dim=0)[-1] lists.append(pred) print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}' ) #구한 값을 출력 log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n' ) #구한 값을 txt에 저장 log.close() #파일 닫기 with torch.no_grad(): for image_tensors, image_path_list in demo_loader2: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) #ATTn preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') #이어서 쓸수 있게 열고 dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') #테이블 양식 출력 log.write( f'{dashed_line}\n{head}\n{dashed_line}\n') #txt에 테이블 양식 저장 preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # confidence score 값을 계산 confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}' ) #구한 값을 출력 log.write( f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n' ) #구한 값을 txt에 저장 if pred in lists: print(pred + "은(는) 알맞은 목적지입니다.") else: print(pred + "은(는) 알맞은 목적지가 아닙니다.") log.close() #파일 닫기
def train(opt): os.makedirs(opt.log, exist_ok=True) writer = SummaryWriter(opt.log) """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ ctc_converter = CTCLabelConverter(opt.character) attn_converter = AttnLabelConverter(opt.character) opt.num_class = len(attn_converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) """ setup loss """ loss_avg = Averager() ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device) attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") """ final options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter pbar = tqdm(range(opt.num_iter)) for iteration in pbar: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) ctc_text, ctc_length = ctc_converter.encode( labels, batch_max_length=opt.batch_max_length) attn_text, attn_length = attn_converter.encode( labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) preds, refiner = model(image, attn_text[:, :-1]) refiner_size = torch.IntTensor([refiner.size(1)] * batch_size) refiner = refiner.log_softmax(2).permute(1, 0, 2) refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length) total_loss = opt.lambda_ctc * refiner_loss target = attn_text[:, 1:] # without [GO] Symbol for pred in preds: total_loss += opt.lambda_attn * attn_loss( pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) model.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(total_loss) if loss_avg.val() <= 0.6: opt.grad_clip = 2 if loss_avg.val() <= 0.3: opt.grad_clip = 1 preds = (p.cpu() for p in preds) refiner = refiner.cpu() image = image.cpu() torch.cuda.empty_cache() writer.add_scalar('train_loss', loss_avg.val(), iteration) pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format( iteration, opt.num_iter, loss_avg.val())) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, attn_loss, valid_loader, attn_converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' writer.add_scalar('Val_loss', valid_loss) pbar.set_description(loss_log) loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth' ) if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' # print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' or 'Transformer' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' log.write(predicted_result_log + '\n') # save model per 1e+3 iter. if (iteration + 1) % 1e+3 == 0: torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/SCATTER_STR.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit()
def train(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length+2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images shuffle=False, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) print('-' * 80) log.write('-' * 80 + '\n') log.close() text_dataset = text_gen(opt) text_loader = torch.utils.data.DataLoader( text_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), pin_memory=True, drop_last=True) opt.num_class = len(converter.character) c_code_size = opt.latent cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size) ocrModel = ModelV1(opt) genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier) disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size) accumulate(g_ema, genModel, 0) # uCriterion = torch.nn.MSELoss() # sCriterion = torch.nn.MSELoss() # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': # ocrCriterion = torch.nn.L1Loss() # else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: print('Not implemented error') sys.exit() # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 cEncoder= torch.nn.DataParallel(cEncoder).to(device) cEncoder.train() genModel = torch.nn.DataParallel(genModel).to(device) g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() disEncModel = torch.nn.DataParallel(disEncModel).to(device) disEncModel.train() ocrModel = torch.nn.DataParallel(ocrModel).to(device) if opt.ocrFixed: if opt.Transformation == 'TPS': ocrModel.module.Transformation.eval() ocrModel.module.FeatureExtraction.eval() ocrModel.module.AdaptiveAvgPool.eval() # ocrModel.module.SequenceModeling.eval() ocrModel.module.Prediction.eval() else: ocrModel.train() g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( list(genModel.parameters())+list(cEncoder.parameters()), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disEncModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ocr_optimizer = optim.Adam( ocrModel.parameters(), lr=opt.lr, betas=(0.9, 0.99), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) # if opt.saved_gen_model !='' and opt.saved_gen_model !='None': # print(f'loading pretrained gen model from {opt.saved_gen_model}') # checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) # genModel.module.load_state_dict(checkpoint['g']) # g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disEncModel.load_state_dict(checkpoint['disEncModel']) ocrModel.load_state_dict(checkpoint['ocrModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_unsup = Averager() loss_avg_sup = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() loss_avg_ocr_sup = Averager() loss_avg_ocr_unsup = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) ocr_scheduler = get_scheduler(ocr_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 # loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 epsilon = 10e-50 # sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() ocr_scheduler.step() image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() image_input_tensors = image_input_tensors.to(device) gt_image_tensors = image_input_tensors[:opt.batch_size].detach() real_image_tensors = image_input_tensors[opt.batch_size:].detach() labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, False) requires_grad(genModel, False) requires_grad(disEncModel, True) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) # #unsupervised code prediction on generated image # u_pred_code = disEncModel(fake_img, mode='enc') # uCost = uCriterion(u_pred_code, z_code) # #supervised code prediction on gt image # s_pred_code = disEncModel(gt_image_tensors, mode='enc') # sCost = uCriterion(s_pred_code, gt_phoc_tensors) #Domain discriminator fake_pred = disEncModel(fake_img) real_pred = disEncModel(real_image_tensors) disCost = d_logistic_loss(real_pred, fake_pred) # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost loss_avg_dis.add(disCost) # loss_avg_sup.add(opt.beta*sCost) # loss_avg_unsup.add(opt.gamma_e * uCost) disEncModel.zero_grad() disCost.backward() dis_optimizer.step() d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: real_image_tensors.requires_grad = True real_pred = disEncModel(real_image_tensors) r1_loss = d_r1_loss(real_pred, real_image_tensors) disEncModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() log_r1_val.add(r1_loss) # Recognizer update if not opt.ocrFixed and not opt.zAlone: requires_grad(disEncModel, False) requires_grad(ocrModel, True) if 'CTC' in opt.Prediction: preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt) else: print("Not implemented error") sys.exit() ocrModel.zero_grad() ocrCost.backward() # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) ocr_optimizer.step() loss_avg_ocr_sup.add(ocrCost) else: loss_avg_ocr_sup.add(torch.tensor(0.0)) # [Word Generator] update # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:opt.batch_size] # real_image_tensors = image_input_tensors[opt.batch_size:] # labels_gt = labels[:opt.batch_size] requires_grad(cEncoder, True) requires_grad(genModel, True) requires_grad(disEncModel, False) requires_grad(ocrModel, False) text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img,_ = genModel(style, input_is_latent=opt.input_latent) fake_pred = disEncModel(fake_img) disGenCost = g_nonsaturating_loss(fake_pred) if opt.zAlone: ocrCost = torch.tensor(0.0) else: #Compute OCR prediction (Reconstruction of content) # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device) if 'CTC' in opt.Prediction: preds_recon = ocrModel(fake_img, text_z_c, is_train=False) preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c) else: print("Not implemented error") sys.exit() genModel.zero_grad() cEncoder.zero_grad() gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0] loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0] loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) if opt.grad_balance: gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0] a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR)) if a is None: print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) if a>1000 or a<0.0001: print(a) ocrCost = a.detach() * ocrCost gen_enc_cost = disGenCost + ocrCost gen_enc_cost.backward(retain_graph=True) grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0] grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0] loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) with torch.no_grad(): gen_enc_cost.backward() else: gen_enc_cost.backward() loss_avg_gen.add(disGenCost) loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost) optimizer.step() g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink) # image_input_tensors, _, labels, _ = iter(train_loader).next() labels_z_c = iter(text_loader).next() # image_input_tensors = image_input_tensors.to(device) # gt_image_tensors = image_input_tensors[:path_batch_size] # labels_gt = labels[:path_batch_size] text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length) # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length) z_c_code = cEncoder(text_z_c) noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device) style=[] style.append(noise_style[0]*z_c_code) if len(noise_style)>1: style.append(noise_style[1]*z_c_code) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style[0][:,:opt.latent]) if len(style)>1: newstyle.append(style[1][:,:opt.latent]) style = newstyle fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length) decay = 0.01 path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_loss = (path_lengths - mean_path_length_orig).pow(2).mean() mean_path_length = mean_path_length_orig.detach().item() genModel.zero_grad() cEncoder.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() # mean_path_length_avg = ( # reduce_sum(mean_path_length).item() / get_world_size() # ) #commented above for multi-gpu , non-distributed setting mean_path_length_avg = mean_path_length accumulate(g_ema, genModel, accum) log_avg_path_loss_val.add(path_loss) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #generate paired content with similar style labels_z_c_1 = iter(text_loader).next() labels_z_c_2 = iter(text_loader).next() text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length) text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length) z_c_code_1 = cEncoder(text_z_c_1) z_c_code_2 = cEncoder(text_z_c_2) style_c1_s1 = [] style_c2_s1 = [] style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s1.append(style_s1[0]*z_c_code_1) style_c2_s1.append(style_s1[0]*z_c_code_2) if len(style_s1)>1: style_c1_s1.append(style_s1[1]*z_c_code_1) style_c2_s1.append(style_s1[1]*z_c_code_2) noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device) style_c1_s2 = [] style_c1_s2.append(noise_style[0]*z_c_code_1) if len(noise_style)>1: style_c1_s2.append(noise_style[1]*z_c_code_1) if opt.zAlone: #to validate orig style gan results newstyle = [] newstyle.append(style_c1_s1[0][:,:opt.latent]) if len(style_c1_s1)>1: newstyle.append(style_c1_s1[1][:,:opt.latent]) style_c1_s1 = newstyle style_c2_s1 = newstyle style_c1_s2 = newstyle fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent) fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent) fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent) if not opt.zAlone: #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) _, preds_index = preds.max(2) preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data) preds = ocrModel(gt_image_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt = converter.decode(preds_index.data, preds_size.data) else: print("Not implemented error") sys.exit() else: preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0] preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(opt.batch_size): try: save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png')) if not opt.zAlone: save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png')) save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png')) if trImgCntr<gt_image_tensors.shape[0]: save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) print(loss_log) loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_ocr_unsup.reset() loss_avg_ocr_sup.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ 'cEncoder':cEncoder.state_dict(), 'genModel':genModel.state_dict(), 'g_ema':g_ema.state_dict(), 'ocrModel':ocrModel.state_dict(), 'disEncModel':disEncModel.state_dict(), 'optimizer':optimizer.state_dict(), 'ocr_optimizer':ocr_optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
def train(opt, tb): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] print("~~~~~~~~~~~~Gradient Descent~~~~~~~~~~~~~") #print(model.parameters()) #print(model.) for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Filtered parameters for gradient descent: \n', len(filtered_parameters)) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 i = start_iter while(True): # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward print(preds[0][0]) target = text[:, 1:] # without [GO] Symbol print(target[0]) cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' tb.add_scalar('Training Loss vs Iteration', loss_avg.val(), i) # Record to Tensorboard tb.add_scalar('Validation Loss vs Iteration', valid_loss, i) # Record to Tensorboard loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' tb.add_scalar('Current Accuracy vs Iteration', current_accuracy, i) # Record to Tensorboard tb.add_scalar('Current Norm ED vs Iteration', current_norm_ED, i) # Record to Tensorboard # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def train(opt): lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(train_dataset_log) print('-' * 80) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.num_class = len(converter.character) # styleModel = StyleTensorEncoder(input_dim=opt.input_channel) # genModel = AdaIN_Tensor_WordGenerator(opt) # disModel = MsImageDisV2(opt) # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none') # mixModel = Mixer(opt,nblk=3, dim=opt.latent) genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device) g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) ocrModel = ModelV1(opt).to(device) accumulate(g_ema, genModel, 0) # # weight initialization # for currModel in [styleModel, mixModel]: # for name, param in currModel.named_parameters(): # if 'localization_fc2' in name: # print(f'Skip {name} as it is already initialized') # continue # try: # if 'bias' in name: # init.constant_(param, 0.0) # elif 'weight' in name: # init.kaiming_normal_(param) # except Exception as e: # for batchnorm. # if 'weight' in name: # param.data.fill_(1) # continue if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': ocrCriterion = torch.nn.L1Loss() else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # vggRecCriterion = torch.nn.L1Loss() # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) if opt.distributed: genModel = torch.nn.parallel.DistributedDataParallel( genModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) disModel = torch.nn.parallel.DistributedDataParallel( disModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) ocrModel = torch.nn.parallel.DistributedDataParallel( ocrModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False ) # styleModel = torch.nn.DataParallel(styleModel).to(device) # styleModel.train() # mixModel = torch.nn.DataParallel(mixModel).to(device) # mixModel.train() # genModel = torch.nn.DataParallel(genModel).to(device) # g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() # disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() # vggModel = torch.nn.DataParallel(vggModel).to(device) # vggModel.eval() # ocrModel = torch.nn.DataParallel(ocrModel).to(device) # if opt.distributed: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() # # ocrModel.module.SequenceModeling.eval() # ocrModel.module.Prediction.eval() # else: # ocrModel.Transformation.eval() # ocrModel.FeatureExtraction.eval() # ocrModel.AdaptiveAvgPool.eval() # # ocrModel.SequenceModeling.eval() # ocrModel.Prediction.eval() ocrModel.eval() if opt.distributed: g_module = genModel.module d_module = disModel.module else: g_module = genModel d_module = disModel g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( genModel.parameters(), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': if not opt.distributed: ocrModel = torch.nn.DataParallel(ocrModel) print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) #temporary fix if not opt.distributed: ocrModel = ocrModel.module if opt.saved_gen_model !='' and opt.saved_gen_model !='None': print(f'loading pretrained gen model from {opt.saved_gen_model}') checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) genModel.module.load_state_dict(checkpoint['g']) g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disModel.load_state_dict(checkpoint['disModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg = Averager() loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_vgg_per = Averager() loss_avg_vgg_sty = Averager() loss_avg_ocr = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, False) # requires_grad(styleModel, False) # requires_grad(mixModel, False) requires_grad(disModel, True) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device) # scInput = mixModel(style,text_2) if 'CTC' in opt.Prediction: images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent) else: images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent) #Domain discriminator: Dis update if opt.augment: image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p) images_recon_2, _ = augment(images_recon_2, ada_aug_p) else: image_gt_tensors_aug = image_gt_tensors fake_pred = disModel(images_recon_2) real_pred = disModel(image_gt_tensors_aug) disCost = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = disCost*opt.disWeight loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() loss_avg_dis.add(disCost) disModel.zero_grad() disCost.backward() dis_optimizer.step() if opt.augment and opt.augment_p == 0: ada_augment += torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device ) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > opt.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: image_gt_tensors.requires_grad = True image_input_tensors.requires_grad = True cat_tensor = image_gt_tensors real_pred = disModel(cat_tensor) r1_loss = d_r1_loss(real_pred, cat_tensor) disModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() loss_dict["r1"] = r1_loss # #[Style Encoder] + [Word Generator] update image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, True) # requires_grad(styleModel, True) # requires_grad(mixModel, True) requires_grad(disModel, False) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent) style = mixing_noise(batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, _ = genModel(style, text_2) else: images_recon_2, _ = genModel(style, text_2[:,1:-1]) if opt.augment: images_recon_2, _ = augment(images_recon_2, ada_aug_p) fake_pred = disModel(images_recon_2) disGenCost = g_nonsaturating_loss(fake_pred) loss_dict["g"] = disGenCost # # #Adversarial loss # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1)) # #Input reconstruction loss # recCost = recCriterion(images_recon_2,image_gt_tensors) # #vgg loss # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2) #ocr loss text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss) preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss) ocrCost = ocrCriterion(preds_recon, preds_gt) else: if 'CTC' in opt.Prediction: preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False) # preds_o = preds_recon[:, :text_1.shape[1], :] preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2) #predict ocr recognition on generated images # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size) _, preds_recon_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) # preds_s = preds_s[:, :text_1.shape[1] - 1, :] preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :] preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data) else: preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1)) #predict ocr recognition on generated images _, preds_o_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_o_index, length_for_pred) for idx, pred in enumerate(labels_o_ocr): pred_EOS = pred.find('[s]') labels_o_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index, length_for_pred) for idx, pred in enumerate(labels_s_ocr): pred_EOS = pred.find('[s]') labels_s_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred) for idx, pred in enumerate(labels_sc_ocr): pred_EOS = pred.find('[s]') labels_sc_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # cost = opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost cost = opt.disWeight*disGenCost + opt.ocrWeight*ocrCost # styleModel.zero_grad() genModel.zero_grad() # mixModel.zero_grad() disModel.zero_grad() # vggModel.zero_grad() ocrModel.zero_grad() cost.backward() optimizer.step() loss_avg.add(cost) g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) path_batch_size = max(1, batch_size // opt.path_batch_shrink) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True) style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True) else: images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( images_recon_2, latents, mean_path_length ) genModel.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() #Individual losses loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(torch.tensor(0.0)) loss_avg_vgg_per.add(torch.tensor(0.0)) loss_avg_vgg_sty.add(torch.tensor(0.0)) loss_avg_ocr.add(opt.ocrWeight*ocrCost) log_r1_val.add(loss_reduced["path"]) log_avg_path_loss_val.add(loss_reduced["path"]) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: # pbar.set_description( # ( # f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " # f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " # f"augment: {ada_aug_p:.4f}" # ) # ) if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # if cntr % 100 == 0: # with torch.no_grad(): # g_ema.eval() # sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]]) # utils.save_image( # sample, # os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"), # nrow=int(opt.n_sample ** 0.5), # normalize=True, # range=(-1, 1), # ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images curr_batch_size = style[0].shape[0] images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent) os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png')) else: save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # styleModel.eval() genModel.eval() g_ema.eval() # mixModel.eval() disModel.eval() with torch.no_grad(): valid_loss, infer_time, length_of_data = validation_synth_v6( iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt) # styleModel.train() genModel.train() # mixModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train OCR loss: {loss_avg_ocr.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Valid Synth loss: {valid_loss[0]:0.5f}, \ Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \ Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item()) print(loss_log) loss_avg.reset() loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_vgg_per.reset() loss_avg_vgg_sty.reset() loss_avg_ocr.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ # 'styleModel':styleModel.state_dict(), # 'mixModel':mixModel.state_dict(), 'genModel':g_module.state_dict(), 'g_ema':g_ema.state_dict(), 'disModel':d_module.state_dict(), 'optimizer':optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1