Example #1
0
def val(model1, model2, dataParser, epoch):
    # 读取数据的迭代器

    val_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]
    # switch to train mode
    model1.eval()
    model2.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(False):
            images.requires_grad = False
            # 网络输出
            one_stage_outputs = model1(images)

            rgb_pred = images * one_stage_outputs[0]
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1],
                                       one_stage_outputs[2],
                                       one_stage_outputs[3])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_band)
            ##############################################
            # deal with two stage issues
            _loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                                labels_dou_edge) * 12

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
                loss_8t += one_loss_t
                map8_loss_value[c_index].update(one_loss_t.item())

            _loss_stage_2 += loss_8t
            loss_stage_2 = _loss_stage_2 / 20
            loss = (loss_stage_1 + loss_stage_2) / 2

            #######################################
            # 总的LOSS
            writer.add_scalar('val_stage_one_loss',
                              loss_stage_1.item(),
                              global_step=epoch * val_epoch + batch_index)
            writer.add_scalar('val_stage_two_fuse_loss',
                              loss_stage_2.item(),
                              global_step=epoch * val_epoch + batch_index)
            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * val_epoch + batch_index)
            ##########################################

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())

        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        precisionscore_stage2 = my_precision_score(two_stage_outputs[0],
                                                   labels_dou_edge)
        accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        recallscore_stage2 = my_recall_score(two_stage_outputs[0],
                                             labels_dou_edge)

        f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        precisionscore_stage1 = my_precision_score(one_stage_outputs[0],
                                                   labels_band)
        accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        writer.add_scalar('val_f1_score_stage1',
                          f1score_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score_stage1',
                          precisionscore_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score_stage1',
                          accscore_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score_stage1',
                          recallscore_stage1,
                          global_step=epoch * val_epoch + batch_index)

        writer.add_scalar('val_f1_score_stage2',
                          f1score_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score_stage2',
                          precisionscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score_stage2',
                          accscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score_stage2',
                          recallscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, val_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= val_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg,
        'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }
Example #2
0
def train(model1, model2, optimizer1, optimizer2, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]

    # switch to train mode
    model1.train()
    model2.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            if images.shape[1] != 3 or images.shape[2] != 320:
                continue
            # 网络输出

            try:
                one_stage_outputs = model1(images)
            except Exception as e:
                print(e)
                print(images.shape)
                continue

            rgb_pred = images * one_stage_outputs[0]
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1],
                                       one_stage_outputs[2],
                                       one_stage_outputs[3])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_band)
            ##############################################
            # deal with two stage issues
            loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                               labels_dou_edge)

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
                loss_8t += one_loss_t
                # print(one_loss_t)
                map8_loss_value[c_index].update(one_loss_t.item())

            # print(loss_stage_2)
            # print(map8_loss_value)
            loss = loss_stage_2 + loss_8t * 10
            #######################################
            # 总的LOSS
            # print(type(loss_stage_2.item()))
            writer.add_scalars('loss_gather', {
                'stage_one_loss': loss_stage_1.item(),
                'stage_two_fuse_loss': loss_stage_2.item()
            },
                               global_step=epoch * train_epoch + batch_index)
            ##########################################
            loss.backward()

            optimizer1.step()
            optimizer2.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        # f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        # precisionscore_stage2 = my_precision_score(two_stage_outputs[0], labels_dou_edge)
        # accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        # recallscore_stage2 = my_recall_score(two_stage_outputs[0], labels_dou_edge)

        f1score_stage2 = 1
        precisionscore_stage2 = 1
        accscore_stage2 = 1
        recallscore_stage2 = 1
        #
        # f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        # precisionscore_stage1 = my_precision_score(one_stage_outputs[0], labels_band)
        # accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        # recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        f1score_stage1 = 1
        precisionscore_stage1 = 1
        accscore_stage1 = 1
        recallscore_stage1 = 1
        writer.add_scalars('f1_score_stage', {
            'stage1': f1score_stage1,
            'stage2': f1score_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('precision_score_stage', {
            'stage1': precisionscore_stage1,
            'stage2': precisionscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('acc_score_stage', {
            'stage1': accscore_stage1,
            'stage2': accscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('recall_score_stage', {
            'stage1': recallscore_stage1,
            'stage2': recallscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg,
        'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }