예제 #1
0
def train(train_loader, model, optimizer, lr_scheduler, tb_writer):
    cur_lr = lr_scheduler.get_cur_lr()
    rank = get_rank()

    average_meter = AverageMeter()

    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    world_size = get_world_size()
    num_per_epoch = len(train_loader.dataset) // \
        cfg.TRAIN.EPOCH // (cfg.TRAIN.BATCH_SIZE * world_size)
    start_epoch = cfg.TRAIN.START_EPOCH
    epoch = start_epoch

    if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR) and \
            get_rank() == 0:
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)

    logger.info("model\n{}".format(describe(model.module)))
    end = time.time()
    for idx, data in enumerate(train_loader):
        if epoch != idx // num_per_epoch + start_epoch:
            epoch = idx // num_per_epoch + start_epoch

            if get_rank() == 0:
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    cfg.TRAIN.SNAPSHOT_DIR + '/checkpoint_e%d.pth' % (epoch))

            if epoch == cfg.TRAIN.EPOCH:
                return

            if cfg.BACKBONE.TRAIN_EPOCH == epoch:
                logger.info('start training backbone.')
                optimizer, lr_scheduler = build_opt_lr(model.module, epoch)
                logger.info("model\n{}".format(describe(model.module)))

            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()
            logger.info('epoch: {}'.format(epoch + 1))

        tb_idx = idx
        if idx % num_per_epoch == 0 and idx != 0:
            for idx, pg in enumerate(optimizer.param_groups):
                logger.info('epoch {} lr {}'.format(epoch + 1, pg['lr']))
                if rank == 0:
                    tb_writer.add_scalar('lr/group{}'.format(idx + 1),
                                         pg['lr'], tb_idx)

        data_time = average_reduce(time.time() - end)
        if rank == 0:
            tb_writer.add_scalar('time/data', data_time, tb_idx)

        outputs = model(data)
        loss = outputs['total_loss']

        if is_valid_number(loss.data.item()):
            optimizer.zero_grad()
            loss.backward()
            reduce_gradients(model)

            if rank == 0 and cfg.TRAIN.LOG_GRADS:
                log_grads(model.module, tb_writer, tb_idx)

            # clip gradient
            clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP)
            optimizer.step()

        batch_time = time.time() - end
        batch_info = {}
        batch_info['batch_time'] = average_reduce(batch_time)
        batch_info['data_time'] = average_reduce(data_time)
        for k, v in sorted(outputs.items()):
            batch_info[k] = average_reduce(v.data.item())

        average_meter.update(**batch_info)

        if rank == 0:
            for k, v in batch_info.items():
                tb_writer.add_scalar(k, v, tb_idx)

            if (idx + 1) % cfg.TRAIN.PRINT_FREQ == 0:
                info = "Epoch: [{}][{}/{}] lr: {:.6f}\n".format(
                    epoch + 1, (idx + 1) % num_per_epoch, num_per_epoch,
                    cur_lr)
                for cc, (k, v) in enumerate(batch_info.items()):
                    if cc % 2 == 0:
                        info += ("\t{:s}\t").format(getattr(average_meter, k))
                    else:
                        info += ("{:s}\n").format(getattr(average_meter, k))
                logger.info(info)
                print_speed(idx + 1 + start_epoch * num_per_epoch,
                            average_meter.batch_time.avg,
                            cfg.TRAIN.EPOCH * num_per_epoch)
        end = time.time()
예제 #2
0
def train(train_loader, model, optimizer, lr_scheduler, tb_writer):
    '''
    :param train_loader:
    :param model:
    :param optimizer:
    :param lr_scheduler:
    :param tb_writer:
    :return:
    '''
    cur_lr = lr_scheduler.get_cur_lr()  #获得当前学习率
    rank = get_rank()

    average_meter = AverageMeter()

    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    world_size = get_world_size()
    num_per_epoch = len(train_loader.dataset) // cfg.TRAIN.EPOCH // (
        cfg.TRAIN.BATCH_SIZE * world_size)
    start_epoch = cfg.TRAIN.START_EPOCH
    epoch = start_epoch

    if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR) and get_rank() == 0:
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)

    logger.info("model\n{}".format(describe(model.module)))  #打印模型
    end = time.time()
    for idx, data in enumerate(train_loader):

        if epoch != idx // num_per_epoch + start_epoch:  #每个epoch的跳变沿进行一次模型存储
            epoch = idx // num_per_epoch + start_epoch

            if get_rank() == 0:
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    cfg.TRAIN.SNAPSHOT_DIR + '/checkpoint_e%d.pth' % (epoch))

            if epoch == cfg.TRAIN.EPOCH:
                return

            # 如果达到第10个epoch后,则要开始微调backbone的后面3层,要重新设置一下哪些参数是可训练的,哪些参数是不动的,学习率的调整因子
            if cfg.BACKBONE.TRAIN_EPOCH == epoch:
                logger.info('start training backbone.')
                optimizer, lr_scheduler = build_opt_lr(model.module, epoch)
                logger.info("model\n{}".format(describe(model.module)))

            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()
            logger.info('epoch: {}'.format(epoch + 1))

        tb_idx = idx  #tensor board的idx
        if idx % num_per_epoch == 0 and idx != 0:
            for idx, pg in enumerate(
                    optimizer.param_groups):  #将优化器中的学习率添加到tensorboard中监视
                logger.info('epoch {} lr {}'.format(epoch + 1, pg['lr']))
                if rank == 0:
                    tb_writer.add_scalar('lr/group{}'.format(idx + 1),
                                         pg['lr'], tb_idx)

        data_time = average_reduce(time.time() - end)
        if rank == 0:
            tb_writer.add_scalar('time/data', data_time, tb_idx)

    # show_tensor(data, tb_idx, tb_writer)  # 只看输入数据,在tensorboard中显示输入数据
        data[0]['iter'] = tb_idx  #添加监视用
        outputs = model(data)
        # loss = outputs['feat_loss']
        loss = outputs['total_loss']
        show_tensor(data, tb_idx, tb_writer,
                    outputs)  #输入输出都看,在tensorboard中显示输入数据

        if is_valid_number(
                loss.data.item()):  #判断损失是否是合法数据,滤掉nan,+inf,>10000的这样的损失
            optimizer.zero_grad()
            loss.backward()
            reduce_gradients(model)  #分发梯度

            if rank == 0 and cfg.TRAIN.LOG_GRADS:  #对梯度信息监视
                log_grads(model.module, tb_writer, tb_idx)

            # clip gradient
            clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP)
            optimizer.step()

        batch_time = time.time() - end
        batch_info = {}
        batch_info['batch_time'] = average_reduce(batch_time)
        batch_info['data_time'] = average_reduce(data_time)
        for k, v in sorted(outputs.items()):
            if k is 'zf' or k is 'zf_gt' or k is 'zfs' or k is 'box_img':
                pass
            else:
                batch_info[k] = average_reduce(v.data.item())

        average_meter.update(**batch_info)

        if rank == 0:
            for k, v in batch_info.items():
                tb_writer.add_scalar(k, v, tb_idx)

            if (idx + 1) % cfg.TRAIN.PRINT_FREQ == 0:
                info = "Epoch: [{}][{}/{}] lr: {:.6f}\n".format(
                    epoch + 1, (idx + 1) % num_per_epoch, num_per_epoch,
                    cur_lr)
                for cc, (k, v) in enumerate(batch_info.items()):
                    if cc % 2 == 0:
                        info += ("\t{:s}\t").format(getattr(average_meter, k))
                    else:
                        info += ("{:s}\n").format(getattr(average_meter, k))
                logger.info(info)
                print_speed(idx + 1 + start_epoch * num_per_epoch,
                            average_meter.batch_time.avg,
                            cfg.TRAIN.EPOCH * num_per_epoch)
        end = time.time()
예제 #3
0
    def train(self):

        rank = get_rank()

        def is_valid_number(x):
            return not (math.isnan(x) or math.isinf(x) or x > 1e4)

        # world_size = get_world_size()

        if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR) and \
                get_rank() == 0:
            os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)

        logger.info("model\n{}".format(describe(self._model.module)))

        eval_result = False
        epoch = cfg.TRAIN.START_EPOCH + 1
        while not eval_result and epoch <= cfg.TSA.MAX_ITERATIONS:

            for idx, data in enumerate(self._train_loader):

                outputs = self._model(data)
                loss = outputs['total_loss']

                start_time = time.time()
                tb_idx = idx
                if is_valid_number(loss.data.item()):
                    self._optimizer.zero_grad()
                    loss.backward()
                    reduce_gradients(self._model)

                    if rank == 0 and cfg.TRAIN.LOG_GRADS:
                        log_grads(self._model.module, self._tb_writer, tb_idx)

                    # clip gradient
                    clip_grad_norm_(self._model.parameters(), cfg.TRAIN.GRAD_CLIP)
                    self._optimizer.step()

                print("Step: %d  first_loss: %f Total_loss: %f  Speed:  %.0f examples per second" %
                      (epoch, outputs['first_loss'], loss.data.item(), cfg.TSA.BATCH_SIZE * cfg.TSA.TIME_STEP / (time.time() - start_time)))

                if epoch % cfg.TSA.MODEL_SAVE_STEP == 0 or epoch == cfg.TSA.MAX_ITERATIONS or epoch % cfg.TSA.VALIDATE_STEP == 0:
                    if get_rank() == 0:
                        torch.save(
                            {'epoch': epoch,
                             'state_dict': self._model.module.state_dict(),
                             'optimizer': self._optimizer.state_dict()},
                            cfg.TRAIN.SNAPSHOT_DIR + '/checkpoint_e%d.pth' % (epoch))
                    print('Save to checkpoint at step %d' % (epoch))

                if epoch % cfg.TSA.VALIDATE_STEP == 0:
                    if self.evaluate(epoch, 'loss'):
                        eval_result = True
                        break
                if epoch > cfg.TSA.MAX_ITERATIONS:
                    break
                # 更新迭代的次数
                epoch += 1
                if epoch % cfg.TSA.LR_UPDATE == 0:
                    self._scheduler.step(epoch)
                    cur_lr = self._scheduler.get_lr()
                    logger.info('epoch: {} cur_lr: {}'.format(epoch, cur_lr))
예제 #4
0
def show_tensor(batch_data, global_iter, tb_writer, outputs):
    '''
    :param batch_data: 输入的网络的数据
    :param global_iter:   tensorboard监视计数
    :param tb_writer:  tensorboard 的summarywriter
    :return:
    '''

    rank = get_rank()
    # global_iter = 0  # tensorboard监视计数
    # if rank==0:
    #     dataiter = iter(train_loader)
    #     data = next(dataiter)               #利用迭代器只取一个数据,用于构建图
    #     # tb_writer.add_graph(model,data)

    max_batch = cfg.TRAIN.MaxShowBatch  #tensorboard最多显示4个batch
    batch, _, _, _ = batch_data[0]["template"].shape
    batch = min(batch, max_batch)

    if rank == 0 and global_iter % cfg.TRAIN.ShowPeriod == 0:
        for i in range(cfg.GRU.SEQ_IN):
            xi = batch_data[
                i]  # 每个data[i]中包含的信息为 'template','search','label_cls','label_loc','label_loc_weight','bbox','neg'

            tensor_t = draw_rect(xi["template"][0:batch],
                                 xi["t_bbox"][0:batch].view(batch, -1, 4))
            tensor_s = draw_rect(xi["search"][0:batch],
                                 xi["s_bbox"][0:batch].view(batch, -1, 4))
            tb_xi_template = vutils.make_grid(
                tensor_t, normalize=True, scale_each=True)  # b c h w的图展开为多个图
            tb_writer.add_image('input/{}th_input_template'.format(i),
                                tb_xi_template,
                                global_iter)  # t_bbox是相对于模板坐标系的
            tb_xi_search = vutils.make_grid(tensor_s,
                                            normalize=True,
                                            scale_each=True)  # b c h w的图展开为多个图
            tb_writer.add_image('input/{}th_input_search'.format(i),
                                tb_xi_search,
                                global_iter)  # s_bbox是相对于搜索区域坐标系的

        for i in range(cfg.GRU.SEQ_OUT):
            xi = batch_data[
                i + cfg.GRU.
                SEQ_IN]  # 每个data[i]中包含的信息为 'template','search','label_cls','label_loc','label_loc_weight','bbox','neg'

            tensor_t = draw_rect(xi["template"][0:batch],
                                 xi["t_bbox"][0:batch].view(batch, -1, 4))
            tensor_s = draw_rect(xi["search"][0:batch],
                                 xi["s_bbox"][0:batch].view(batch, -1, 4))
            tb_xi_template = vutils.make_grid(
                tensor_t, normalize=True, scale_each=True)  # b c h w的图展开为多个图
            tb_writer.add_image(
                'input/{}th_output_template'.format(i + cfg.GRU.SEQ_IN),
                tb_xi_template, global_iter)  # t_bbox是相对于模板坐标系的
            tb_xi_search = vutils.make_grid(tensor_s,
                                            normalize=True,
                                            scale_each=True)  # b c h w的图展开为多个图
            tb_writer.add_image(
                'input/{}th_output_search'.format(i + cfg.GRU.SEQ_IN),
                tb_xi_search, global_iter)  # s_bbox是相对于搜索区域坐标系的

        if outputs['zf'] is not None:
            fb, fc, fh, fw = outputs['zf'].shape
            fc = (min(fc, 9) // 3) * 3  #最多显示9个通道的数据
            for i in range(0, fc, 3):
                tb_feat = vutils.make_grid(outputs['zf'][0:batch, i:i + 3,
                                                         ...],
                                           normalize=True,
                                           scale_each=True)  # b c h w的图展开为多个图
                tb_writer.add_image('feature/{}th_feat'.format(i), tb_feat,
                                    global_iter)  # t_bbox是相对于模板坐标系的

        if outputs['zfs'] is not None:
            _, ft, _, _, _ = outputs['zfs'].shape
            for t in range(ft):
                feat = outputs['zfs'][:, t, ...]
                fb, fc, fh, fw = feat.shape
                fc = (min(fc, 9) // 3) * 3  # 最多显示9个通道的数据
                for i in range(0, fc, 3):
                    tb_feat = vutils.make_grid(
                        feat[0:batch, i:i + 3, ...],
                        normalize=True,
                        scale_each=True)  # b c h w的图展开为多个图
                    tb_writer.add_image('feature/{}th_{}feat'.format(i, t),
                                        tb_feat,
                                        global_iter)  # t_bbox是相对于模板坐标系的

        if outputs['zf_gt'] is not None:
            fb, fc, fh, fw = outputs['zf_gt'].shape
            fc = (min(fc, 9) // 3) * 3  # 最多显示9个通道的数据
            for i in range(0, fc, 3):
                tb_feat_gt = vutils.make_grid(
                    outputs['zf_gt'][0:batch, i:i + 3, ...],
                    normalize=True,
                    scale_each=True)  # b c h w的图展开为多个图
                tb_writer.add_image('feature/{}th_feat_gt'.format(i),
                                    tb_feat_gt,
                                    global_iter)  # t_bbox是相对于模板坐标系的

        if outputs['box_img'] is not None:
            tb_writer.add_image('predict/box_img', outputs['box_img'],
                                global_iter)  # t_bbox是相对于模板坐标系的