Ejemplo n.º 1
0
def train(data_dir, model_path=None, vis_port=None, init=None):
    # loading meta data
    # -----------------------------------------------------------------------------------------------------
    configs = './net/config.json'
    cfg = json.load(open(configs))

    # create dataset
    # -----------------------------------------------------------------------------------------------------
    trainloader, validloader = build_data_loader(cfg)
    anchors = None
    # create summary writer
    if not os.path.exists(config.log_dir):
        os.mkdir(config.log_dir)
    summary_writer = SummaryWriter(config.log_dir)
    if vis_port:
        vis = visual(port=vis_port)

    # start training
    # -----------------------------------------------------------------------------------------------------
    model = SiameseAlexNet()
    model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)
    # load model weight
    # -----------------------------------------------------------------------------------------------------
    start_epoch = 1
    if model_path and init:
        print("init training with checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        if 'model' in checkpoint.keys():
            model.load_state_dict(checkpoint['model'])
        else:
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
        del checkpoint
        torch.cuda.empty_cache()
        print("inited checkpoint")
    elif model_path and not init:
        print("loading checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
        print("loaded checkpoint")
    elif not model_path and config.pretrained_model:
        print("init with pretrained checkpoint %s" % config.pretrained_model +
              '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.state_dict()
        model_dict.update(checkpoint)
        model.load_state_dict(model_dict)

    # freeze layers
    def freeze_layers(model):
        print(
            '------------------------------------------------------------------------------------------------'
        )
        for layer in model.featureExtract[:10]:
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.Conv2d):
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.MaxPool2d):
                continue
            elif isinstance(layer, nn.ReLU):
                continue
            else:
                raise KeyError('error in fixing former 3 layers')
        print("fixed layers:")
        print(model.featureExtract[:10])

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    for epoch in range(start_epoch, config.epoch + 1):
        train_loss = []
        model.train()
        if config.fix_former_3_layers:
            if torch.cuda.device_count() > 1:
                freeze_layers(model.module)
            else:
                freeze_layers(model)
        loss_temp_cls = 0
        loss_temp_reg = 0
        for i, data in enumerate(tqdm(trainloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target, delta_weight, gt = data
            # conf_target (8,1125) (8,225x5)
            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
            optimizer.step()

            step = (epoch - 1) * len(trainloader) + i
            summary_writer.add_scalar('train/cls_loss', cls_loss.data, step)
            summary_writer.add_scalar('train/reg_loss', reg_loss.data, step)
            train_loss.append(loss.detach().cpu())
            loss_temp_cls += cls_loss.detach().cpu().numpy()
            loss_temp_reg += reg_loss.detach().cpu().numpy()
            # if vis_port:
            #     vis.plot_error({'rpn_cls_loss': cls_loss.detach().cpu().numpy().ravel()[0],
            #                     'rpn_regress_loss': reg_loss.detach().cpu().numpy().ravel()[0]}, win=0)
            if (i + 1) % config.show_interval == 0:
                tqdm.write(
                    "[epoch %2d][iter %4d] cls_loss: %.4f, reg_loss: %.4f lr: %.2e"
                    % (epoch, i, loss_temp_cls / config.show_interval,
                       loss_temp_reg / config.show_interval,
                       optimizer.param_groups[0]['lr']))
                loss_temp_cls = 0
                loss_temp_reg = 0

        train_loss = np.mean(train_loss)

        valid_loss = []
        model.eval()
        for i, data in enumerate(tqdm(validloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target, delta_weight, gt = data

            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            valid_loss.append(loss.detach().cpu())
        valid_loss = np.mean(valid_loss)
        print("EPOCH %d valid_loss: %.4f, train_loss: %.4f" %
              (epoch, valid_loss, train_loss))
        summary_writer.add_scalar('valid/loss', valid_loss,
                                  (epoch + 1) * len(trainloader))
        adjust_learning_rate(
            optimizer, config.gamma
        )  # adjust before save, and it will be epoch+1's lr when next load
        if epoch % config.save_interval == 0:
            if not os.path.exists('./data/models/'):
                os.makedirs("./data/models/")
            save_name = "./data/models/siamrpn_{}.pth".format(epoch)
            new_state_dict = model.state_dict()
            if torch.cuda.device_count() > 1:
                new_state_dict = OrderedDict()
                for k, v in model.state_dict().items():
                    namekey = k[7:]  # remove `module.`
                    new_state_dict[namekey] = v
            torch.save(
                {
                    'epoch': epoch,
                    'model': new_state_dict,
                    'optimizer': optimizer.state_dict(),
                }, save_name)
            print('save model: {}'.format(save_name))
Ejemplo n.º 2
0
def train(data_dir, model_path=None, vis_port=None, init=None):
    # 得到所有视频序列(已处理)
    meta_data_path = os.path.join(data_dir, "meta_data.pkl")
    meta_data = pickle.load(open(meta_data_path, 'rb'))
    all_sequences = [x[0] for x in meta_data]
    # 分割出训练集、测试集
    train_sequences, valid_sequences = train_test_split(all_sequences,
                                                        test_size=1 - Config.train_ratio, random_state=Config.seed)
    # define transforms
    train_z_transforms = transforms.Compose([
        ToTensor()
    ])
    train_x_transforms = transforms.Compose([
        ToTensor()
    ])
    valid_z_transforms = transforms.Compose([
        ToTensor()
    ])
    valid_x_transforms = transforms.Compose([
        ToTensor()
    ])

    # get train dataset
    train_dataset = GetDataSet(train_sequences, data_dir, train_z_transforms, train_x_transforms, meta_data,
                               training=True)
    anchors = train_dataset.anchors
    # get valid dataset
    valid_dataset = GetDataSet(valid_sequences, data_dir, valid_z_transforms, valid_x_transforms, meta_data,
                               training=False)
    # 创建dataloader迭代器
    train_batch_size = Config.stmm_train_batch_size if Config.update_template else Config.train_batch_size
    valid_batch_size = Config.stmm_valid_batch_size if Config.update_template else Config.valid_batch_size
    trainloader = DataLoader(train_dataset, batch_size=train_batch_size * t.cuda.device_count(),
                             shuffle=True, pin_memory=True,
                             num_workers=Config.train_num_workers * t.cuda.device_count(),
                             drop_last=True)
    validloader = DataLoader(valid_dataset, batch_size=valid_batch_size * t.cuda.device_count(),
                             shuffle=False, pin_memory=True,
                             num_workers=Config.valid_num_workers * t.cuda.device_count(), drop_last=True)
    # 创建summary writer
    if not os.path.exists(Config.log_dir):
        os.mkdir(Config.log_dir)
    summary_writer = SummaryWriter(Config.log_dir)
    # 可视化
    if vis_port:
        vis = visual()
    # start training
    model = SiameseAlexNet()
    model = model.cuda()
    optimizer = t.optim.SGD(model.parameters(), lr=Config.lr, momentum=Config.momentum,
                            weight_decay=Config.weight_dacay)
    start_epoch = 1
    # load model weight
    if model_path and init:  # 需要初始化以及存在训练模型时
        print("init training with checkpoint %s" % model_path + '\n')
        print('--------------------------------------------------------------------------------- \n')
        # 这里load的是整个模型,包括网络、优化方法等等
        checkpoint = t.load(model_path)
        if 'model' in checkpoint.keys():
            # 这里加载的是网络的pred_cls_score
            # 这里加载的是网络的pred_cls_score
            model.load_state_dict(checkpoint['model'])
        # 换个方式加载
        else:
            model_dict = model.state_dict()  # state_dict返回的是整个网络的状态的字典
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
        del checkpoint
        # 只有执行完下面这句,显存才会在Nvidia-smi中释放
        t.cuda.empty_cache()
        print("finish initing checkpoint! \n")
    elif model_path and not init:  # 无需初始化且有之前断点保存的模型时
        print("loading the previous checkpoint %s" % model_path + '\n')
        print('-------------------------------------------------------------------------------- \n')
        checkpoint = t.load(model_path)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint
        t.cuda.empty_cache()
        print("finish loading previous checkpoint! \n")
    elif not model_path and Config.pretrained_model:  # 需加载预训练模型的时候
        print("load pre-trained checkpoint %s" % Config.pretrained_model + '\n')
        print('-------------------------------------------------------------------------------- \n')
        checkpoint = t.load(Config.pretrained_model)
        checkpoint = {k.replace('features.features', 'sharedFeatExtra'): v for k, v in checkpoint.items()}
        model_dict = model.state_dict()
        model_dict.update(checkpoint)
        model.load_state_dict(model_dict)
        print("finish loading pre-trained model \n")

    # 训练的时候前3个层的参数是固定的
    def freeze_layers(model):
        for layer in model.sharedFeatExtra[:10]:
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()  # 由于参数固定,所以这层的bn相当于是测试模式
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.Conv2d):
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.MaxPool2d):
                continue
            elif isinstance(layer, nn.ReLU):
                continue
            else:
                raise KeyError("something wrong in fixing 3 layers \n")
            print("fixed layers:  \n", layer)

    if t.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    for epoch in range(start_epoch, Config.epoch + 1):
        print("staring epoch{} \n".format(epoch))
        train_loss = []
        model.train()  # 设置为训练模式 train=True
        if Config.fix_former_3_layers:
            if t.cuda.device_count() > 1:  # 如果GPU数量大于1,这样是什么意思?
                freeze_layers(model.module)
            else:
                freeze_layers(model)
        # 为了训练时在终端打印实时loss设置的
        loss_temp_cls = 0
        loss_temp_reg = 0
        loss_temp_template = 0
        for i, data in enumerate(tqdm(trainloader)):
            # 每次加载一个mini-batch数量的样本
            # exemplar_imgs size:[32,127,127,3]  regression_target size:[32,1805,4]
            exemplar_imgs, instance_imgs, regression_target, cls_label_map, instance_his_imgs = data
            if Config.update_template:  # 这里收集的历史搜索帧的大小已经裁剪成 模板大小
                instance_his_imgs = [x.numpy() for x in instance_his_imgs]
                instance_his_imgs = np.stack(instance_his_imgs).transpose(1, 0, 2, 3, 4)
                instance_his_imgs = instance_his_imgs.reshape(-1, Config.exemplar_size, Config.exemplar_size, 3)
                # exemplar_imgs = np.concatenate(exemplar_imgs, axis=0)  # 合并第一第二维度,因为网络的输入规定四维
                instance_his_imgs = t.from_numpy(instance_his_imgs)
                pred_cls_score, pred_regression, template_loss = model(exemplar_imgs.cuda(),
                                                        instance_imgs.cuda(),
                                                        instance_his_imgs,
                                                        training=True)
            else:
                pred_cls_score, pred_regression = model(exemplar_imgs.cuda(),
                                                        instance_imgs.cuda(),
                                                        instance_his_imgs,
                                                        training=True)
            regression_target, cls_label_map = regression_target.cuda(), cls_label_map.cuda()


            pred_cls_score = pred_cls_score.reshape(-1, 2,
                                                    Config.anchor_num *
                                                    Config.train_map_size *
                                                    Config.train_map_size).permute(0, 2, 1)

            pred_regression = pred_regression.reshape(-1, 4,
                                                      Config.anchor_num * Config.train_map_size *
                                                      Config.train_map_size).permute(0, 2, 1)

            cls_loss = rpn_cross_entropy_banlance(pred_cls_score, cls_label_map, Config.num_pos,
                                                  Config.num_neg, anchors,
                                                  ohem_pos=Config.ohem_pos, ohem_neg=Config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_regression, regression_target, cls_label_map,
                                    Config.num_pos, ohem=Config.ohem_reg)
            # 总的loss上加上模版loss
            if Config.update_template:
                loss = cls_loss + Config.lamb * reg_loss + template_loss
            else:
                loss = cls_loss + Config.lamb * reg_loss
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播求梯度
            loss.backward()
            t.nn.utils.clip_grad_norm_(model.parameters(), Config.clip)
            # 更新参数
            optimizer.step()
            step = (epoch - 1) * len(trainloader) + i
            # summary_writer.add_scalar('train/cls_loss', cls_loss.data, step)
            # summary_writer.add_scalar('train/reg_loss', reg_loss.data, step)
            if Config.update_template:
                summary_writer.add_scalars('train',
                                           {'cls_loss': cls_loss.data.item(), 'reg_loss': reg_loss.data.item(),
                                            'template_loss': template_loss.data.item(),
                                            'total_loss': loss.data.item()},
                                           step)
            else:
                summary_writer.add_scalars('train',
                                           {'cls_loss': cls_loss.data.item(), 'reg_loss': reg_loss.data.item(),
                                            'total_loss': loss.data.item()},
                                           step)
            # 加入总loss
            train_loss.append(loss.detach().cpu())
            loss_temp_cls += cls_loss.detach().cpu().numpy()
            loss_temp_reg += reg_loss.detach().cpu().numpy()
            loss_temp_template += template_loss.detach().cpu().numpy()
            if (i + 1) % Config.show_interval == 0:
                tqdm.write("[epoch %2d][iter %4d] cls_loss: %.4f, reg_loss: %.4f, temp_loss: %.4f, lr: %.2e"
                           % (epoch, i, loss_temp_cls / Config.show_interval,
                              loss_temp_reg / Config.show_interval, loss_temp_template / Config.show_interval,
                              optimizer.param_groups[0]['lr']))
                loss_temp_cls = 0
                loss_temp_reg = 0
                loss_temp_template = 0

                # 可视化
                if vis_port:
                    anchors_show = train_dataset.anchors
                    exem_img = exemplar_imgs[0].cpu().detach().numpy()
                    inst_img = instance_imgs[0].cpu().detach().numpy()
                    # choose odd layer and show the heatmap
                    # cls_response = cls_map_vis.squeeze()[0:10, :, :]
                    # cls_res_show = []
                    # for x in range(10):
                    #     if x % 2 == 1:
                    #         res = cls_response[x:x + 1, :, :].squeeze().cpu().detach().numpy()
                    #         cls_res_show.append(res)
                    # count = 20
                    # for heatmap in cls_res_show:
                    #     vis.plot_heatmap(heatmap, win=count)
                    #     count += count
                    topk = Config.show_topK
                    vis.plot_img(exem_img.transpose(2, 0, 1), win=1, name='exemplar_img')
                    cls_pred = cls_label_map[0]  # 对这个存疑,看看cls_pred的内容
                    gt_box = get_topK_box(cls_pred, regression_target[0], anchors_show)[0]
                    # show gt box
                    img_box = add_box_img(inst_img, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1), win=2, name='instance_img')
                    # show anchor with max score (without regression)
                    cls_pred = F.softmax(pred_cls_score, dim=2)[0, :, 1]  # 1 的意思是最后一维,第一个代表的是正样本结果
                    scores, index = t.topk(cls_pred, k=topk)
                    img_box = add_box_img(inst_img, anchors_show[index.cpu()])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1), win=3, name='max_score_anchors')

                    # max score anchor with regression
                    cls_pred = F.softmax(pred_cls_score, dim=2)[0, :, 1]
                    topk_box = get_topK_box(cls_pred, pred_regression[0], anchors_show, topk=topk)
                    img_box = add_box_img(inst_img, topk_box.squeeze())
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1), win=4, name='max_score_box')

                    # show anchor with max iou (without regression)
                    iou = compute_iou(anchors_show, gt_box).flatten()
                    index = np.argsort(iou)[-topk:]
                    img_box = add_box_img(inst_img, anchors_show[index])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1), win=5, name='max_iou_anchor')
                    # show regressed anchor with max iou
                    reg_offset = pred_regression[0].cpu().detach().numpy()
                    topk_offset = reg_offset[index, :]
                    anchors_det = anchors_show[index, :]
                    pred_box = box_transform_use_reg_offset(anchors_det, topk_offset)
                    img_box = add_box_img(inst_img, pred_box.squeeze())
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1), win=6, name='max_iou_box')

        train_loss = np.mean(train_loss)

        # finish training an epoch, starting validation
        valid_loss = []
        model.eval()
        for i, data in enumerate(tqdm(validloader)):
            exemplar_imgs, instance_imgs, regression_target, cls_label_map, instance_his_imgs = data
            if Config.update_template:  # 这里收集的历史搜索帧的大小已经裁剪成 模板大小
                instance_his_imgs = [x.numpy() for x in instance_his_imgs]
                instance_his_imgs = np.stack(instance_his_imgs).transpose(1, 0, 2, 3, 4)
                instance_his_imgs = instance_his_imgs.reshape(-1, Config.exemplar_size, Config.exemplar_size, 3)
                # exemplar_imgs = np.concatenate(exemplar_imgs, axis=0)  # 合并第一第二维度,因为网络的输入规定四维
                instance_his_imgs = t.from_numpy(instance_his_imgs)
                pred_cls_score, pred_regression, template_loss = model(exemplar_imgs.cuda(),
                                                                       instance_imgs.cuda(),
                                                                       instance_his_imgs,
                                                                       training=False)
            else:
                pred_cls_score, pred_regression = model(exemplar_imgs.cuda(),
                                                        instance_imgs.cuda(),
                                                        instance_his_imgs)
            regression_target, cls_label_map = regression_target.cuda(), cls_label_map.cuda()
            pred_cls_score = pred_cls_score.reshape(-1, 2,
                                                    Config.anchor_num *
                                                    Config.train_map_size *
                                                    Config.train_map_size).permute(0, 2, 1)
            pred_regression = pred_regression.reshape(-1, 4,
                                                      Config.anchor_num * Config.train_map_size *
                                                      Config.train_map_size).permute(0, 2, 1)

            cls_loss = rpn_cross_entropy_banlance(pred_cls_score, cls_label_map, Config.num_pos,
                                                  Config.num_neg, anchors, ohem_pos=Config.ohem_pos,
                                                  ohem_neg=Config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_regression, regression_target, cls_label_map,
                                    Config.num_pos, Config.ohem_reg)
            if Config.update_template:
                loss = cls_loss + Config.lamb * reg_loss + template_loss
            else:
                loss = cls_loss + Config.lamb * reg_loss
            valid_loss.append(loss.detach().cpu())
        valid_loss = np.mean(valid_loss)
        print("[EPOCH %2d] valid_loss: %.4f, train_loss: %.4f", (epoch, valid_loss, train_loss))
        # 这里验证集的add_scalar的step参数和之前训练时候的不同
        if Config.update_template:
            summary_writer.add_scalars('valid', {'cls_loss': cls_loss.data.item(),
                                                 'reg_loss': reg_loss.data.item(),
                                                 'template_loss': template_loss.data.item(),
                                                 'total_loss': loss.data.item()},
                                       (epoch + 1) * len(trainloader))
        else:
            summary_writer.add_scalars('valid', {'cls_loss': cls_loss.data.item(),
                                                 'reg_loss': reg_loss.data.item(),
                                                 'total_loss': loss.data.item()},
                                       (epoch + 1) * len(trainloader))
        ajust_learning_rate(optimizer, Config.gamma)
        if epoch % 10 == 0:  # 每10个epoch看一下已经选择过的序列
            print(train_dataset.choosed_idx.sort())
        # save model
        if epoch % Config.save_interval == 0:
            if not os.path.exists('../data/models/'):
                os.mkdir('../data/models/')
            if Config.update_template:
                save_name = '../data/models/siamrpn_stmm_select_adaptive_epoch_{}.pth'.format(epoch)
            else:
                save_name = '../data/models/siamrpn_noupdate_adaptive_epoch_{}.pth'.format(epoch)
            if t.cuda.device_count() > 1:  # remove 'module.'
                new_state_dict = OrderedDict()
                for k, v in model.state_dict().items():
                    new_state_dict[k] = v
            new_state_dict = model.state_dict()
            t.save({
                'epoch': epoch,
                'model': new_state_dict,
                'optimizer': optimizer.state_dict(),
            }, save_name)
            print('save model as:{}'.format(save_name))
Ejemplo n.º 3
0
def train(data_dir, model_path=None, vis_port=None, init=None):
    # loading meta data
    # -----------------------------------------------------------------------------------------------------
    meta_data_path = os.path.join(data_dir, "meta_data.pkl")
    meta_data = pickle.load(open(meta_data_path, 'rb'))
    all_videos = [x[0] for x in meta_data]

    # split train/valid dataset
    # -----------------------------------------------------------------------------------------------------
    train_videos, valid_videos = train_test_split(all_videos,
                                                  test_size=1 -
                                                  config.train_ratio,
                                                  random_state=config.seed)

    # define transforms
    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([ToTensor()])
    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])

    # open lmdb
    db = lmdb.open(data_dir + '.lmdb', readonly=True, map_size=int(200e9))

    # create dataset
    # -----------------------------------------------------------------------------------------------------
    train_dataset = ImagnetVIDDataset(db, train_videos, data_dir,
                                      train_z_transforms, train_x_transforms)
    anchors = train_dataset.anchors
    # dic_num = {}
    # ind_random = list(range(len(train_dataset)))
    # import random
    # random.shuffle(ind_random)
    # for i in tqdm(ind_random):
    #     exemplar_img, instance_img, regression_target, conf_target = train_dataset[i+1000]

    valid_dataset = ImagnetVIDDataset(db,
                                      valid_videos,
                                      data_dir,
                                      valid_z_transforms,
                                      valid_x_transforms,
                                      training=False)
    # create dataloader
    trainloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size * torch.cuda.device_count(),
        shuffle=True,
        pin_memory=True,
        num_workers=config.train_num_workers * torch.cuda.device_count(),
        drop_last=True)
    validloader = DataLoader(
        valid_dataset,
        batch_size=config.valid_batch_size * torch.cuda.device_count(),
        shuffle=False,
        pin_memory=True,
        num_workers=config.valid_num_workers * torch.cuda.device_count(),
        drop_last=True)

    # create summary writer
    if not os.path.exists(config.log_dir):
        os.mkdir(config.log_dir)
    summary_writer = SummaryWriter(config.log_dir)
    if vis_port:
        vis = visual(port=vis_port)

    # start training
    # -----------------------------------------------------------------------------------------------------
    model = SiameseAlexNet()
    model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)
    # load model weight
    # -----------------------------------------------------------------------------------------------------
    start_epoch = 1
    if model_path and init:
        print("init training with checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        if 'model' in checkpoint.keys():
            model.load_state_dict(checkpoint['model'])
        else:
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
        del checkpoint
        torch.cuda.empty_cache()
        print("inited checkpoint")
    elif model_path and not init:
        print("loading checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
        print("loaded checkpoint")
    elif not model_path and config.pretrained_model:
        print("init with pretrained checkpoint %s" % config.pretrained_model +
              '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.state_dict()
        model_dict.update(checkpoint)
        model.load_state_dict(model_dict)

    # freeze layers
    def freeze_layers(model):
        print(
            '------------------------------------------------------------------------------------------------'
        )
        for layer in model.featureExtract[:10]:
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.Conv2d):
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.MaxPool2d):
                continue
            elif isinstance(layer, nn.ReLU):
                continue
            else:
                raise KeyError('error in fixing former 3 layers')
        print("fixed layers:")
        print(model.featureExtract[:10])

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    for epoch in range(start_epoch, config.epoch + 1):
        train_loss = []
        model.train()
        if config.fix_former_3_layers:
            if torch.cuda.device_count() > 1:
                freeze_layers(model.module)
            else:
                freeze_layers(model)
        loss_temp_cls = 0
        loss_temp_reg = 0
        for i, data in enumerate(tqdm(trainloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target = data
            # conf_target (8,1125) (8,225x5)
            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
            optimizer.step()

            step = (epoch - 1) * len(trainloader) + i
            summary_writer.add_scalar('train/cls_loss', cls_loss.data, step)
            summary_writer.add_scalar('train/reg_loss', reg_loss.data, step)
            train_loss.append(loss.detach().cpu())
            loss_temp_cls += cls_loss.detach().cpu().numpy()
            loss_temp_reg += reg_loss.detach().cpu().numpy()
            # if vis_port:
            #     vis.plot_error({'rpn_cls_loss': cls_loss.detach().cpu().numpy().ravel()[0],
            #                     'rpn_regress_loss': reg_loss.detach().cpu().numpy().ravel()[0]}, win=0)
            if (i + 1) % config.show_interval == 0:
                tqdm.write(
                    "[epoch %2d][iter %4d] cls_loss: %.4f, reg_loss: %.4f lr: %.2e"
                    % (epoch, i, loss_temp_cls / config.show_interval,
                       loss_temp_reg / config.show_interval,
                       optimizer.param_groups[0]['lr']))
                loss_temp_cls = 0
                loss_temp_reg = 0
                if vis_port:
                    anchors_show = train_dataset.anchors
                    exem_img = exemplar_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)
                    inst_img = instance_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)

                    # show detected box with max score
                    topk = config.show_topK
                    vis.plot_img(exem_img.transpose(2, 0, 1),
                                 win=1,
                                 name='exemple')
                    cls_pred = conf_target[0]
                    gt_box = get_topk_box(cls_pred, regression_target[0],
                                          anchors_show)[0]

                    # show gt_box
                    img_box = add_box_img(inst_img, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=2,
                                 name='instance')

                    # show anchor with max score
                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    scores, index = torch.topk(cls_pred, k=topk)
                    img_box = add_box_img(inst_img, anchors_show[index.cpu()])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=3,
                                 name='anchor_max_score')

                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    topk_box = get_topk_box(cls_pred,
                                            pred_offset[0],
                                            anchors_show,
                                            topk=topk)
                    img_box = add_box_img(inst_img, topk_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=4,
                                 name='box_max_score')

                    # show anchor and detected box with max iou
                    iou = compute_iou(anchors_show, gt_box).flatten()
                    index = np.argsort(iou)[-topk:]
                    img_box = add_box_img(inst_img, anchors_show[index])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=5,
                                 name='anchor_max_iou')

                    # detected box
                    regress_offset = pred_offset[0].cpu().detach().numpy()
                    topk_offset = regress_offset[index, :]
                    anchors_det = anchors_show[index, :]
                    pred_box = box_transform_inv(anchors_det, topk_offset)
                    img_box = add_box_img(inst_img, pred_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=6,
                                 name='box_max_iou')

        train_loss = np.mean(train_loss)

        valid_loss = []
        model.eval()
        for i, data in enumerate(tqdm(validloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target = data

            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            valid_loss.append(loss.detach().cpu())
        valid_loss = np.mean(valid_loss)
        print("EPOCH %d valid_loss: %.4f, train_loss: %.4f" %
              (epoch, valid_loss, train_loss))
        summary_writer.add_scalar('valid/loss', valid_loss,
                                  (epoch + 1) * len(trainloader))
        adjust_learning_rate(
            optimizer, config.gamma
        )  # adjust before save, and it will be epoch+1's lr when next load
        if epoch % config.save_interval == 0:
            if not os.path.exists('./data/models/'):
                os.makedirs("./data/models/")
            save_name = "./data/models/siamrpn_{}.pth".format(epoch)
            new_state_dict = model.state_dict()
            if torch.cuda.device_count() > 1:
                new_state_dict = OrderedDict()
                for k, v in model.state_dict().items():
                    namekey = k[7:]  # remove `module.`
                    new_state_dict[namekey] = v
            torch.save(
                {
                    'epoch': epoch,
                    'model': new_state_dict,
                    'optimizer': optimizer.state_dict(),
                }, save_name)
            print('save model: {}'.format(save_name))
Ejemplo n.º 4
0
def train(data_dir, model_path=None, vis_port=None, init=None):
    # loading meta data
    # -----------------------------------------------------------------------------------------------------
    meta_data_path = os.path.join(data_dir, "meta_data.pkl")
    meta_data = pickle.load(
        open(meta_data_path, 'rb')
    )  # meta_data[0] = ('ILSVRC2015_train_00001000', {0: ['000000', '000001', '000002',...]}),
    all_videos = [x[0] for x in meta_data]

    # split train/valid dataset
    # -----------------------------------------------------------------------------------------------------
    train_videos, valid_videos = train_test_split(all_videos,
                                                  test_size=1 -
                                                  config.train_ratio,
                                                  random_state=config.seed)
    print("after split:train_videos {0},valid_videos {1}".format(
        len(train_videos), len(valid_videos)))
    # define transforms
    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([ToTensor()])
    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])

    # open lmdb
    # db = lmdb.open(data_dir + '_lmdb', readonly=True, map_size=int(1024*1024*1024)) # 200e9,单位Byte
    db_path = data_dir + '_Lmdb'
    # create dataset
    # -----------------------------------------------------------------------------------------------------
    train_dataset = ImagnetVIDDataset(db_path, train_videos, data_dir,
                                      train_z_transforms, train_x_transforms)
    # test __getitem__
    # train_dataset.__getitem__(1)
    # exit(0)

    anchors = train_dataset.anchors  # (1805,4) = (19*19*5,4)
    # dic_num = {}
    # ind_random = list(range(len(train_dataset)))
    # import random
    # random.shuffle(ind_random)
    # for i in tqdm(ind_random):
    #     exemplar_img, instance_img, regression_target, conf_target = train_dataset[i+1000]

    valid_dataset = ImagnetVIDDataset(db_path,
                                      valid_videos,
                                      data_dir,
                                      valid_z_transforms,
                                      valid_x_transforms,
                                      training=False)
    # create dataloader
    trainloader = DataLoader(train_dataset,
                             batch_size=config.train_batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=config.train_num_workers,
                             drop_last=True)
    validloader = DataLoader(valid_dataset,
                             batch_size=config.valid_batch_size,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=config.valid_num_workers,
                             drop_last=True)

    # create summary writer
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    summary_writer = SummaryWriter(config.log_dir)
    if vis_port:
        vis = visual(port=vis_port)

    # start training
    # -----------------------------------------------------------------------------------------------------
    # model = SiameseAlexNet()
    model = SiamFPN50()
    model.init_weights()  # 权重初始化
    if config.CUDA:
        model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)
    # load model weight
    # -----------------------------------------------------------------------------------------------------
    start_epoch = 1
    if model_path and init:
        print("init training with checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        if 'model' in checkpoint.keys():
            model.load_state_dict(checkpoint['model'])
        else:
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
        del checkpoint
        torch.cuda.empty_cache()
        print("inited checkpoint")
    elif model_path and not init:
        print("loading checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
        print("loaded checkpoint")
    elif not model_path and config.pretrained_model:
        print("init with pretrained checkpoint %s" % config.pretrained_model +
              '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.state_dict()
        model_dict.update(checkpoint)
        model.load_state_dict(model_dict)

    #  layers
    def freeze_layers(model):
        print(
            '------------------------------------------------------------------------------------------------'
        )
        for layer in model.featureExtract[:10]:
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.Conv2d):
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.MaxPool2d):
                continue
            elif isinstance(layer, nn.ReLU):
                continue
            else:
                raise KeyError('error in fixing former 3 layers')
        # print("fixed layers:")
        # print(model.featureExtract[:10])
        '''
        fixed layers:
        Sequential(
        (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(2, 2))
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU(inplace)
        (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1))
        (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (7): ReLU(inplace)
        (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1))
        (9): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        '''

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)  # 前提是model已经.cuda()了
    # if isinstance(model,torch.nn.DataParallel): # 多GPU训练, AttributeError: ‘DataParallel’ object has no attribute ‘xxxx’
    #     model = model.module
    for epoch in range(start_epoch, config.EPOCH + 1):
        train_loss = []
        model.train()
        if config.fix_former_3_layers:  # 暂时去掉 # 固定前三层卷积的v.requires_grad = False
            if torch.cuda.device_count() > 1:
                freeze_layers(model.module)
            else:
                freeze_layers(model)
        loss_temp_cls = 0
        loss_temp_reg = 0
        loss_temp = 0
        # for i, data in enumerate(tqdm(trainloader)): # can't pickle Transaction objects
        for k, data in enumerate(tqdm(trainloader)):  # 这里有问题,loader没有遍历完就跳走了
            # print("done")
            # return
            # (8,3,127,127)\(8,3,271,271)\(8,1805,4)\(8,1805)
            # 8为batch_size,1445 = 19 * 19 * 5,5 = anchors_num
            # exemplar_imgs, instance_imgs, regression_target, conf_target = data
            exemplar_imgs, instance_imgs, regression_targets, conf_targets = data

            # conf_target (8,1125) (8,225x5)
            if config.CUDA:
                # 这里有问题,regression_targets是list,不能直接使用.cuda(),后面考虑将其压缩成(N,4)的形式
                # regression_targets, conf_targets = torch.tensor(regression_targets).cuda(), torch.tensor(conf_targets).cuda()
                exemplar_imgs, instance_imgs = exemplar_imgs.cuda(
                ), instance_imgs.cuda()

            # # 基于一层的损失计算
            # # (8,10,19,19)\(8,20,19,19)
            # pred_score, pred_regression = model(exemplar_imgs, instance_imgs)
            # # (8,1805,2)
            # pred_conf = pred_score.reshape(-1, 2, config.anchor_num * config.score_size * config.score_size).permute(0,2,1)
            # # (8,1805,4)
            # pred_offset = pred_regression.reshape(-1, 4,config.anchor_num * config.score_size * config.score_size).permute(0,2,1)

            # cls_loss = rpn_cross_entropy_balance(pred_conf, conf_target, config.num_pos, config.num_neg, anchors,
            #                                      ohem_pos=config.ohem_pos, ohem_neg=config.ohem_neg)
            # reg_loss = rpn_smoothL1(pred_offset, regression_target, conf_target, config.num_pos, ohem=config.ohem_reg)
            # loss = cls_loss + config.lamb * reg_loss
            # 基于金字塔模型的损失计算
            # try:
            #     output = model(input)
            # except RuntimeError as exception:
            #     if "out of memory" in str(exception):
            #         print("WARNING: out of memory")
            #         if hasattr(torch.cuda, 'empty_cache'):
            #             torch.cuda.empty_cache()
            #     else:
            #         raise exception

            pred_scores, pred_regressions = model(exemplar_imgs, instance_imgs)
            # FEATURE_MAP_SIZE、FPN_ANCHOR_NUM
            '''
            when batch_size = 2, anchor_num = 3
            torch.Size([N, 6, 37, 37])
            torch.Size([N, 6, 19, 19])
            torch.Size([N, 6, 10, 10])
            torch.Size([N, 6, 6, 6])

            torch.Size([N, 12, 37, 37])
            torch.Size([N, 12, 19, 19])
            torch.Size([N, 12, 10, 10])
            torch.Size([N, 12, 6, 6])
            '''
            loss = 0
            cls_loss_sum = 0
            reg_loss_sum = 0
            for i in range(len(pred_scores)):
                if i != 1:
                    continue  # 这里先只考虑一层(19*19)的损失,其余的暂时不考虑
                pred_score = pred_scores[i]
                pred_regression = pred_regressions[i]
                anchors_num = config.FPN_ANCHOR_NUM * config.FEATURE_MAP_SIZE[
                    i] * config.FEATURE_MAP_SIZE[i]
                pred_conf = pred_score.reshape(-1, 2,
                                               anchors_num).permute(0, 2, 1)
                pred_offset = pred_regression.reshape(-1, 4,
                                                      anchors_num).permute(
                                                          0, 2, 1)

                conf_target = conf_targets[i]
                regression_target = regression_targets[i].type(
                    torch.FloatTensor)  # pred_offset是float类型
                if config.CUDA:
                    conf_target = conf_target.cuda()
                    regression_target = regression_target.cuda()
                # 二分类损失计算(交叉熵)
                cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                     conf_target,
                                                     config.num_pos,
                                                     config.num_neg,
                                                     anchors[i],
                                                     ohem_pos=config.ohem_pos,
                                                     ohem_neg=config.ohem_neg)
                # 回归损失计算(Smooth L1) # 这里应该有问题,回归损失的值为0
                reg_loss = rpn_smoothL1(pred_offset,
                                        regression_target,
                                        conf_target,
                                        config.num_pos,
                                        ohem=config.ohem_reg)

                _loss = cls_loss + reg_loss * config.lamb_reg  # config.lamb_cls

                loss += _loss  # 这里四层的loss先直接加起来,后面考虑加权处理
                # 用于tensorboard展示cls_loss\reg_loss 原样输出
                cls_loss_sum = cls_loss_sum + cls_loss
                reg_loss_sum = reg_loss_sum + reg_loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
            optimizer.step()

            step = (epoch - 1) * len(trainloader) + k
            summary_writer.add_scalar('train/cls_loss', cls_loss_sum.data,
                                      step)
            summary_writer.add_scalar('train/reg_loss', reg_loss_sum.data,
                                      step)
            loss = loss.detach().cpu()
            train_loss.append(loss)
            loss_temp_cls += cls_loss_sum.detach().cpu().numpy()
            loss_temp_reg += reg_loss_sum.detach().cpu().numpy()
            loss_temp += loss.numpy()
            # if vis_port:
            #     vis.plot_error({'rpn_cls_loss': cls_loss.detach().cpu().numpy().ravel()[0],
            #                     'rpn_regress_loss': reg_loss.detach().cpu().numpy().ravel()[0]}, win=0)

            # print("Epoch {0} batch {1} training_loss:{2}".format(epoch, k+1, loss))

            if (k + 1) % config.show_interval == 0:
                tqdm.write(
                    "[epoch %2d][iter %4d] loss: %.4f, cls_loss: %.4f, reg_loss: %.4f lr: %.2e"
                    % (epoch, k + 1, loss_temp / config.show_interval,
                       loss_temp_cls / config.show_interval, loss_temp_reg /
                       config.show_interval, optimizer.param_groups[0]['lr']))
                loss_temp_cls = 0
                loss_temp_reg = 0
                loss_temp = 0
                # 视觉展示
                if vis_port:
                    anchors_show = train_dataset.anchors
                    exem_img = exemplar_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)
                    inst_img = instance_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)

                    # show detected box with max score
                    topk = config.show_topK
                    vis.plot_img(exem_img.transpose(2, 0, 1),
                                 win=1,
                                 name='exemple')
                    cls_pred = conf_target[0]
                    gt_box = get_topk_box(cls_pred, regression_target[0],
                                          anchors_show)[0]

                    # show gt_box
                    img_box = add_box_img(inst_img, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=2,
                                 name='instance')

                    # show anchor with max score
                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    scores, index = torch.topk(cls_pred, k=topk)
                    img_box = add_box_img(inst_img, anchors_show[index.cpu()])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=3,
                                 name='anchor_max_score')

                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    topk_box = get_topk_box(cls_pred,
                                            pred_offset[0],
                                            anchors_show,
                                            topk=topk)
                    img_box = add_box_img(inst_img, topk_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=4,
                                 name='box_max_score')

                    # show anchor and detected box with max iou
                    iou = compute_iou(anchors_show, gt_box).flatten()
                    index = np.argsort(iou)[-topk:]
                    img_box = add_box_img(inst_img, anchors_show[index])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=5,
                                 name='anchor_max_iou')

                    # detected box
                    regress_offset = pred_offset[0].cpu().detach().numpy()
                    topk_offset = regress_offset[index, :]
                    anchors_det = anchors_show[index, :]
                    pred_box = box_transform_inv(anchors_det, topk_offset)
                    img_box = add_box_img(inst_img, pred_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=6,
                                 name='box_max_iou')
        train_loss = np.mean(train_loss)
        # print("done")

        # exit(0)
        # 验证
        valid_loss = []
        # 不计算梯度,节约显存(验证阶段等价于测试,仅计算结果)
        with torch.no_grad():
            model.eval()
            # for i, data in enumerate(tqdm(validloader)):
            for i, data in enumerate(tqdm(validloader)):
                exemplar_imgs, instance_imgs, regression_targets, conf_targets = data
                if config.CUDA:
                    exemplar_imgs, instance_imgs = exemplar_imgs.cuda(
                    ), instance_imgs.cuda()

                pred_scores, pred_regressions = model(exemplar_imgs,
                                                      instance_imgs)
                loss = 0
                for i in range(len(pred_scores)):
                    if i != 1:
                        continue  # 这里先只考虑一层(19*19)的损失,其余的暂时不考虑
                    pred_score = pred_scores[i]
                    pred_regression = pred_regressions[i]
                    anchors_num = config.FPN_ANCHOR_NUM * config.FEATURE_MAP_SIZE[
                        i] * config.FEATURE_MAP_SIZE[i]
                    pred_conf = pred_score.reshape(-1, 2, anchors_num).permute(
                        0, 2, 1)
                    pred_offset = pred_regression.reshape(-1, 4,
                                                          anchors_num).permute(
                                                              0, 2, 1)

                    conf_target = conf_targets[i]
                    regression_target = regression_targets[i].type(
                        torch.FloatTensor)  # pred_offset是float类型
                    if config.CUDA:
                        conf_target = conf_target.cuda()
                        regression_target = regression_target.cuda()
                    # 二分类损失计算(交叉熵)
                    cls_loss = rpn_cross_entropy_balance(
                        pred_conf,
                        conf_target,
                        config.num_pos,
                        config.num_neg,
                        anchors[i],
                        ohem_pos=config.ohem_pos,
                        ohem_neg=config.ohem_neg)
                    # 回归损失计算(Smooth L1) # 这里应该有问题,回归损失的值为0
                    reg_loss = rpn_smoothL1(pred_offset,
                                            regression_target,
                                            conf_target,
                                            config.num_pos,
                                            ohem=config.ohem_reg)

                    _loss = cls_loss * config.lamb_cls + reg_loss * config.lamb_reg
                    loss += _loss  # 这里四层的loss先直接加起来,后面考虑加权处理
                valid_loss.append(loss.detach().cpu())
        valid_loss = np.mean(valid_loss)

        print("EPOCH %d valid_loss: %.4f, train_loss: %.4f" %
              (epoch, valid_loss, train_loss))
        summary_writer.add_scalar('valid/loss', valid_loss,
                                  (epoch + 1) * len(trainloader))
        # 调整学习率
        adjust_learning_rate(
            optimizer, config.gamma
        )  # adjust before save, and it will be epoch+1's lr when next load
        # 保存训练好的模型
        if epoch % config.save_interval == 0:
            if not os.path.exists('./data/models/'):
                os.makedirs("./data/models/")

            save_name = "./data/models/otb_siamfpn_{}_trainloss_{:.4f}_validloss_{:.4f}.pth".format(
                epoch, train_loss, valid_loss)
            new_state_dict = model.state_dict()
            if torch.cuda.device_count() > 1:
                new_state_dict = OrderedDict()
                for k, v in model.state_dict().items():
                    namekey = k[7:]  # remove `module.`
                    new_state_dict[namekey] = v
            torch.save(
                {
                    'epoch': epoch,
                    'model': new_state_dict,
                    'optimizer': optimizer.state_dict(),
                }, save_name)
            print('save model: {}'.format(save_name))

        # 清空缓存
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()