Ejemplo n.º 1
0
def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
          cfg, _epoch, logger):
    """
    Returns:
    :param net: 模型
    :param optimizer: 优化器
    :param scheduler: 学习率更新
    :param loss_func: loss函数
    :param train_loader: 训练数据集dataloader
    :param eval_loader: 验证数据集dataloader
    :param to_use_device:
    :param train_options:
    :param _epoch:
    :param logger:
    """
    from torchocr.metrics import RecMetric
    from torchocr.utils import CTCLabelConverter
    with open(cfg.dataset.alphabet, 'r', encoding='utf-8') as file:
        cfg.dataset.alphabet = ''.join([s.strip('\n') for s in file.readlines()])
    converter = CTCLabelConverter(cfg.dataset.alphabet)
    train_options = cfg.train_options
    metric = RecMetric(converter)
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
    best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
    # 开始训练
    try:
        start = time.time()
        for epoch in range(_epoch, train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
                current_lr = optimizer.param_groups[0]['lr']
                cur_batch_size = batch_data['img'].shape[0]
                targets, targets_lengths = converter.encode(batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths

                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                loss_dict = loss_func(output, batch_data)
                loss_dict['loss'].backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                # statistic loss for print
                acc_dict = metric(output, batch_data['label'])
                acc = acc_dict['n_correct'] / cur_batch_size
                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] -"
                                f"lr:{current_lr} - "
                                f"loss:{loss_dict['loss'].item():.4f} - "
                                f"acc:{acc:.4f} - "
                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
                    # val
                    eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
                    if train_options['ckpt_save_type'] == 'HighestAcc':
                        net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                        save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg)
                        if eval_dict['eval_acc'] > best_model['eval_acc']:
                            best_model.update(eval_dict)
                            best_model['models'] = net_save_path
                            shutil.copy(net_save_path, net_save_path.replace('latest', 'best'))
                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
                        net_save_path = f"{train_options['checkpoint_save_dir']}/{epoch}.pth"
                        save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg)
            scheduler.step()
    except KeyboardInterrupt:
        import os
        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final_' + str(epoch) + '.pth'), net,
                        optimizer, epoch, logger, cfg)
    except:
        error_msg = traceback.format_exc(limit=1)
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')
Ejemplo n.º 2
0
def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
          cfg, global_state, logger):
    """
    训练函数

    :param net: 网络
    :param optimizer: 优化器
    :param scheduler: 学习率更新器
    :param loss_func: loss函数
    :param train_loader: 训练数据集 dataloader
    :param eval_loader: 验证数据集 dataloader
    :param to_use_device: device
    :param cfg: 当前训练所使用的配置
    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
    :param logger: logger 对象
    :return: None
    """

    from torchocr.metrics import RecMetric
    from torchocr.utils import CTCLabelConverter
    converter = CTCLabelConverter(cfg.dataset.alphabet)
    train_options = cfg.train_options
    metric = RecMetric(converter)
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
    if len(global_state) > 0:
        best_model = global_state['best_model']
        start_epoch = global_state['start_epoch']
        global_step = global_state['global_step']
    else:
        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
        start_epoch = 0
        global_step = 0
    # 开始训练
    try:
        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            start = time.time()
            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
                current_lr = optimizer.param_groups[0]['lr']
                cur_batch_size = batch_data['img'].shape[0]
                targets, targets_lengths = converter.encode(batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths
                # 清零梯度及反向传播
                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                loss_dict = loss_func(output, batch_data)
                loss_dict['loss'].backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                # statistic loss for print
                acc_dict = metric(output, batch_data['label'])
                acc = acc_dict['n_correct'] / cur_batch_size
                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] - "
                                f"lr:{current_lr} - "
                                f"loss:{loss_dict['loss'].item():.4f} - "
                                f"acc:{acc:.4f} - "
                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
                    global_state['start_epoch'] = epoch
                    global_state['best_model'] = best_model
                    global_state['global_step'] = global_step
                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    if train_options['ckpt_save_type'] == 'HighestAcc':
                        # val
                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
                        if eval_dict['eval_acc'] > best_model['eval_acc']:
                            best_model.update(eval_dict)
                            best_model['best_model_epoch'] = epoch
                            best_model['models'] = net_save_path

                            global_state['start_epoch'] = epoch
                            global_state['best_model'] = best_model
                            global_state['global_step'] = global_step
                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
                global_step += 1
            scheduler.step()
    except KeyboardInterrupt:
        import os
        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state)
    except:
        error_msg = traceback.format_exc()
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')
Ejemplo n.º 3
0
     best_model = global_state['best_model']
     start_epoch = global_state['start_epoch']
     global_step = global_state['global_step']
 else:
     best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
     start_epoch = 0
     global_step = 0
 # 开始训练
 try:
     for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
         net.train()  # train mode
         start = time.time()
         for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
             current_lr = optimizer.param_groups[0]['lr']
             cur_batch_size = batch_data['img'].shape[0]
             targets, targets_lengths = converter.encode(batch_data['label'])
             batch_data['targets'] = targets
             batch_data['targets_lengths'] = targets_lengths
             # 清零梯度及反向传播
             optimizer.zero_grad()
             output = net.forward(batch_data['img'].to(to_use_device))
             loss_dict = loss_func(output, batch_data)
             loss_dict['loss'].backward()
             torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
             optimizer.step()
             # statistic loss for print
             acc_dict = metric(output, batch_data['label'])
             acc = acc_dict['n_correct'] / cur_batch_size
             norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
             if (i + 1) % train_options['print_interval'] == 0:
                 interval_batch_time = time.time() - start