Example #1
0
    def analyze_all(self):
        """

        :return:
        """
        path_dir = self.pred_dir
        # 1. path check
        if os.path.exists(path_dir):
            print('path ok')
        else:
            print('Path Not find:', path_dir)
            sys.exit()

        pred_list = os.listdir(path_dir)
        fileNumber = len(pred_list)
        print('The number of images are: ', fileNumber)
        rate = 1
        pickNumber = int(fileNumber * rate)
        sample = random.sample(pred_list, pickNumber)

        f1_score_list = []
        pred_name_list = []
        loss_list = []
        src_name_list = []
        gt_name_list = []
        band_gt_name_list = []
        precision_score_list = []
        acc_list = []
        recall_list = []

        combineArray = np.zeros((320, 4 * 320, 3))

        for index, name in enumerate(sample):
            print(index, '/', len(sample))
            src_path, gt_path = Analyze.__find_src_and_gt(self, name)
            pred_img = os.path.join(path_dir, name)
            pred_img = Image.open(pred_img)
            src_img = Image.open(src_path)
            gt_img = Image.open(gt_path)

            # check channel required 1 dim
            if len(pred_img.split()) == 3:
                pred_img = pred_img.split()[0]
            else:
                pass
            if len(gt_img.split()) == 3:
                gt_img = gt_img.split()[0]
            else:
                pass

            # convert to ndarray and normalize,then to tensor
            pred_ndarray = np.array(pred_img)
            pred_ndarray3D = np.expand_dims(pred_ndarray, axis=2)
            pred_ndarray = pred_ndarray / 255

            pred_ndarray4D = pred_ndarray[np.newaxis, np.newaxis, :, :]
            # convert numpy to tensor
            pred_img_tensor = torch.from_numpy(pred_ndarray4D)

            # compute loss
            gt_ndarray = np.array(gt_img)
            gt = gt_ndarray.copy()
            gt = np.where((gt == 100) | (gt == 255), 1, 0)
            gt_ndarray3D = np.expand_dims(gt_ndarray, axis=2)
            gt_ndarray4D = gt_ndarray[np.newaxis, np.newaxis, :, :]
            band_gt_np = Analyze.__gen_band_gt(self, gt_ndarray4D)
            band_gt_np3D = band_gt_np.squeeze(0)
            band_gt_np2D = band_gt_np3D.squeeze(0)
            band_gt_np3DLast1 = np.expand_dims(band_gt_np2D, axis=2)
            band_gt_np3DLast1 = band_gt_np3DLast1 * 255
            band_gt_img = Image.fromarray(band_gt_np2D)
            band_gt_prefix = 'band5_'

            gt_name = gt_path.split('/')[-1]

            band_gt_name = band_gt_prefix + gt_name
            band_gt_img.save(os.path.join(self.save_band_dir, band_gt_name))

            band_gt_tensor = torch.from_numpy(band_gt_np)

            # compute loss,f1,acc,precision,recall
            gt = torch.from_numpy(gt)
            gt = gt.unsqueeze(0)
            gt = gt.unsqueeze(0)
            loss_tonsor = wce_dice_huber_loss(pred_img_tensor.float(),
                                              gt.float())

            # loss_tonsor = wce_dice_huber_loss(pred_img_tensor.float(), band_gt_tensor.float())
            loss = loss_tonsor.item()

            f1_score = my_f1_score(pred_img_tensor, band_gt_tensor)

            acc_score = my_acc_score(pred_img_tensor, band_gt_tensor)
            recall = my_recall_score(pred_img_tensor, band_gt_tensor)
            precision = my_precision_score(pred_img_tensor, band_gt_tensor)

            # output to csv
            f1_score_list.append(f1_score)
            pred_name_list.append(name)
            loss_list.append(loss)
            acc_list.append(acc_score)
            recall_list.append(recall)
            precision_score_list.append(precision)

            src_name = src_path.split('/')[-1]
            src_name_list.append(src_name)

            gt_name_list.append(gt_name)

            band_gt_name_list.append(band_gt_name)

            # combine_plot
            src_ndarray = np.array(src_img)
            combineArray[:, :320, :] = src_ndarray
            combineArray[:, 320:640, :] = gt_ndarray3D
            combineArray[:, 640:960, :] = band_gt_np3DLast1
            combineArray[:, 960:, :] = pred_ndarray3D

            combineImg = Image.fromarray(combineArray.astype(np.uint8))
            combineImg_prefix = 'comb_'
            combineImg_name = combineImg_prefix + src_name
            combineImg.save(
                os.path.join(self.save_combine_dir, combineImg_name))

            # difficult top-k
        data = {
            'srcName': src_name_list,
            'gtName': gt_name_list,
            'bandGtName': band_gt_name_list,
            'predName': pred_name_list,
            'loss': loss_list,
            'precision': precision_score_list,
            'recall': recall_list,
            'f1': f1_score_list,
            'acc': acc_list
        }
        test = pd.DataFrame(data)
        test.to_excel(self.save_excel_dir)
Example #2
0
def train(model1, model2, optimizer1, optimizer2, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]

    # switch to train mode
    model1.train()
    model2.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            if images.shape[1] != 3 or images.shape[2] != 320:
                continue
            # 网络输出

            try:
                one_stage_outputs = model1(images)
            except Exception as e:
                print(e)
                print(images.shape)
                continue

            rgb_pred = images * one_stage_outputs[0]
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1],
                                       one_stage_outputs[2],
                                       one_stage_outputs[3])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_band)
            ##############################################
            # deal with two stage issues
            loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                               labels_dou_edge)

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
                loss_8t += one_loss_t
                # print(one_loss_t)
                map8_loss_value[c_index].update(one_loss_t.item())

            # print(loss_stage_2)
            # print(map8_loss_value)
            loss = loss_stage_2 + loss_8t * 10
            #######################################
            # 总的LOSS
            # print(type(loss_stage_2.item()))
            writer.add_scalars('loss_gather', {
                'stage_one_loss': loss_stage_1.item(),
                'stage_two_fuse_loss': loss_stage_2.item()
            },
                               global_step=epoch * train_epoch + batch_index)
            ##########################################
            loss.backward()

            optimizer1.step()
            optimizer2.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        # f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        # precisionscore_stage2 = my_precision_score(two_stage_outputs[0], labels_dou_edge)
        # accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        # recallscore_stage2 = my_recall_score(two_stage_outputs[0], labels_dou_edge)

        f1score_stage2 = 1
        precisionscore_stage2 = 1
        accscore_stage2 = 1
        recallscore_stage2 = 1
        #
        # f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        # precisionscore_stage1 = my_precision_score(one_stage_outputs[0], labels_band)
        # accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        # recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        f1score_stage1 = 1
        precisionscore_stage1 = 1
        accscore_stage1 = 1
        recallscore_stage1 = 1
        writer.add_scalars('f1_score_stage', {
            'stage1': f1score_stage1,
            'stage2': f1score_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('precision_score_stage', {
            'stage1': precisionscore_stage1,
            'stage2': precisionscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('acc_score_stage', {
            'stage1': accscore_stage1,
            'stage2': accscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('recall_score_stage', {
            'stage1': recallscore_stage1,
            'stage2': recallscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg,
        'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }
Example #3
0
def val(model1, model2, dataParser, epoch):
    # 读取数据的迭代器

    val_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]
    # switch to train mode
    model1.eval()
    model2.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(False):
            images.requires_grad = False
            # 网络输出
            one_stage_outputs = model1(images)

            rgb_pred = images * one_stage_outputs[0]
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1],
                                       one_stage_outputs[2],
                                       one_stage_outputs[3])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_band)
            ##############################################
            # deal with two stage issues
            _loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                                labels_dou_edge) * 12

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
                loss_8t += one_loss_t
                map8_loss_value[c_index].update(one_loss_t.item())

            _loss_stage_2 += loss_8t
            loss_stage_2 = _loss_stage_2 / 20
            loss = (loss_stage_1 + loss_stage_2) / 2

            #######################################
            # 总的LOSS
            writer.add_scalar('val_stage_one_loss',
                              loss_stage_1.item(),
                              global_step=epoch * val_epoch + batch_index)
            writer.add_scalar('val_stage_two_fuse_loss',
                              loss_stage_2.item(),
                              global_step=epoch * val_epoch + batch_index)
            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * val_epoch + batch_index)
            ##########################################

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())

        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        precisionscore_stage2 = my_precision_score(two_stage_outputs[0],
                                                   labels_dou_edge)
        accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        recallscore_stage2 = my_recall_score(two_stage_outputs[0],
                                             labels_dou_edge)

        f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        precisionscore_stage1 = my_precision_score(one_stage_outputs[0],
                                                   labels_band)
        accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        writer.add_scalar('val_f1_score_stage1',
                          f1score_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score_stage1',
                          precisionscore_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score_stage1',
                          accscore_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score_stage1',
                          recallscore_stage1,
                          global_step=epoch * val_epoch + batch_index)

        writer.add_scalar('val_f1_score_stage2',
                          f1score_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score_stage2',
                          precisionscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score_stage2',
                          accscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score_stage2',
                          recallscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, val_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= val_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg,
        'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }
Example #4
0
def train(model, optimizer, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to train mode
    model.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels = input_data['gt_band'].cuda()
        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer.zero_grad()
            # 网络输出
            outputs = model(images)[0]
            # 这里放保存中间结果的代码
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""

            loss = wce_dice_huber_loss(outputs, labels)
            writer.add_scalar('loss_per_batch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

            loss.backward()
            optimizer.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs, labels)
        precisionscore = my_precision_score(outputs, labels)
        accscore = my_acc_score(outputs, labels)
        recallscore = my_recall_score(outputs, labels)

        writer.add_scalar('f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
Example #5
0
def val(model, dataParser, epoch):
    # 读取数据的迭代器
    val_epoch = len(dataParser)

    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)

        images = input_data['tamper_image']
        labels = input_data['gt_band']

        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()

        # 对读取的numpy类型数据进行调整

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())
        # 网络输出
        outputs = model(images)[0]
        # 这里放保存中间结果的代码
        if args.save_mid_result:
            if batch_index in args.mid_result_index:
                save_mid_result(outputs,
                                labels,
                                epoch,
                                batch_index,
                                args.mid_result_root,
                                save_8map=True,
                                train_phase=True)
            else:
                pass
        else:
            pass
        """""" """""" """""" """""" """"""
        "         Loss 函数           "
        """""" """""" """""" """""" """"""
        loss = wce_dice_huber_loss(outputs, labels)
        writer.add_scalar('val_fuse_loss_per_epoch',
                          loss.item(),
                          global_step=epoch * val_epoch + batch_index)
        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs, labels)
        precisionscore = my_precision_score(outputs, labels)
        accscore = my_acc_score(outputs, labels)
        recallscore = my_recall_score(outputs, labels)

        writer.add_scalar('val_f1_score',
                          f1score,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score',
                          precisionscore,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score',
                          recallscore,
                          global_step=epoch * val_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, val_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'vla_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'val_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'val_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'val_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'val_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= val_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
def train(model1, model2, optimizer1, optimizer2, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to train mode
    model1.train()
    model2.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            if images.shape[1] != 3 or images.shape[2] != 320:
                continue
            # 网络输出
            one_stage_outputs = model1(images)

            zero = torch.zeros_like(one_stage_outputs[0])
            one = torch.ones_like(one_stage_outputs[0])

            rgb_pred = images * torch.where(one_stage_outputs[0] > 0.1, one,
                                            zero)
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[9],
                                       one_stage_outputs[10],
                                       one_stage_outputs[11])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            _loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                                labels_band)
            loss_stage_1 = _loss_stage_1
            ##############################################
            # deal with two stage issues
            _loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                                labels_dou_edge) * 12

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = cross_entropy_loss(c,
                                                relation_map[c_index].cuda())
                loss_8t += one_loss_t
                writer.add_scalar('stage2_%d_map_loss' % (c_index),
                                  one_loss_t.item(),
                                  global_step=epoch * train_epoch +
                                  batch_index)

            _loss_stage_2 += loss_8t
            loss_stage_2 = _loss_stage_2 / 20
            loss = (loss_stage_1 + loss_stage_2) / 2

            #######################################
            # 总的LOSS
            writer.add_scalar('stage_one_loss',
                              loss_stage_1.item(),
                              global_step=epoch * train_epoch + batch_index)
            writer.add_scalar('stage_two_pred_loss',
                              _loss_stage_2.item(),
                              global_step=epoch * train_epoch + batch_index)
            writer.add_scalar('stage_two_fuse_loss',
                              loss_stage_2.item(),
                              global_step=epoch * train_epoch + batch_index)

            writer.add_scalar('fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)
            ##########################################

            loss.backward()
            optimizer1.step()
            optimizer2.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())

        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        precisionscore_stage2 = my_precision_score(two_stage_outputs[0],
                                                   labels_dou_edge)
        accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        recallscore_stage2 = my_recall_score(two_stage_outputs[0],
                                             labels_dou_edge)

        f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        precisionscore_stage1 = my_precision_score(one_stage_outputs[0],
                                                   labels_band)
        accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        writer.add_scalar('f1_score_stage1',
                          f1score_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score_stage1',
                          precisionscore_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score_stage1',
                          accscore_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score_stage1',
                          recallscore_stage1,
                          global_step=epoch * train_epoch + batch_index)

        writer.add_scalar('f1_score_stage2',
                          f1score_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score_stage2',
                          precisionscore_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score_stage2',
                          accscore_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score_stage2',
                          recallscore_stage2,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) +\
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) +\
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg
    }
Example #7
0
def val(model, dataParser, epoch):
    # 读取数据的迭代器
    train_epoch = int(dataParser.val_steps)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, (images, labels_numpy) in enumerate(
            generate_minibatches(dataParser, False)):
        # 读取数据的时间
        data_time.update(time.time() - end)

        # 对读取的numpy类型数据进行调整
        labels = []
        if torch.cuda.is_available():
            images = torch.from_numpy(images).cuda()
            for item in labels_numpy:
                labels.append(torch.from_numpy(item).cuda())
        else:
            images = torch.from_numpy(images)
            for item in labels_numpy:
                labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        # 网络输出
        outputs = model(images)
        # 这里放保存中间结果的代码
        if args.save_mid_result:
            if batch_index in args.mid_result_index:
                save_mid_result(outputs,
                                labels,
                                epoch,
                                batch_index,
                                args.mid_result_root,
                                save_8map=True,
                                train_phase=True)
            else:
                pass
        else:
            pass
        """""" """""" """""" """""" """"""
        "         Loss 函数           "
        """""" """""" """""" """""" """"""

        if not args.band_mode:
            # 如果不是使用band_mode 则需要计算8张图的loss
            loss = wce_dice_huber_loss(outputs[0],
                                       labels[0]) * args.fuse_loss_weight

            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item() / args.fuse_loss_weight,
                              global_step=epoch * train_epoch + batch_index)

            for c_index, c in enumerate(outputs[1:]):
                one_loss_t = wce_dice_huber_loss(c, labels[c_index + 1])
                loss_8t += one_loss_t
                writer.add_scalar('val_%d_map_loss' % (c_index),
                                  one_loss_t.item(),
                                  global_step=train_epoch)
            loss += loss_8t
            loss = loss / 20
        else:
            loss = wce_dice_huber_loss(outputs[0], labels[0])
            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels[0])
        precisionscore = my_precision_score(outputs[0], labels[0])
        accscore = my_acc_score(outputs[0], labels[0])
        recallscore = my_recall_score(outputs[0], labels[0])

        writer.add_scalar('val_f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.val_steps) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'vla_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'val_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'val_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'val_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'val_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
def train(model1, optimizer1, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]

    # switch to train mode
    model1.train()
    end = time.time()
    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            if images.shape[1] != 3 or images.shape[2] != 320:
                continue
            # 网络输出

            try:
                one_stage_outputs = model1(images)
            except Exception as e:
                print(e)
                print(images.shape)
                continue
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_dou_edge)
            ##############################################
            # deal with two stage issues
            # for c_index, c in enumerate(two_stage_outputs[1:9]):
            #     one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
            #     loss_8t += one_loss_t
            #     # print(one_loss_t)
            #     map8_loss_value[c_index].update(one_loss_t.item())

            # print(loss_stage_2)
            # print(map8_loss_value)
            # loss = (loss_stage_2 * 12 + loss_8t) / 20
            loss = loss_stage_1
            #######################################
            # 总的LOSS
            # print(type(loss_stage_2.item()))
            writer.add_scalars('loss_gather', {
                'stage_one_loss': loss_stage_1.item(),
            },
                               global_step=epoch * train_epoch + batch_index)
            ##########################################
            loss.backward()

            optimizer1.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1)
            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        # 'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }
def train(model, optimizer, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to train mode
    model.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels = input_data['gt_band'].cuda()

        # 对读取的numpy类型数据进行调整
        # labels = []
        # if torch.cuda.is_available():
        #     images = torch.from_numpy(images).cuda()
        #     for item in labels_numpy:
        #         labels.append(torch.from_numpy(item).cuda())
        # else:
        #     images = torch.from_numpy(images)
        #     for item in labels_numpy:
        #         labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer.zero_grad()
            # 网络输出
            outputs = model(images)
            # 这里放保存中间结果的代码
            if args.save_mid_result:
                if batch_index in args.mid_result_index:
                    save_mid_result(outputs,
                                    labels,
                                    epoch,
                                    batch_index,
                                    args.mid_result_root,
                                    save_8map=True,
                                    train_phase=True)
                else:
                    pass
            else:
                pass
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""

            # if not args.band_mode:
            #     # 如果不是使用band_mode 则需要计算8张图的loss
            #     loss = wce_dice_huber_loss(outputs[0], labels[0]) * args.fuse_loss_weight
            #
            #     writer.add_scalar('fuse_loss_per_epoch', loss.item() / args.fuse_loss_weight,
            #                       global_step=epoch * train_epoch + batch_index)
            #
            #     for c_index, c in enumerate(outputs[1:]):
            #         one_loss_t = wce_dice_huber_loss(c, labels[c_index + 1])
            #         loss_8t += one_loss_t
            #         writer.add_scalar('%d_map_loss' % (c_index), one_loss_t.item(), global_step=train_epoch)
            #     loss += loss_8t
            #     loss = loss / 20
            loss = wce_dice_huber_loss(outputs[0], labels)
            writer.add_scalar('fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

            loss.backward()
            optimizer.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels)
        precisionscore = my_precision_score(outputs[0], labels)
        accscore = my_acc_score(outputs[0], labels)
        recallscore = my_recall_score(outputs[0], labels)

        writer.add_scalar('f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
Example #10
0
def test(model, dataParser, epoch):
    # 读取数据的迭代器
    test_epoch = len(dataParser)

    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)

        images = input_data['tamper_image']
        labels = input_data['gt_band']

        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        # 网络输出
        try:
            outputs = model(images)[0]
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""

            loss = wce_dice_huber_loss(outputs, labels)
        except Exception as e:
            continue
        writer.add_scalar('val_fuse_loss_per_epoch',
                          loss.item(),
                          global_step=epoch * test_epoch + batch_index)

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs, labels)
        precisionscore = my_precision_score(outputs, labels)
        accscore = my_acc_score(outputs, labels)
        recallscore = my_recall_score(outputs, labels)

        writer.add_scalar('test_f1_score',
                          f1score,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_scalar('test_precision_score',
                          precisionscore,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_scalar('test_acc_score',
                          accscore,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_scalar('test_recall_score',
                          recallscore,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_images('test_image_batch:%d' % (batch_index),
                          outputs,
                          global_step=epoch)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'TEST_Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, test_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'test_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'test_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'test_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value) + \
                   'test_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'test_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= test_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }