Exemple #1
0
def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
          cfg, global_state, logger):
    """
    训练函数

    :param net: 网络
    :param optimizer: 优化器
    :param scheduler: 学习率更新器
    :param loss_func: loss函数
    :param train_loader: 训练数据集 dataloader
    :param eval_loader: 验证数据集 dataloader
    :param to_use_device: device
    :param cfg: 当前训练所使用的配置
    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
    :param logger: logger 对象
    :return: None
    """

    from torchocr.metrics import RecMetric
    from torchocr.utils import CTCLabelConverter
    converter = CTCLabelConverter(cfg.dataset.alphabet)
    train_options = cfg.train_options
    metric = RecMetric(converter)
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
    if len(global_state) > 0:
        best_model = global_state['best_model']
        start_epoch = global_state['start_epoch']
        global_step = global_state['global_step']
    else:
        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
        start_epoch = 0
        global_step = 0
    # 开始训练
    try:
        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            start = time.time()
            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
                current_lr = optimizer.param_groups[0]['lr']
                cur_batch_size = batch_data['img'].shape[0]
                targets, targets_lengths = converter.encode(batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths
                # 清零梯度及反向传播
                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                loss_dict = loss_func(output, batch_data)
                loss_dict['loss'].backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                # statistic loss for print
                acc_dict = metric(output, batch_data['label'])
                acc = acc_dict['n_correct'] / cur_batch_size
                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] - "
                                f"lr:{current_lr} - "
                                f"loss:{loss_dict['loss'].item():.4f} - "
                                f"acc:{acc:.4f} - "
                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
                    global_state['start_epoch'] = epoch
                    global_state['best_model'] = best_model
                    global_state['global_step'] = global_step
                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    if train_options['ckpt_save_type'] == 'HighestAcc':
                        # val
                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
                        if eval_dict['eval_acc'] > best_model['eval_acc']:
                            best_model.update(eval_dict)
                            best_model['best_model_epoch'] = epoch
                            best_model['models'] = net_save_path

                            global_state['start_epoch'] = epoch
                            global_state['best_model'] = best_model
                            global_state['global_step'] = global_step
                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
                global_step += 1
            scheduler.step()
    except KeyboardInterrupt:
        import os
        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state)
    except:
        error_msg = traceback.format_exc()
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')
Exemple #2
0
def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
          cfg, global_state, logger, post_process):
    """
    训练函数

    :param net: 网络
    :param optimizer: 优化器
    :param scheduler: 学习率更新器
    :param loss_func: loss函数
    :param train_loader: 训练数据集 dataloader
    :param eval_loader: 验证数据集 dataloader
    :param to_use_device: device
    :param cfg: 当前训练所使用的配置
    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
    :param logger: logger 对象
    :param post_process: 后处理类对象
    :return: None
    """

    train_options = cfg.train_options
    metric = DetMetric()
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(
        f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader'
    )
    logger.info(
        f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader'
    )
    if len(global_state) > 0:
        best_model = global_state['best_model']
        start_epoch = global_state['start_epoch']
        global_step = global_state['global_step']
    else:
        best_model = {
            'recall': 0,
            'precision': 0,
            'hmean': 0,
            'best_model_epoch': 0
        }
        start_epoch = 0
        global_step = 0
    # 开始训练
    base_lr = cfg['optimizer']['lr']
    all_iters = all_step * train_options['epochs']
    warmup_iters = 3 * all_step
    # eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
    try:
        for epoch in range(start_epoch,
                           train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            train_loss = 0.
            start = time.time()
            for i, batch_data in enumerate(
                    train_loader):  # traverse each batch in the epoch
                current_lr = adjust_learning_rate(optimizer,
                                                  base_lr,
                                                  global_step,
                                                  all_iters,
                                                  0.9,
                                                  warmup_iters=warmup_iters)
                # 数据进行转换和丢到gpu
                # for key, value in batch_data.items():
                #     if value is not None:
                #         if isinstance(value, torch.Tensor):
                #             batch_data[key] = value.to(to_use_device)
                # 清零梯度及反向传播
                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                labels, training_mask = batch_data['score_maps'].to(
                    to_use_device), batch_data['training_mask'].to(
                        to_use_device)
                loss_c, loss_s, loss = loss_func(output, labels, training_mask)
                loss.backward()
                optimizer.step()
                # statistic loss for print
                train_loss += loss.item()
                loss_str = 'loss: {:.4f} - '.format(loss.item())

                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] - "
                                f"lr:{current_lr} - "
                                f"{loss_str} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                global_step += 1
            logger.info(f'train_loss: {train_loss / len(train_loader)}')
            if (epoch + 1) % train_options['val_interval'] == 0:
                global_state['start_epoch'] = epoch
                global_state['best_model'] = best_model
                global_state['global_step'] = global_step
                net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                save_checkpoint(net_save_path,
                                net,
                                optimizer,
                                logger,
                                cfg,
                                global_state=global_state)
                if train_options['ckpt_save_type'] == 'HighestAcc':
                    # val
                    eval_dict = evaluate(net, eval_loader, to_use_device,
                                         logger, post_process, metric)
                    if eval_dict['hmean'] > best_model['hmean']:
                        best_model.update(eval_dict)
                        best_model['best_model_epoch'] = epoch
                        best_model['models'] = net_save_path

                        global_state['start_epoch'] = epoch
                        global_state['best_model'] = best_model
                        global_state['global_step'] = global_step
                        net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
                        save_checkpoint(net_save_path,
                                        net,
                                        optimizer,
                                        logger,
                                        cfg,
                                        global_state=global_state)
                elif train_options[
                        'ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[
                            'ckpt_save_epoch'] == 0:
                    shutil.copy(
                        net_save_path,
                        net_save_path.replace('latest.pth', f'{epoch}.pth'))
                best_str = 'current best, '
                for k, v in best_model.items():
                    best_str += '{}: {}, '.format(k, v)
                logger.info(best_str)
    except KeyboardInterrupt:
        import os
        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'],
                                     'final.pth'),
                        net,
                        optimizer,
                        logger,
                        cfg,
                        global_state=global_state)
    except:
        error_msg = traceback.format_exc()
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')
Exemple #3
0
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] - "
                                f"lr:{current_lr} - "
                                f"loss:{loss_dict['loss'].item():.4f} - "
                                f"acc:{acc:.4f} - "
                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
                    global_state['start_epoch'] = epoch
                    global_state['best_model'] = best_model
                    global_state['global_step'] = global_step
                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    if train_options['ckpt_save_type'] == 'HighestAcc':
                        # val
                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
                        if eval_dict['eval_acc'] > best_model['eval_acc']:
                            best_model.update(eval_dict)
                            best_model['best_model_epoch'] = epoch
                            best_model['models'] = net_save_path

                            global_state['start_epoch'] = epoch
                            global_state['best_model'] = best_model
                            global_state['global_step'] = global_step
                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
                global_step += 1
            scheduler.step()
Exemple #4
0
def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
          cfg, _epoch, logger):
    """
    Returns:
    :param net: 模型
    :param optimizer: 优化器
    :param scheduler: 学习率更新
    :param loss_func: loss函数
    :param train_loader: 训练数据集dataloader
    :param eval_loader: 验证数据集dataloader
    :param to_use_device:
    :param train_options:
    :param _epoch:
    :param logger:
    """
    from torchocr.metrics import RecMetric
    from torchocr.utils import CTCLabelConverter
    with open(cfg.dataset.alphabet, 'r', encoding='utf-8') as file:
        cfg.dataset.alphabet = ''.join([s.strip('\n') for s in file.readlines()])
    converter = CTCLabelConverter(cfg.dataset.alphabet)
    train_options = cfg.train_options
    metric = RecMetric(converter)
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
    best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
    # 开始训练
    try:
        start = time.time()
        for epoch in range(_epoch, train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
                current_lr = optimizer.param_groups[0]['lr']
                cur_batch_size = batch_data['img'].shape[0]
                targets, targets_lengths = converter.encode(batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths

                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                loss_dict = loss_func(output, batch_data)
                loss_dict['loss'].backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                # statistic loss for print
                acc_dict = metric(output, batch_data['label'])
                acc = acc_dict['n_correct'] / cur_batch_size
                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] -"
                                f"lr:{current_lr} - "
                                f"loss:{loss_dict['loss'].item():.4f} - "
                                f"acc:{acc:.4f} - "
                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
                    # val
                    eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
                    if train_options['ckpt_save_type'] == 'HighestAcc':
                        net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                        save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg)
                        if eval_dict['eval_acc'] > best_model['eval_acc']:
                            best_model.update(eval_dict)
                            best_model['models'] = net_save_path
                            shutil.copy(net_save_path, net_save_path.replace('latest', 'best'))
                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
                        net_save_path = f"{train_options['checkpoint_save_dir']}/{epoch}.pth"
                        save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg)
            scheduler.step()
    except KeyboardInterrupt:
        import os
        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final_' + str(epoch) + '.pth'), net,
                        optimizer, epoch, logger, cfg)
    except:
        error_msg = traceback.format_exc(limit=1)
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')
Exemple #5
0
def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
          cfg, _epoch, logger, post_process):
    """
    训练函数

    :param net: 网络
    :param optimizer: 优化器
    :param scheduler: 学习率更新器
    :param loss_func: loss函数
    :param train_loader: 训练数据集 dataloader
    :param eval_loader: 验证数据集 dataloader
    :param to_use_device: device
    :param cfg: 当前训练所使用的配置
    :param _epoch: 当前训练起始的 epoch
    :param logger: logger 对象
    :param post_process: 后处理类对象
    :return: None
    """

    train_options = cfg.train_options
    metric = DetMetric()
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(
        f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader'
    )
    logger.info(
        f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader'
    )
    best_model = {
        'recall': 0,
        'precision': 0,
        'hmean': 0,
        'best_model_epoch': 0
    }
    # 开始训练
    base_lr = optimizer.param_groups[0]['lr']
    try:
        for epoch in range(_epoch,
                           train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            train_loss = 0.
            start = time.time()
            epoch_start_time = time.time()
            for i, batch_data in enumerate(
                    train_loader):  # traverse each batch in the epoch
                current_lr = adjust_learning_rate(optimizer, base_lr, epoch,
                                                  train_options['epochs'], 0.9)
                # 数据进行转换和丢到gpu
                for key, value in batch_data.items():
                    if value is not None:
                        if isinstance(value, torch.Tensor):
                            batch_data[key] = value.to(to_use_device)
                # 清零梯度及反向传播
                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                loss_dict = loss_func(output, batch_data)
                loss_dict['loss'].backward()
                optimizer.step()
                # statistic loss for print
                train_loss += loss_dict['loss'].item()
                loss_str = 'loss: {:.4f} - '.format(
                    loss_dict.pop('loss').item())
                for idx, (key, value) in enumerate(loss_dict.items()):
                    loss_dict[key] = value.item()
                    loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
                    if idx < len(loss_dict) - 1:
                        loss_str += ' - '
                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] - "
                                f"lr:{current_lr:.7f} - "
                                f"{loss_str} - "
                                f"time:{interval_batch_time:.4f}s")
                    start = time.time()
            epoch_time = time.time() - epoch_start_time
            logger.info(f'epoch_time: {epoch_time:.4f}s')
            avg_loss = train_loss / len(train_loader)
            logger.info(f'train_loss: {avg_loss:.4f}')
            #  周期保存模型
            if train_options['ckpt_save_type'] == 'EpochStep' and (
                    epoch + 1) % train_options['ckpt_save_epoch'] == 0:
                net_save_path = f"{train_options['checkpoint_save_dir']}/{epoch}.pth"
                save_checkpoint(net_save_path, net, optimizer, epoch, logger,
                                cfg)

            #  保存最好的模型
            if (epoch + 1) % train_options['val_interval'] == 0:
                # val
                eval_dict = evaluate(net, eval_loader, to_use_device, logger,
                                     post_process, metric)
                # if train_options['ckpt_save_type'] == 'HighestAcc':
                net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                save_checkpoint(net_save_path, net, optimizer, epoch, logger,
                                cfg)
                if eval_dict['hmean'] > best_model['hmean']:
                    best_model.update(eval_dict)
                    best_model['models'] = net_save_path
                    shutil.copy(net_save_path,
                                net_save_path.replace('latest', 'best'))

                best_str = 'current best, '
                for k, v in best_model.items():
                    best_str += '{}: {}, '.format(k, v)
                logger.info(best_str)
    except KeyboardInterrupt:
        import os
        save_checkpoint(
            os.path.join(train_options['checkpoint_save_dir'], 'final.pth'),
            net, optimizer, epoch, logger, cfg)
    except BaseException:
        error_msg = traceback.format_exc()
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')