def main(): # Set parameters of the trainer global args, device args = parse_args() print('=' * 60) print(args) print('=' * 60) random.seed(args.manual_seed) np.random.seed(args.manual_seed) torch.manual_seed(args.manual_seed) if torch.cuda.is_available() and not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if args.cuda and torch.cuda.is_available(): device = torch.device('cuda') #cudnn.benchmark = True else: device = torch.device('cpu') # load alphabet from file if os.path.isfile(args.alphabet): alphabet = '' with open(args.alphabet, mode='rb') as f: for line in f.readlines(): alphabet += line.decode('utf-8')[0] args.alphabet = alphabet converter = utils.CTCLabelConverter(args.alphabet, ignore_case=False) # data loader image_size = (args.image_h, args.image_w) collater = DatasetCollater(image_size, keep_ratio=args.keep_ratio) train_dataset = Dataset(mode='train', data_root=args.data_root, transform=None) #sampler = RandomSequentialSampler(train_dataset, args.batch_size) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collater, shuffle=True, num_workers=args.workers) val_dataset = Dataset(mode='val', data_root=args.data_root, transform=None) val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collater, shuffle=True, num_workers=args.workers) # network num_classes = len(args.alphabet) + 1 num_channels = 1 if args.arch == 'crnn': model = CRNN(args.image_h, num_channels, num_classes, args.num_hidden) elif args.arch == 'densenet': model = DenseNet( num_channels=num_channels, num_classes=num_classes, growth_rate=12, block_config=(3, 6, 9), #(3,6,12,16), compression=0.5, num_init_features=64, bn_size=4, rnn=args.rnn, num_hidden=args.num_hidden, drop_rate=0, small_inputs=True, efficient=False) else: raise ValueError('unknown architecture {}'.format(args.arch)) model = model.to(device) summary(model, torch.zeros((2, 1, 32, 650)).to(device)) #print('='*60) #print(model) #print('='*60) # loss criterion = CTCLoss() criterion = criterion.to(device) # setup optimizer if args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) elif args.optimizer == 'rmsprop': optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay) elif args.optimizer == 'adadelta': optimizer = optim.Adadelta(model.parameters(), weight_decay=args.weight_decay) else: raise ValueError('unknown optimizer {}'.format(args.optimizer)) print('=' * 60) print(optimizer) print('=' * 60) # Define learning rate decay schedule global scheduler #exp_decay = math.exp(-0.1) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate) #step_size = 10000 #gamma_decay = 0.8 #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma_decay) #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=gamma_decay) # initialize model if args.pretrained and os.path.isfile(args.pretrained): print(">> Using pre-trained model '{}'".format( os.path.basename(args.pretrained))) state_dict = torch.load(args.pretrained) model.load_state_dict(state_dict) print("loading pretrained model done.") global is_best, best_accuracy is_best = False best_accuracy = 0.0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): # load checkpoint weights and update model and optimizer print(">> Loading checkpoint:\n>> '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_accuracy = checkpoint['best_accuracy'] print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})".format( args.resume, start_epoch)) model.load_state_dict(checkpoint['state_dict']) #optimizer.load_state_dict(checkpoint['optimizer']) # important not to forget scheduler updating #scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate, last_epoch=start_epoch - 1) else: print(">> No checkpoint found at '{}'".format(args.resume)) # Create export dir if it doesnt exist checkpoint = "{}".format(args.arch) checkpoint += "_{}".format(args.optimizer) checkpoint += "_lr_{}".format(args.lr) checkpoint += "_decay_rate_{}".format(args.decay_rate) checkpoint += "_bsize_{}".format(args.batch_size) checkpoint += "_height_{}".format(args.image_h) checkpoint += "_keep_ratio" if args.keep_ratio else "_width_{}".format( image_size[1]) args.checkpoint = os.path.join(args.checkpoint, checkpoint) if not os.path.exists(args.checkpoint): os.makedirs(args.checkpoint) print('start training...') for epoch in range(start_epoch, args.max_epoch): # Aujust learning rate for each epoch scheduler.step() # Train for one epoch on train set _ = train(train_loader, val_loader, model, criterion, optimizer, epoch, converter)
def main(): config = Config() if not os.path.exists(config.expr_dir): os.makedirs(config.expr_dir) if torch.cuda.is_available() and not config.use_cuda: print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True") # 加载训练数据集 train_dataset = HubDataset(config, "train", transform=None) train_kwargs = {'num_workers': 2, 'pin_memory': True, 'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {} training_data_batch = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=False, **train_kwargs) # 加载定长校验数据集 eval_dataset = HubDataset(config, "eval", transform=transforms.Compose([ResizeNormalize(config.img_height, config.img_width)])) eval_kwargs = {'num_workers': 2, 'pin_memory': False} if torch.cuda.is_available() else {} eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs) # 加载不定长校验数据集 # eval_dataset = HubDataset(config, "eval") # eval_kwargs = {'num_workers': 2, 'pin_memory': False, # 'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {} # eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs) # 定义网络模型 nclass = len(config.label_classes) + 1 crnn = CRNN(config.img_height, config.nc, nclass, config.hidden_size, n_rnn=config.n_layers) # 加载预训练模型 if config.pretrained != '': print('loading pretrained model from %s' % config.pretrained) crnn.load_state_dict(torch.load(config.pretrained)) print(crnn) # Compute average for `torch.Variable` and `torch.Tensor`. loss_avg = utils.averager() # Convert between str and label. converter = utils.strLabelConverter(config.label_classes) criterion = CTCLoss() # 定义损失函数 # 设置占位符 image = torch.FloatTensor(config.train_batch_size, 3, config.img_height, config.img_height) text = torch.LongTensor(config.train_batch_size * 5) length = torch.LongTensor(config.train_batch_size) if config.use_cuda and torch.cuda.is_available(): criterion = criterion.cuda() image = image.cuda() crnn = crnn.to(config.device) image = Variable(image) text = Variable(text) length = Variable(length) # 设定优化器 if config.adam: optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999)) elif config.adadelta: optimizer = optim.Adadelta(crnn.parameters()) else: optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr) def val(net, criterion, eval_data_batch): print('Start val') for p in crnn.parameters(): p.requires_grad = False net.eval() n_correct = 0 loss_avg_eval = utils.averager() for data in eval_data_batch: cpu_images, cpu_texts = data batch_size = cpu_images.size(0) utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) utils.loadData(text, t) utils.loadData(length, l) preds = crnn(image) preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size loss_avg_eval.add(cost) # 计算loss _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = converter.decode(preds.data, preds_size.data, raw=False) cpu_texts_decode = [] for i in cpu_texts: cpu_texts_decode.append(i) for pred, target in zip(sim_preds, cpu_texts_decode): # 计算准确率 if pred == target: n_correct += 1 raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.n_val_disp] for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode): print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) accuracy = n_correct / float(len(eval_dataset)) print('Val loss: %f, accuray: %f' % (loss_avg.val(), accuracy)) # 训练每个batch数据 def train(net, criterion, optimizer, data): cpu_images, cpu_texts = data batch_size = cpu_images.size(0) # 计算当前batch_size大小 utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) # 转换为类别 utils.loadData(text, t) utils.loadData(length, l) optimizer.zero_grad() # 清零梯度 preds = net(image) preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size cost.backward() optimizer.step() return cost for epoch in range(config.nepoch): i = 0 for batch_data in training_data_batch: for p in crnn.parameters(): p.requires_grad = True crnn.train() cost = train(crnn, criterion, optimizer, batch_data) loss_avg.add(cost) i += 1 if i % config.displayInterval == 0: print('[%d/%d][%d/%d] Loss: %f' % (epoch, config.nepoch, i, len(training_data_batch), loss_avg.val())) loss_avg.reset() # if i % config.valInterval == 0: # val(crnn, criterion, eval_data_batch) # # # do checkpointing # if i % config.saveInterval == 0: # torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(config.expr_dir, epoch, i)) val(crnn, criterion, eval_data_batch) torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_end.pth'.format(config.expr_dir, epoch))