def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger): """ 训练函数 :param net: 网络 :param optimizer: 优化器 :param scheduler: 学习率更新器 :param loss_func: loss函数 :param train_loader: 训练数据集 dataloader :param eval_loader: 验证数据集 dataloader :param to_use_device: device :param cfg: 当前训练所使用的配置 :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息 :param logger: logger 对象 :return: None """ from torchocr.metrics import RecMetric from torchocr.utils import CTCLabelConverter converter = CTCLabelConverter(cfg.dataset.alphabet) train_options = cfg.train_options metric = RecMetric(converter) # ===> logger.info('Training...') # ===> print loss信息的参数 all_step = len(train_loader) logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader') logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader') if len(global_state) > 0: best_model = global_state['best_model'] start_epoch = global_state['start_epoch'] global_step = global_state['global_step'] else: best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.} start_epoch = 0 global_step = 0 # 开始训练 try: for epoch in range(start_epoch, train_options['epochs']): # traverse each epoch net.train() # train mode start = time.time() for i, batch_data in enumerate(train_loader): # traverse each batch in the epoch current_lr = optimizer.param_groups[0]['lr'] cur_batch_size = batch_data['img'].shape[0] targets, targets_lengths = converter.encode(batch_data['label']) batch_data['targets'] = targets batch_data['targets_lengths'] = targets_lengths # 清零梯度及反向传播 optimizer.zero_grad() output = net.forward(batch_data['img'].to(to_use_device)) loss_dict = loss_func(output, batch_data) loss_dict['loss'].backward() torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optimizer.step() # statistic loss for print acc_dict = metric(output, batch_data['label']) acc = acc_dict['n_correct'] / cur_batch_size norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size if (i + 1) % train_options['print_interval'] == 0: interval_batch_time = time.time() - start logger.info(f"[{epoch}/{train_options['epochs']}] - " f"[{i + 1}/{all_step}] - " f"lr:{current_lr} - " f"loss:{loss_dict['loss'].item():.4f} - " f"acc:{acc:.4f} - " f"norm_edit_dis:{norm_edit_dis:.4f} - " f"time:{interval_batch_time:.4f}") start = time.time() if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0: global_state['start_epoch'] = epoch global_state['best_model'] = best_model global_state['global_step'] = global_step net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth" save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state) if train_options['ckpt_save_type'] == 'HighestAcc': # val eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric) if eval_dict['eval_acc'] > best_model['eval_acc']: best_model.update(eval_dict) best_model['best_model_epoch'] = epoch best_model['models'] = net_save_path global_state['start_epoch'] = epoch global_state['best_model'] = best_model global_state['global_step'] = global_step net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth" save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state) elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0: shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth')) global_step += 1 scheduler.step() except KeyboardInterrupt: import os save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state) except: error_msg = traceback.format_exc() logger.error(error_msg) finally: for k, v in best_model.items(): logger.info(f'{k}: {v}')
def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger, post_process): """ 训练函数 :param net: 网络 :param optimizer: 优化器 :param scheduler: 学习率更新器 :param loss_func: loss函数 :param train_loader: 训练数据集 dataloader :param eval_loader: 验证数据集 dataloader :param to_use_device: device :param cfg: 当前训练所使用的配置 :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息 :param logger: logger 对象 :param post_process: 后处理类对象 :return: None """ train_options = cfg.train_options metric = DetMetric() # ===> logger.info('Training...') # ===> print loss信息的参数 all_step = len(train_loader) logger.info( f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader' ) logger.info( f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader' ) if len(global_state) > 0: best_model = global_state['best_model'] start_epoch = global_state['start_epoch'] global_step = global_state['global_step'] else: best_model = { 'recall': 0, 'precision': 0, 'hmean': 0, 'best_model_epoch': 0 } start_epoch = 0 global_step = 0 # 开始训练 base_lr = cfg['optimizer']['lr'] all_iters = all_step * train_options['epochs'] warmup_iters = 3 * all_step # eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric) try: for epoch in range(start_epoch, train_options['epochs']): # traverse each epoch net.train() # train mode train_loss = 0. start = time.time() for i, batch_data in enumerate( train_loader): # traverse each batch in the epoch current_lr = adjust_learning_rate(optimizer, base_lr, global_step, all_iters, 0.9, warmup_iters=warmup_iters) # 数据进行转换和丢到gpu # for key, value in batch_data.items(): # if value is not None: # if isinstance(value, torch.Tensor): # batch_data[key] = value.to(to_use_device) # 清零梯度及反向传播 optimizer.zero_grad() output = net.forward(batch_data['img'].to(to_use_device)) labels, training_mask = batch_data['score_maps'].to( to_use_device), batch_data['training_mask'].to( to_use_device) loss_c, loss_s, loss = loss_func(output, labels, training_mask) loss.backward() optimizer.step() # statistic loss for print train_loss += loss.item() loss_str = 'loss: {:.4f} - '.format(loss.item()) if (i + 1) % train_options['print_interval'] == 0: interval_batch_time = time.time() - start logger.info(f"[{epoch}/{train_options['epochs']}] - " f"[{i + 1}/{all_step}] - " f"lr:{current_lr} - " f"{loss_str} - " f"time:{interval_batch_time:.4f}") start = time.time() global_step += 1 logger.info(f'train_loss: {train_loss / len(train_loader)}') if (epoch + 1) % train_options['val_interval'] == 0: global_state['start_epoch'] = epoch global_state['best_model'] = best_model global_state['global_step'] = global_step net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth" save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state) if train_options['ckpt_save_type'] == 'HighestAcc': # val eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric) if eval_dict['hmean'] > best_model['hmean']: best_model.update(eval_dict) best_model['best_model_epoch'] = epoch best_model['models'] = net_save_path global_state['start_epoch'] = epoch global_state['best_model'] = best_model global_state['global_step'] = global_step net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth" save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state) elif train_options[ 'ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[ 'ckpt_save_epoch'] == 0: shutil.copy( net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth')) best_str = 'current best, ' for k, v in best_model.items(): best_str += '{}: {}, '.format(k, v) logger.info(best_str) except KeyboardInterrupt: import os save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state) except: error_msg = traceback.format_exc() logger.error(error_msg) finally: for k, v in best_model.items(): logger.info(f'{k}: {v}')
logger.info(f"[{epoch}/{train_options['epochs']}] - " f"[{i + 1}/{all_step}] - " f"lr:{current_lr} - " f"loss:{loss_dict['loss'].item():.4f} - " f"acc:{acc:.4f} - " f"norm_edit_dis:{norm_edit_dis:.4f} - " f"time:{interval_batch_time:.4f}") start = time.time() if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0: global_state['start_epoch'] = epoch global_state['best_model'] = best_model global_state['global_step'] = global_step net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth" save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state) if train_options['ckpt_save_type'] == 'HighestAcc': # val eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric) if eval_dict['eval_acc'] > best_model['eval_acc']: best_model.update(eval_dict) best_model['best_model_epoch'] = epoch best_model['models'] = net_save_path global_state['start_epoch'] = epoch global_state['best_model'] = best_model global_state['global_step'] = global_step net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth" save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state) elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0: shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth')) global_step += 1 scheduler.step()
def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, _epoch, logger): """ Returns: :param net: 模型 :param optimizer: 优化器 :param scheduler: 学习率更新 :param loss_func: loss函数 :param train_loader: 训练数据集dataloader :param eval_loader: 验证数据集dataloader :param to_use_device: :param train_options: :param _epoch: :param logger: """ from torchocr.metrics import RecMetric from torchocr.utils import CTCLabelConverter with open(cfg.dataset.alphabet, 'r', encoding='utf-8') as file: cfg.dataset.alphabet = ''.join([s.strip('\n') for s in file.readlines()]) converter = CTCLabelConverter(cfg.dataset.alphabet) train_options = cfg.train_options metric = RecMetric(converter) # ===> logger.info('Training...') # ===> print loss信息的参数 all_step = len(train_loader) logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader') logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader') best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.} # 开始训练 try: start = time.time() for epoch in range(_epoch, train_options['epochs']): # traverse each epoch net.train() # train mode for i, batch_data in enumerate(train_loader): # traverse each batch in the epoch current_lr = optimizer.param_groups[0]['lr'] cur_batch_size = batch_data['img'].shape[0] targets, targets_lengths = converter.encode(batch_data['label']) batch_data['targets'] = targets batch_data['targets_lengths'] = targets_lengths optimizer.zero_grad() output = net.forward(batch_data['img'].to(to_use_device)) loss_dict = loss_func(output, batch_data) loss_dict['loss'].backward() torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optimizer.step() # statistic loss for print acc_dict = metric(output, batch_data['label']) acc = acc_dict['n_correct'] / cur_batch_size norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size if (i + 1) % train_options['print_interval'] == 0: interval_batch_time = time.time() - start logger.info(f"[{epoch}/{train_options['epochs']}] - " f"[{i + 1}/{all_step}] -" f"lr:{current_lr} - " f"loss:{loss_dict['loss'].item():.4f} - " f"acc:{acc:.4f} - " f"norm_edit_dis:{norm_edit_dis:.4f} - " f"time:{interval_batch_time:.4f}") start = time.time() if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0: # val eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric) if train_options['ckpt_save_type'] == 'HighestAcc': net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth" save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg) if eval_dict['eval_acc'] > best_model['eval_acc']: best_model.update(eval_dict) best_model['models'] = net_save_path shutil.copy(net_save_path, net_save_path.replace('latest', 'best')) elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0: net_save_path = f"{train_options['checkpoint_save_dir']}/{epoch}.pth" save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg) scheduler.step() except KeyboardInterrupt: import os save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final_' + str(epoch) + '.pth'), net, optimizer, epoch, logger, cfg) except: error_msg = traceback.format_exc(limit=1) logger.error(error_msg) finally: for k, v in best_model.items(): logger.info(f'{k}: {v}')
def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device, cfg, _epoch, logger, post_process): """ 训练函数 :param net: 网络 :param optimizer: 优化器 :param scheduler: 学习率更新器 :param loss_func: loss函数 :param train_loader: 训练数据集 dataloader :param eval_loader: 验证数据集 dataloader :param to_use_device: device :param cfg: 当前训练所使用的配置 :param _epoch: 当前训练起始的 epoch :param logger: logger 对象 :param post_process: 后处理类对象 :return: None """ train_options = cfg.train_options metric = DetMetric() # ===> logger.info('Training...') # ===> print loss信息的参数 all_step = len(train_loader) logger.info( f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader' ) logger.info( f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader' ) best_model = { 'recall': 0, 'precision': 0, 'hmean': 0, 'best_model_epoch': 0 } # 开始训练 base_lr = optimizer.param_groups[0]['lr'] try: for epoch in range(_epoch, train_options['epochs']): # traverse each epoch net.train() # train mode train_loss = 0. start = time.time() epoch_start_time = time.time() for i, batch_data in enumerate( train_loader): # traverse each batch in the epoch current_lr = adjust_learning_rate(optimizer, base_lr, epoch, train_options['epochs'], 0.9) # 数据进行转换和丢到gpu for key, value in batch_data.items(): if value is not None: if isinstance(value, torch.Tensor): batch_data[key] = value.to(to_use_device) # 清零梯度及反向传播 optimizer.zero_grad() output = net.forward(batch_data['img'].to(to_use_device)) loss_dict = loss_func(output, batch_data) loss_dict['loss'].backward() optimizer.step() # statistic loss for print train_loss += loss_dict['loss'].item() loss_str = 'loss: {:.4f} - '.format( loss_dict.pop('loss').item()) for idx, (key, value) in enumerate(loss_dict.items()): loss_dict[key] = value.item() loss_str += '{}: {:.4f}'.format(key, loss_dict[key]) if idx < len(loss_dict) - 1: loss_str += ' - ' if (i + 1) % train_options['print_interval'] == 0: interval_batch_time = time.time() - start logger.info(f"[{epoch}/{train_options['epochs']}] - " f"[{i + 1}/{all_step}] - " f"lr:{current_lr:.7f} - " f"{loss_str} - " f"time:{interval_batch_time:.4f}s") start = time.time() epoch_time = time.time() - epoch_start_time logger.info(f'epoch_time: {epoch_time:.4f}s') avg_loss = train_loss / len(train_loader) logger.info(f'train_loss: {avg_loss:.4f}') # 周期保存模型 if train_options['ckpt_save_type'] == 'EpochStep' and ( epoch + 1) % train_options['ckpt_save_epoch'] == 0: net_save_path = f"{train_options['checkpoint_save_dir']}/{epoch}.pth" save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg) # 保存最好的模型 if (epoch + 1) % train_options['val_interval'] == 0: # val eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric) # if train_options['ckpt_save_type'] == 'HighestAcc': net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth" save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg) if eval_dict['hmean'] > best_model['hmean']: best_model.update(eval_dict) best_model['models'] = net_save_path shutil.copy(net_save_path, net_save_path.replace('latest', 'best')) best_str = 'current best, ' for k, v in best_model.items(): best_str += '{}: {}, '.format(k, v) logger.info(best_str) except KeyboardInterrupt: import os save_checkpoint( os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, epoch, logger, cfg) except BaseException: error_msg = traceback.format_exc() logger.error(error_msg) finally: for k, v in best_model.items(): logger.info(f'{k}: {v}')