def test(model1, model2, dataParser, epoch):
    # 读取数据的迭代器
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()

    # switch to train mode
    model1.eval()
    model2.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        with torch.set_grad_enabled(False):
            images.requires_grad = False
            # 网络输出
            one_stage_outputs = model1(images)

            rgb_pred_rgb = torch.cat((one_stage_outputs[0], images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1],
                                       one_stage_outputs[2],
                                       one_stage_outputs[3])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            z = torch.cat((one_stage_outputs[0], two_stage_outputs[0]), 0)
            writer.add_image('one&two_stage_image_batch:%d' % (batch_index),
                             make_grid(z, nrow=2),
                             global_step=epoch)
def test(model1, dataParser, epoch):
    # 读取数据的迭代器
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()

    # switch to train mode
    model1.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # 准备输入数据

        images = input_data['tamper_image'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        # relation_map = input_data['relation_map']
        with torch.set_grad_enabled(False):
            images.requires_grad = False
            # 网络输出
            one_stage_outputs = model1(images)
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            writer.add_images('one&two_stage_image_batch:%d' % (batch_index),
                              one_stage_outputs[0],
                              global_step=epoch)
示例#3
0
def train(train_loader, model, optimizer, epoch, save_dir):
    batch_time = Averagvalue()
    losses = Averagvalue()
    # switch to train mode
    model.train()
    end = time.time()
    epoch_loss = []
    counter = 0
    #params = list(model.parameters())

    for i, (image, label) in enumerate(train_loader):
        # check whether data is valid
        if not isdir(save_dir):
            os.makedirs(save_dir)
        if image.shape[2] == 100:
            # measure data loading time
            image, label = image.cuda(), label.cuda()
            outputs = model(image)
            loss = torch.zeros(1).cuda()
            for o in outputs:
                loss = loss + cross_entropy_loss(o, label)
            counter += 1
            loss = loss / args.itersize
            loss.backward()
            if counter == args.itersize:
                optimizer.step()
                optimizer.zero_grad()
                counter = 0
            # measure accuracy and record loss
            losses.update(loss.item(), image.size(0))
            epoch_loss.append(loss.item())
            batch_time.update(time.time() - end)
            end = time.time()
            # display and logging
            if i % args.print_freq == 0:
                info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, i, len(train_loader)) + \
                       'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                       'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses)
                print(info)
                outputs.append(label)
                _, _, H, W = outputs[0].shape
                all_results = torch.zeros((len(outputs), 1, H, W))
                for j in range(len(outputs)):
                    all_results[j, 0, :, :] = outputs[j][0, 0, :, :]
                torchvision.utils.save_image(all_results,
                                             join(save_dir, "iter-%d.jpg" % i))
        # save checkpoint
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        },
        filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch))

    return losses.avg, epoch_loss
示例#4
0
def test_luma(model):
    psnr_before = Averagvalue()
    psnr_after = Averagvalue()
    for i, (image, label, qpmap) in enumerate(testloader):
        image, label, qpmap = image.cuda(), label.cuda(), qpmap.cuda()
        outputs = model(torch.cat([image, qpmap], 1))
        psnr_1 = psnr(F.mse_loss(image, label).item())
        psnr_2 = psnr(F.mse_loss(outputs, label).item())

        info = '[{}]'.format(i) + 'PSNR from {:.4f} to {:.4f}'.format(
            psnr_1, psnr_2) + ' Delta:{:.4f}'.format(psnr_2 - psnr_1)
        psnr_before.update(psnr_1)
        psnr_after.update(psnr_2)
    return psnr_after.avg - psnr_before.avg
示例#5
0
文件: test.py 项目: Linkeyboard/NRCNN
def test_chroma(model, testloader):
    psnr_before = Averagvalue()
    psnr_after = Averagvalue()
    for i, (luma, chroma_rec, chroma_pad, chroma_gd) in enumerate(testloader):
        luma, chroma_rec, chroma_pad, chroma_gd = luma.cuda(), chroma_rec.cuda(), chroma_pad.cuda(), chroma_gd.cuda()
        outputs = model(luma, chroma_pad)
        losslist = []
        losslist.append(F.mse_loss(chroma_rec, chroma_gd).item())
        loss = F.mse_loss(outputs, chroma_gd - chroma_rec)
        losslist.append(loss.item())

        info = '[{}]'.format(i) + 'PSNR from {:.4f} to {:.4f}'.format(psnr(losslist[0]), psnr(losslist[-1])) + ' Delta:{:.4f}'.format(psnr(losslist[-1])- psnr(losslist[0]))
        psnr_before.update(psnr(losslist[0]))
        psnr_after.update(psnr(losslist[-1]))

    #print('PSNR from {:.4f} to {:.4f}'.format(psnr_before.avg, psnr_after.avg))
    return psnr_after.avg - psnr_before.avg
示例#6
0
文件: test.py 项目: Linkeyboard/NRCNN
def test_luma(model, testloader):
    psnr_before = Averagvalue()
    psnr_after = Averagvalue()
    for i, (image, label) in enumerate(testloader):
        image, label = image.cuda(), label.cuda()
        outputs = model(image)
        losslist = []
        losslist.append(F.mse_loss(label, image).item())
        loss = F.mse_loss(outputs, label)
        losslist.append(loss.item())

        info = '[{}]'.format(i) + 'PSNR from {:.4f} to {:.4f}'.format(psnr(losslist[0]), psnr(losslist[-1])) + ' Delta:{:.4f}'.format(psnr(losslist[-1])- psnr(losslist[0]))
        psnr_before.update(psnr(losslist[0]))
        psnr_after.update(psnr(losslist[-1]))

    #print('PSNR from {:.4f} to {:.4f}'.format(psnr_before.avg, psnr_after.avg))
    return psnr_after.avg - psnr_before.avg
示例#7
0
def test_chroma(model):
    psnr_before = Averagvalue()
    psnr_after = Averagvalue()
    for i, (luma, chroma_rec, chroma_en, chroma_gd,
            qpmap) in enumerate(trainloader):
        luma, chroma_rec, chroma_en, chroma_gd, qpmap = luma.cuda(
        ), chroma_rec.cuda(), chroma_en.cuda(), chroma_gd.cuda(), qpmap.cuda()
        outputs = model(torch.cat([chroma_en, qpmap], 1), luma)

        psnr_1 = psnr(F.mse_loss(chroma_rec, chroma_gd).item())
        psnr_2 = psnr(F.mse_loss(outputs, chroma_gd - chroma_rec).item())

        info = '[{}]'.format(i) + 'PSNR from {:.4f} to {:.4f}'.format(
            psnr_1, psnr_2) + ' Delta:{:.4f}'.format(psnr_2 - psnr_1)
        psnr_before.update(psnr_1)
        psnr_after.update(psnr_2)

    return psnr_after.avg - psnr_before.avg
示例#8
0
def train(trainloader, model, optimizer, epoch, save_dir):
    global_step = epoch * len(trainloader) // args.print_freq
    batch_time = Averagvalue()
    data_time = Averagvalue()
    loss_list = Averagvalue()
    model.train()
    for i, (image, label, qpmap) in enumerate(trainloader):
        image, label, qpmap = image.cuda(), label.cuda(), qpmap.cuda()
        outputs = model(torch.cat([image, qpmap], 1))

        psnr_1 = psnr(F.mse_loss(image, label).item())
        psnr_2 = psnr(F.mse_loss(outputs, label).item())

        loss = F.mse_loss(outputs, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_list.update(loss.item(), image.size(0))

        if i % args.print_freq == args.print_freq - 1:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, i, len(trainloader)) + \
                    'Loss {loss.val:f} (avg:{loss.avg:f})'.format(loss = loss_list) + ' PSNR {:.4f}'.format(psnr_2 - psnr_1)
            print(info)

            writer = SummaryWriter(args.show_path)
            writer.add_scalar('scalar/loss', loss_list.avg, global_step)
            delta_psnr = test_luma(model)
            writer.add_scalar('scalar/psnr', delta_psnr, global_step)
            writer.close()

            global_step += 1
            loss_list.reset()

    if not isdir(save_dir):
        os.makedirs(save_dir)
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        },
        filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch))
示例#9
0
def val(model, dataParser, epoch):
    # 读取数据的迭代器
    val_epoch = len(dataParser)

    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)

        images = input_data['tamper_image']
        labels = input_data['gt_band']

        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()

        # 对读取的numpy类型数据进行调整

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())
        # 网络输出
        outputs = model(images)[0]
        # 这里放保存中间结果的代码
        if args.save_mid_result:
            if batch_index in args.mid_result_index:
                save_mid_result(outputs,
                                labels,
                                epoch,
                                batch_index,
                                args.mid_result_root,
                                save_8map=True,
                                train_phase=True)
            else:
                pass
        else:
            pass
        """""" """""" """""" """""" """"""
        "         Loss 函数           "
        """""" """""" """""" """""" """"""
        loss = wce_dice_huber_loss(outputs, labels)
        writer.add_scalar('val_fuse_loss_per_epoch',
                          loss.item(),
                          global_step=epoch * val_epoch + batch_index)
        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs, labels)
        precisionscore = my_precision_score(outputs, labels)
        accscore = my_acc_score(outputs, labels)
        recallscore = my_recall_score(outputs, labels)

        writer.add_scalar('val_f1_score',
                          f1score,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score',
                          precisionscore,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score',
                          recallscore,
                          global_step=epoch * val_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, val_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'vla_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'val_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'val_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'val_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'val_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= val_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
示例#10
0
def train(model, optimizer, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to train mode
    model.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels = input_data['gt_band'].cuda()
        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer.zero_grad()
            # 网络输出
            outputs = model(images)[0]
            # 这里放保存中间结果的代码
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""

            loss = wce_dice_huber_loss(outputs, labels)
            writer.add_scalar('loss_per_batch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

            loss.backward()
            optimizer.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs, labels)
        precisionscore = my_precision_score(outputs, labels)
        accscore = my_acc_score(outputs, labels)
        recallscore = my_recall_score(outputs, labels)

        writer.add_scalar('f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
示例#11
0
def train(train_loader, model, optimizer, epoch, save_dir):

    adversary = PGD(epsilon=args.epsilon,
                    num_steps=args.num_steps,
                    step_size=args.step_size).cuda()

    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()

    # switch to train mode
    model.train()
    end = time.time()
    epoch_loss = []
    counter = 0
    for i, (image, label) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        image, label = image.cuda(), label.cuda()

        outputs_clean = model(image)
        image_adv = adversary(model, image, label)
        outputs = model(image_adv)

        loss = torch.zeros(1).cuda()
        for o in outputs:
            loss = loss + cross_entropy_loss_RCF(o, label)
        counter += 1

        loss = loss / args.itersize
        loss.backward()

        if counter == args.itersize:
            optimizer.step()
            optimizer.zero_grad()
            counter = 0

        # measure accuracy and record loss
        losses.update(loss.item(), image.size(0))
        epoch_loss.append(loss.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # display and logging
        if not isdir(save_dir):
            os.makedirs(save_dir)

        if i % args.print_freq == 0:
            loss_100 = epoch_loss[-100:]
            loss_100 = sum(loss_100) / len(loss_100)
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, i, len(train_loader)) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss avg:{loss.avg:f}, Last 100 losses avg = {loss_100}'.format(loss=losses, loss_100=loss_100)

            print(info)

            # Save output from model
            label_out = label.float()
            save_outputs = [outputs[-1], label_out, outputs_clean[-1]]
            _, _, H, W = save_outputs[0].shape
            all_results = torch.zeros((len(save_outputs), 1, H, W))
            for j in range(len(save_outputs)):
                all_results[j, 0, :, :] = save_outputs[j][0, 0, :, :]
            torchvision.utils.save_image(1 - all_results,
                                         join(save_dir,
                                              "{0}-edges.jpg".format(i)),
                                         nrow=4)

            # Save adversarial iamge
            torchvision.utils.save_image(
                unnormalize(image_adv),
                join(save_dir, "{0}-adversarial.jpg".format(i)))

            # Save standard iamge
            torchvision.utils.save_image(
                unnormalize(image), join(save_dir,
                                         "{0}-standard.jpg".format(i)))

            # Save checkpoint
            save_file = os.path.join(TMP_DIR, 'checkpoint.pth')
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                filename=save_file)

    # save checkpoint
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        },
        filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch))

    return losses.avg, epoch_loss
示例#12
0
文件: train.py 项目: WHUHLX/CATS
#!/user/bin/python
示例#13
0
def train(train_loader, model, optimizer, epoch, save_dir):
    batch_time = Averagvalue()
    losses = Averagvalue()
    # switch to train mode
    model.train()
    end = time.time()
    epoch_loss = []
    counter = 0
    #params = list(model.parameters())

    for i, (image, label) in enumerate(train_loader):
        # check whether data is valid
        if not isdir(save_dir):
            os.makedirs(save_dir)
        if image.shape[2] == 100:
            # measure data loading time
            image, label = image.cuda(), label.double().cuda()
            outputs = model(image)
            loss = torch.zeros(1, dtype=torch.double).cuda()
            #loss2 = 0
            w_loss = WassersteinLoss.apply

            label_clone = label.clone()
            label = label / torch.sum(label)
            label = label.reshape(1, -1)

            for _, prediction in enumerate(outputs):
                prediction = prediction.squeeze()

                if torch.all(prediction == 0):
                    print('Error! prediction all zeros!')
                    pass
                prediction = prediction / torch.sum(prediction)
                prediction = prediction.reshape(1, -1)
                loss = loss + w_loss(prediction, label, ground_metric,
                                     args.reg)
                #loss2 += W_loss_lp(sq_prediction2, sq_label, pre_idx, lab_idx)
            counter += 1
            loss = loss / args.itersize
            if loss != 0:
                loss.backward()
            if counter == args.itersize:
                optimizer.step()
                optimizer.zero_grad()
                counter = 0
            # measure accuracy and record loss
            losses.update(loss.item(), image.size(0))
            epoch_loss.append(loss.item())
            batch_time.update(time.time() - end)
            end = time.time()
            # display and logging
            if i % args.print_freq == 0:
                info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, i, len(train_loader)) + \
                       'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                       'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses)
                print(info)
                outputs.append(label_clone)
                _, _, H, W = outputs[0].shape
                all_results = torch.zeros((len(outputs), 1, H, W))
                for j in range(len(outputs)):
                    all_results[j, 0, :, :] = outputs[j][0, 0, :, :]
                torchvision.utils.save_image(all_results,
                                             join(save_dir, "iter-%d.jpg" % i))
        # save checkpoint
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        },
        filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch))

    return losses.avg, epoch_loss
示例#14
0
def train(model1, model2, optimizer1, optimizer2, dataParser, epoch, save_dir):
    # 读取数据的迭代器
    train_epoch = int(dataParser.steps_per_epoch)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    fuse_loss = Averagvalue
    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()
    map8_loss_value_stage1 = Averagvalue()
    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value_stage2 = Averagvalue()
    loss_8 = Averagvalue()

    stage1_loss = Averagvalue()
    stage1_pred_loss = Averagvalue()
    stage2_loss = Averagvalue()
    stage2_pred_loss = Averagvalue()
    ###############################

    mid_freq = 100

    # switch to train mode
    model1.train()
    model2.train()
    end = time.time()
    epoch_loss = []

    for batch_index, (images, labels_numpy) in enumerate(
            generate_minibatches(dataParser, True)):
        # 读取数据的时间
        data_time.update(time.time() - end)

        # 对读取的numpy类型数据进行调整
        labels = []
        if torch.cuda.is_available():
            images = torch.from_numpy(images).cuda()
            for item in labels_numpy:
                labels.append(torch.from_numpy(item).cuda())
        else:
            images = torch.from_numpy(images)
            for item in labels_numpy:
                labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_stage_2 = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_stage_2 = torch.zeros(1).cuda()
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            one_stage_outputs = model1(images)

            zero = torch.zeros_like(one_stage_outputs[0])
            one = torch.ones_like(one_stage_outputs[0])

            rgb_pred = images * torch.where(one_stage_outputs[0] > 0.1, one,
                                            zero)
            _rgb_pred = rgb_pred.cpu()
            _rgb_pred = _rgb_pred.detach().numpy()
            if abs(batch_index - dataParser.steps_per_epoch) % mid_freq == 0:
                for i in range(args.batch_size):
                    t = _rgb_pred[i, :, :, :]
                    t = t * 255
                    t = np.array(t, dtype='uint8')
                    t = t.transpose(1, 2, 0)
                    t = Image.fromarray(t)
                    t.save(
                        'mid_result/two_stage_input/two_stage_the_midinput_%d_%d.png'
                        % (epoch, batch_index))

            _rgb = images.cpu()
            _rgb = _rgb.detach().numpy()
            if abs(batch_index - dataParser.steps_per_epoch) % mid_freq == 0:
                for i in range(args.batch_size):
                    t = _rgb[i, :, :, :]
                    t = t * 255
                    t = np.array(t, dtype='uint8')
                    t = t.transpose(1, 2, 0)
                    t = Image.fromarray(t)
                    t.save(
                        'mid_result/one_stage_input/two_stage_the_midinput_%d_%d.png'
                        % (epoch, batch_index))
            two_stage_outputs = model2(rgb_pred, one_stage_outputs[9],
                                       one_stage_outputs[10],
                                       one_stage_outputs[11])

            ##########################################
            # deal with one stage issue
            # 建立loss
            _loss_stage_1 = wce_huber_loss(one_stage_outputs[0], labels[1])

            if False:
                loss_stage_1 *= 12
                for c_index, c in enumerate(one_stage_outputs[1:9]):
                    one_loss_t = wce_huber_loss((c, labels[c_index + 2]))
                    loss_8t += one_loss_t
                    writer.add_scalar('stage1_%d_map_loss' % (c_index),
                                      one_loss_t.item(),
                                      global_step=epoch * train_epoch +
                                      batch_index)

                loss_stage_1 += loss_8t
                loss_8t = torch.zeros(())
                loss_stage_1 = loss_stage_1 / 20
            else:
                loss_stage_1 = _loss_stage_1
            ##############################################
            # deal with two stage issues
            _loss_stage_2 = wce_huber_loss(two_stage_outputs[0],
                                           labels[0]) * 12

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = wce_huber_loss(c, labels[c_index + 2])
                loss_8t += one_loss_t
                writer.add_scalar('stage2_%d_map_loss' % (c_index),
                                  one_loss_t.item(),
                                  global_step=epoch * train_epoch +
                                  batch_index)

            loss_stage_2 += loss_8t
            loss_stage_2 = _loss_stage_2 / 20

            #######################################
            # 总的LOSS
            writer.add_scalar('stage_one_loss',
                              loss_stage_1.item(),
                              global_step=epoch * train_epoch + batch_index)
            writer.add_scalar('stage_two_pred_loss',
                              _loss_stage_2.item(),
                              global_step=epoch * train_epoch + batch_index)
            writer.add_scalar('stage_two_fuse_loss',
                              loss_stage_2.item(),
                              global_step=epoch * train_epoch + batch_index)
            loss = (loss_stage_1 + loss_stage_2) / 2
            writer.add_scalar('fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)
            ##########################################

            _output = two_stage_outputs[0].cpu()
            _output = _output.detach().numpy()
            if abs(batch_index - dataParser.steps_per_epoch) % mid_freq == 0:
                for i in range(args.batch_size):
                    t = _output[i, :, :]
                    t = np.squeeze(t, 0)
                    t = t * 255
                    t = np.array(t, dtype='uint8')
                    t = Image.fromarray(t)

                    t.save(
                        'mid_result/two_stage_output/two_stage_the_midoutput_%d_%d.png'
                        % (epoch, batch_index))
            _output = one_stage_outputs[0].cpu()
            _output = _output.detach().numpy()
            if abs(batch_index - dataParser.steps_per_epoch) % mid_freq == 0:
                for i in range(args.batch_size):
                    t = _output[i, :, :]
                    t = np.squeeze(t, 0)
                    t = t * 255
                    t = np.array(t, dtype='uint8')
                    t = Image.fromarray(t)

                    t.save(
                        'mid_result/one_stage_output/one_stage_the_midoutput_%d_%d.png'
                        % (epoch, batch_index))
            # 这里放保存中间结1果的代码
            # if batch_index in args.mid_result_index:
            #     save_mid_result(outputs, labels, epoch, batch_index, args.mid_result_root,save_8map=True,train_phase=True)

            loss.backward()
            optimizer1.step()
            optimizer2.step()

        # measure the accuracy and record loss
        losses.update(loss.item())
        map8_loss_value_stage1.update(loss_8t.item())
        # epoch_loss.append(loss.item())
        batch_time.update(time.time() - end)
        end = time.time()
        stage1_loss.update(loss_stage_1.item())
        stage2_loss.update(loss_stage_2.item())
        stage1_pred_loss.update(_loss_stage_1.item())
        stage2_pred_loss.update(_loss_stage_2.item())
        ##############################################
        # 评价指标
        # stage 1
        f1score_stage1 = my_f1_score(one_stage_outputs[0], labels[1])
        precisionscore_stage1 = my_precision_score(one_stage_outputs[0],
                                                   labels[1])
        accscore_stage1 = my_acc_score(one_stage_outputs[0], labels[1])
        recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels[1])
        # stage 2
        f1score_stage2 = my_f1_score(two_stage_outputs[0], labels[0])
        precisionscore_stage2 = my_precision_score(two_stage_outputs[0],
                                                   labels[0])
        accscore_stage2 = my_acc_score(two_stage_outputs[0], labels[0])
        recallscore_stage2 = my_recall_score(two_stage_outputs[0], labels[0])
        #################################################
        writer.add_scalar('f1_score_stage1',
                          f1score_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score_stage1',
                          f1score_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score_stage1',
                          f1score_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score_stage1',
                          f1score_stage1,
                          global_step=epoch * train_epoch + batch_index)

        writer.add_scalar('f1_score_stage2',
                          f1score_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score_stage2',
                          f1score_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score_stage2',
                          f1score_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score_stage2',
                          f1score_stage2,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.steps_per_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) +\
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) +\
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        # 对于每一个epoch内按照一定的频率保存评价指标,以观察震荡情况
        # if batch_index % args.per_epoch_freq == 0:
        #     writer.add_scalar('tr_loss_per_epoch', losses.val, global_step=epoch * train_epoch + batch_index)
        #     writer.add_scalar('f1_score_per_epoch', f1score, global_step=epoch * train_epoch + batch_index)
        #     writer.add_scalar('precision_score_per_epoch', precisionscore,
        #                       global_step=epoch * train_epoch + batch_index)
        #     writer.add_scalar('acc_score_per_epoch', accscore, global_step=epoch * train_epoch + batch_index)
        #     writer.add_scalar('recall_score_per_epoch',recallscore,global_step=epoch * train_epoch + batch_index)
        #
        if batch_index >= train_epoch:
            break
    # 保存模型
    if epoch % 1 == 0:
        save_file = os.path.join(
            args.model_save_dir,
            '1104_stage1_checkpoint_epoch{}.pth'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model1.state_dict(),
                'optimizer': optimizer1.state_dict()
            },
            filename=save_file)
        save_file = os.path.join(
            args.model_save_dir,
            '1104_stage2_checkpoint_epoch{}.pth'.format(epoch))
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model2.state_dict(),
                'optimizer': optimizer2.state_dict()
            },
            filename=save_file)
    # save_checkpoint({
    #     'epoch': epoch,
    #     'state_dict': model1.state_dict(),
    #     'optimizer': optimizer1.state_dict()
    # }, filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch))

    return losses.avg, stage1_loss.avg, stage2_loss.avg, f1_value_stage1.avg, f1_value_stage2.avg
示例#15
0
def train(model1, model2, optimizer1, optimizer2, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]

    # switch to train mode
    model1.train()
    model2.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            if images.shape[1] != 3 or images.shape[2] != 320:
                continue
            # 网络输出

            try:
                one_stage_outputs = model1(images)
            except Exception as e:
                print(e)
                print(images.shape)
                continue

            rgb_pred = images * one_stage_outputs[0]
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1],
                                       one_stage_outputs[2],
                                       one_stage_outputs[3])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_band)
            ##############################################
            # deal with two stage issues
            loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                               labels_dou_edge)

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
                loss_8t += one_loss_t
                # print(one_loss_t)
                map8_loss_value[c_index].update(one_loss_t.item())

            # print(loss_stage_2)
            # print(map8_loss_value)
            loss = loss_stage_2 + loss_8t * 10
            #######################################
            # 总的LOSS
            # print(type(loss_stage_2.item()))
            writer.add_scalars('loss_gather', {
                'stage_one_loss': loss_stage_1.item(),
                'stage_two_fuse_loss': loss_stage_2.item()
            },
                               global_step=epoch * train_epoch + batch_index)
            ##########################################
            loss.backward()

            optimizer1.step()
            optimizer2.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        # f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        # precisionscore_stage2 = my_precision_score(two_stage_outputs[0], labels_dou_edge)
        # accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        # recallscore_stage2 = my_recall_score(two_stage_outputs[0], labels_dou_edge)

        f1score_stage2 = 1
        precisionscore_stage2 = 1
        accscore_stage2 = 1
        recallscore_stage2 = 1
        #
        # f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        # precisionscore_stage1 = my_precision_score(one_stage_outputs[0], labels_band)
        # accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        # recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        f1score_stage1 = 1
        precisionscore_stage1 = 1
        accscore_stage1 = 1
        recallscore_stage1 = 1
        writer.add_scalars('f1_score_stage', {
            'stage1': f1score_stage1,
            'stage2': f1score_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('precision_score_stage', {
            'stage1': precisionscore_stage1,
            'stage2': precisionscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('acc_score_stage', {
            'stage1': accscore_stage1,
            'stage2': accscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        writer.add_scalars('recall_score_stage', {
            'stage1': recallscore_stage1,
            'stage2': recallscore_stage2
        },
                           global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg,
        'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }
示例#16
0
def train(model1, optimizer1, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]

    # switch to train mode
    model1.train()
    end = time.time()
    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            if images.shape[1] != 3 or images.shape[2] != 320:
                continue
            # 网络输出

            try:
                one_stage_outputs = model1(images)
            except Exception as e:
                print(e)
                print(images.shape)
                continue
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_dou_edge)
            ##############################################
            # deal with two stage issues
            # for c_index, c in enumerate(two_stage_outputs[1:9]):
            #     one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
            #     loss_8t += one_loss_t
            #     # print(one_loss_t)
            #     map8_loss_value[c_index].update(one_loss_t.item())

            # print(loss_stage_2)
            # print(map8_loss_value)
            # loss = (loss_stage_2 * 12 + loss_8t) / 20
            loss = loss_stage_1
            #######################################
            # 总的LOSS
            # print(type(loss_stage_2.item()))
            writer.add_scalars('loss_gather', {
                'stage_one_loss': loss_stage_1.item(),
            },
                               global_step=epoch * train_epoch + batch_index)
            ##########################################
            loss.backward()

            optimizer1.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1)
            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        # 'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }
示例#17
0
def train(model, optimizer, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to train mode
    model.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels = input_data['gt_band'].cuda()

        # 对读取的numpy类型数据进行调整
        # labels = []
        # if torch.cuda.is_available():
        #     images = torch.from_numpy(images).cuda()
        #     for item in labels_numpy:
        #         labels.append(torch.from_numpy(item).cuda())
        # else:
        #     images = torch.from_numpy(images)
        #     for item in labels_numpy:
        #         labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer.zero_grad()
            # 网络输出
            outputs = model(images)
            # 这里放保存中间结果的代码
            if args.save_mid_result:
                if batch_index in args.mid_result_index:
                    save_mid_result(outputs,
                                    labels,
                                    epoch,
                                    batch_index,
                                    args.mid_result_root,
                                    save_8map=True,
                                    train_phase=True)
                else:
                    pass
            else:
                pass
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""

            # if not args.band_mode:
            #     # 如果不是使用band_mode 则需要计算8张图的loss
            #     loss = wce_dice_huber_loss(outputs[0], labels[0]) * args.fuse_loss_weight
            #
            #     writer.add_scalar('fuse_loss_per_epoch', loss.item() / args.fuse_loss_weight,
            #                       global_step=epoch * train_epoch + batch_index)
            #
            #     for c_index, c in enumerate(outputs[1:]):
            #         one_loss_t = wce_dice_huber_loss(c, labels[c_index + 1])
            #         loss_8t += one_loss_t
            #         writer.add_scalar('%d_map_loss' % (c_index), one_loss_t.item(), global_step=train_epoch)
            #     loss += loss_8t
            #     loss = loss / 20
            loss = wce_dice_huber_loss(outputs[0], labels)
            writer.add_scalar('fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

            loss.backward()
            optimizer.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels)
        precisionscore = my_precision_score(outputs[0], labels)
        accscore = my_acc_score(outputs[0], labels)
        recallscore = my_recall_score(outputs[0], labels)

        writer.add_scalar('f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
def train(model,optimizer,epoch,save_dir):
    dataParser = DataParser(args.batch_size)
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    # switch to train mode
    model.train()
    end = time.time()
    epoch_loss = []
    counter = 0


    for batch_index ,(images,labels_numpy) in enumerate(generate_minibatches(dataParser,True)):

        # measure data loading time
        data_time.update(time.time()-end)

        labels = []
        if torch.cuda.is_available():
            images = torch.from_numpy(images).cuda()
            for item in labels_numpy:
                labels.append(torch.from_numpy(item).cuda())
        else:
            images = torch.from_numpy(images)
            for item in labels_numpy:
                labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss =torch.zeros(1).cuda()
        else:
            loss = torch.zeros(1)

        optimizer.zero_grad()
        outputs = model(images)
        # 四张GT监督

        for o in outputs[9:]: # o2 o3 o4
            t_loss = cross_entropy_loss(o, labels[-1])
            loss = loss +t_loss
        counter +=1

        for c_index,c in enumerate(outputs[:8]):
            loss = loss + cross_entropy_loss(c, labels[c_index])
        loss = loss/11
        loss.backward()
        acc_scroe = my_accuracy_score(outputs[9].cpu().detach().numpy(),labels[-1].cpu().detach().numpy())
        print('the acc is :',acc_scroe)




        # 下面应该是用来解决batch size 过下的问题
        # if counter == args.itersize:
        #     optimizer.step()
        #     optimizer.zero_grad()
        #     counter = 0

        optimizer.step()
        optimizer.zero_grad()

        # measure the accuracy and record loss
        losses.update(loss.item(),images.size(0))
        epoch_loss.append(loss.item())
        batch_time.update(time.time()-end)
        end = time.time()

        # display and logging
        if not isdir(save_dir):
            os.makedirs(save_dir)
        if batch_index % args.print_freq ==0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.steps_per_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(
                       loss=losses)
            print(info)

        # torch.save(model,join(save_dir,"checkpoint.pth"))
    # 每一轮保存一次参数
    save_checkpoint({'epoch': epoch,'state_dict':model.state_dict(), 'optimizer': optimizer.state_dict()},filename=join(save_dir,"epooch-%d-checkpoint.pth" %epoch))


    return losses.avg,epoch_loss
示例#19
0
def val(model1, model2, dataParser, epoch):
    # 读取数据的迭代器

    val_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = [
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue(),
        Averagvalue()
    ]
    # switch to train mode
    model1.eval()
    model2.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss_8t = torch.zeros(()).cuda()
        else:
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(False):
            images.requires_grad = False
            # 网络输出
            one_stage_outputs = model1(images)

            rgb_pred = images * one_stage_outputs[0]
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[1],
                                       one_stage_outputs[2],
                                       one_stage_outputs[3])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                               labels_band)
            ##############################################
            # deal with two stage issues
            _loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                                labels_dou_edge) * 12

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = map8_loss_ce(c, relation_map[c_index].cuda())
                loss_8t += one_loss_t
                map8_loss_value[c_index].update(one_loss_t.item())

            _loss_stage_2 += loss_8t
            loss_stage_2 = _loss_stage_2 / 20
            loss = (loss_stage_1 + loss_stage_2) / 2

            #######################################
            # 总的LOSS
            writer.add_scalar('val_stage_one_loss',
                              loss_stage_1.item(),
                              global_step=epoch * val_epoch + batch_index)
            writer.add_scalar('val_stage_two_fuse_loss',
                              loss_stage_2.item(),
                              global_step=epoch * val_epoch + batch_index)
            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * val_epoch + batch_index)
            ##########################################

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())

        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        precisionscore_stage2 = my_precision_score(two_stage_outputs[0],
                                                   labels_dou_edge)
        accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        recallscore_stage2 = my_recall_score(two_stage_outputs[0],
                                             labels_dou_edge)

        f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        precisionscore_stage1 = my_precision_score(one_stage_outputs[0],
                                                   labels_band)
        accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        writer.add_scalar('val_f1_score_stage1',
                          f1score_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score_stage1',
                          precisionscore_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score_stage1',
                          accscore_stage1,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score_stage1',
                          recallscore_stage1,
                          global_step=epoch * val_epoch + batch_index)

        writer.add_scalar('val_f1_score_stage2',
                          f1score_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score_stage2',
                          precisionscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score_stage2',
                          accscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score_stage2',
                          recallscore_stage2,
                          global_step=epoch * val_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, val_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) + \
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) + \
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= val_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg,
        'map8_loss': [map8_loss.avg for map8_loss in map8_loss_value],
    }
示例#20
0
def val(model, dataParser, epoch):
    # 读取数据的迭代器
    train_epoch = int(dataParser.val_steps)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, (images, labels_numpy) in enumerate(
            generate_minibatches(dataParser, False)):
        # 读取数据的时间
        data_time.update(time.time() - end)

        # 对读取的numpy类型数据进行调整
        labels = []
        if torch.cuda.is_available():
            images = torch.from_numpy(images).cuda()
            for item in labels_numpy:
                labels.append(torch.from_numpy(item).cuda())
        else:
            images = torch.from_numpy(images)
            for item in labels_numpy:
                labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        # 网络输出
        outputs = model(images)
        # 这里放保存中间结果的代码
        if args.save_mid_result:
            if batch_index in args.mid_result_index:
                save_mid_result(outputs,
                                labels,
                                epoch,
                                batch_index,
                                args.mid_result_root,
                                save_8map=True,
                                train_phase=True)
            else:
                pass
        else:
            pass
        """""" """""" """""" """""" """"""
        "         Loss 函数           "
        """""" """""" """""" """""" """"""

        if not args.band_mode:
            # 如果不是使用band_mode 则需要计算8张图的loss
            loss = wce_dice_huber_loss(outputs[0],
                                       labels[0]) * args.fuse_loss_weight

            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item() / args.fuse_loss_weight,
                              global_step=epoch * train_epoch + batch_index)

            for c_index, c in enumerate(outputs[1:]):
                one_loss_t = wce_dice_huber_loss(c, labels[c_index + 1])
                loss_8t += one_loss_t
                writer.add_scalar('val_%d_map_loss' % (c_index),
                                  one_loss_t.item(),
                                  global_step=train_epoch)
            loss += loss_8t
            loss = loss / 20
        else:
            loss = wce_dice_huber_loss(outputs[0], labels[0])
            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels[0])
        precisionscore = my_precision_score(outputs[0], labels[0])
        accscore = my_acc_score(outputs[0], labels[0])
        recallscore = my_recall_score(outputs[0], labels[0])

        writer.add_scalar('val_f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.val_steps) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'vla_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'val_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'val_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'val_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'val_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
示例#21
0
def val(model, epoch):
    torch.cuda.empty_cache()
    # 读取数据的迭代器
    dataParser = DataParser(args.batch_size)
    #
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()

    # switch to train mode
    model.eval()
    end = time.time()
    epoch_loss = []
    counter = 0

    for batch_index, (images, labels_numpy) in enumerate(
            generate_minibatches(dataParser, False)):

        # 读取数据的时间
        data_time.update(time.time() - end)

        # 对读取的numpy类型数据进行调整
        labels = []
        if torch.cuda.is_available():
            images = torch.from_numpy(images).cuda()
            for item in labels_numpy:
                labels.append(torch.from_numpy(item).cuda())
        else:
            images = torch.from_numpy(images)
            for item in labels_numpy:
                labels.append(torch.from_numpy(item))

        # 输出结果[img,8张图]
        outputs = model(images)

        # 这里放保存中间结果的代码
        if batch_index in args.mid_result_index:
            save_mid_result(outputs,
                            labels,
                            epoch,
                            batch_index,
                            args.mid_result_root,
                            save_8map=True,
                            train_phase=False)

        # 建立loss
        loss = wce_huber_loss(outputs[0], labels[0]) * 12
        for c_index, c in enumerate(outputs[1:]):
            loss = loss + wce_huber_loss_8(c, labels[c_index + 1])
        loss = loss / 20

        # measure the accuracy and record loss
        losses.update(loss.item(), images.size(0))
        epoch_loss.append(loss.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels[0])
        precisionscore = my_precision_score(outputs[0], labels[0])
        accscore = my_acc_score(outputs[0], labels[0])
        writer.add_scalar('val_f1score',
                          f1score,
                          global_step=epoch * dataParser.val_steps +
                          batch_index)
        writer.add_scalar('val_precisionscore',
                          precisionscore,
                          global_step=epoch * dataParser.val_steps +
                          batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * dataParser.val_steps +
                          batch_index)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.val_steps) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score : %.4f ' % f1score + \
                   'precision_score: %.4f ' % precisionscore + \
                   'acc_score %.4f ' % accscore

            print('val: ', info)
        writer.add_scalar('val_avg_loss2',
                          losses.val,
                          global_step=epoch * (dataParser.val_steps // 100) +
                          batch_index)
        if batch_index > dataParser.val_steps // 1:
            break

    return losses.avg, epoch_loss
def train(model1, model2, optimizer1, optimizer2, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    loss_stage1 = Averagvalue()
    loss_stage2 = Averagvalue()

    f1_value_stage1 = Averagvalue()
    acc_value_stage1 = Averagvalue()
    recall_value_stage1 = Averagvalue()
    precision_value_stage1 = Averagvalue()

    f1_value_stage2 = Averagvalue()
    acc_value_stage2 = Averagvalue()
    recall_value_stage2 = Averagvalue()
    precision_value_stage2 = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to train mode
    model1.train()
    model2.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels_band = input_data['gt_band'].cuda()
        labels_dou_edge = input_data['gt_dou_edge'].cuda()
        relation_map = input_data['relation_map']

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            if images.shape[1] != 3 or images.shape[2] != 320:
                continue
            # 网络输出
            one_stage_outputs = model1(images)

            zero = torch.zeros_like(one_stage_outputs[0])
            one = torch.ones_like(one_stage_outputs[0])

            rgb_pred = images * torch.where(one_stage_outputs[0] > 0.1, one,
                                            zero)
            rgb_pred_rgb = torch.cat((rgb_pred, images), 1)
            two_stage_outputs = model2(rgb_pred_rgb, one_stage_outputs[9],
                                       one_stage_outputs[10],
                                       one_stage_outputs[11])
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""
            ##########################################
            # deal with one stage issue
            # 建立loss
            _loss_stage_1 = wce_dice_huber_loss(one_stage_outputs[0],
                                                labels_band)
            loss_stage_1 = _loss_stage_1
            ##############################################
            # deal with two stage issues
            _loss_stage_2 = wce_dice_huber_loss(two_stage_outputs[0],
                                                labels_dou_edge) * 12

            for c_index, c in enumerate(two_stage_outputs[1:9]):
                one_loss_t = cross_entropy_loss(c,
                                                relation_map[c_index].cuda())
                loss_8t += one_loss_t
                writer.add_scalar('stage2_%d_map_loss' % (c_index),
                                  one_loss_t.item(),
                                  global_step=epoch * train_epoch +
                                  batch_index)

            _loss_stage_2 += loss_8t
            loss_stage_2 = _loss_stage_2 / 20
            loss = (loss_stage_1 + loss_stage_2) / 2

            #######################################
            # 总的LOSS
            writer.add_scalar('stage_one_loss',
                              loss_stage_1.item(),
                              global_step=epoch * train_epoch + batch_index)
            writer.add_scalar('stage_two_pred_loss',
                              _loss_stage_2.item(),
                              global_step=epoch * train_epoch + batch_index)
            writer.add_scalar('stage_two_fuse_loss',
                              loss_stage_2.item(),
                              global_step=epoch * train_epoch + batch_index)

            writer.add_scalar('fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)
            ##########################################

            loss.backward()
            optimizer1.step()
            optimizer2.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        loss_stage1.update(loss_stage_1.item())
        loss_stage2.update(loss_stage_2.item())

        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score_stage2 = my_f1_score(two_stage_outputs[0], labels_dou_edge)
        precisionscore_stage2 = my_precision_score(two_stage_outputs[0],
                                                   labels_dou_edge)
        accscore_stage2 = my_acc_score(two_stage_outputs[0], labels_dou_edge)
        recallscore_stage2 = my_recall_score(two_stage_outputs[0],
                                             labels_dou_edge)

        f1score_stage1 = my_f1_score(one_stage_outputs[0], labels_band)
        precisionscore_stage1 = my_precision_score(one_stage_outputs[0],
                                                   labels_band)
        accscore_stage1 = my_acc_score(one_stage_outputs[0], labels_band)
        recallscore_stage1 = my_recall_score(one_stage_outputs[0], labels_band)

        writer.add_scalar('f1_score_stage1',
                          f1score_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score_stage1',
                          precisionscore_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score_stage1',
                          accscore_stage1,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score_stage1',
                          recallscore_stage1,
                          global_step=epoch * train_epoch + batch_index)

        writer.add_scalar('f1_score_stage2',
                          f1score_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score_stage2',
                          precisionscore_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score_stage2',
                          accscore_stage2,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score_stage2',
                          recallscore_stage2,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value_stage1.update(f1score_stage1)
        precision_value_stage1.update(precisionscore_stage1)
        acc_value_stage1.update(accscore_stage1)
        recall_value_stage1.update(recallscore_stage1)

        f1_value_stage2.update(f1score_stage2)
        precision_value_stage2.update(precisionscore_stage2)
        acc_value_stage2.update(accscore_stage2)
        recall_value_stage2.update(recallscore_stage2)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   '两阶段总Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   '第一阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage1) + \
                   '第二阶段Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=loss_stage2) + \
                   '第一阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage1) + \
                   '第一阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value_stage1) + \
                   '第一阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage1) +\
                   '第一阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage1) + \
                   '第二阶段:f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value_stage2) + \
                   '第二阶段:precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value_stage2) + \
                   '第二阶段:acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value_stage2) +\
                   '第二阶段:recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value_stage2)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg_stage1': f1_value_stage1.avg,
        'precision_avg_stage1': precision_value_stage1.avg,
        'accuracy_avg_stage1': acc_value_stage1.avg,
        'recall_avg_stage1': recall_value_stage1.avg,
        'f1_avg_stage2': f1_value_stage2.avg,
        'precision_avg_stage2': precision_value_stage2.avg,
        'accuracy_avg_stage2': acc_value_stage2.avg,
        'recall_avg_stage2': recall_value_stage2.avg
    }
示例#23
0
def train(train_loader, model, optimizer, epoch, save_dir):
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    # switch to train mode
    model.train()
    end = time.time()
    epoch_loss = []
    counter = 0
    for i, (image, label) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        image, label = image.cuda(), label.cuda()
        outputs = model(image)
        loss = torch.zeros(1).cuda()
        for o in outputs:
            loss = loss + cross_entropy_loss_RCF(o, label)
        counter += 1
        loss = loss / args.itersize
        loss.backward()
        if counter == args.itersize:
            optimizer.step()
            optimizer.zero_grad()
            counter = 0
        # measure accuracy and record loss
        losses.update(loss.item(), image.size(0))
        epoch_loss.append(loss.item())
        batch_time.update(time.time() - end)
        end = time.time()
        # display and logging
        if not isdir(save_dir):
            os.makedirs(save_dir)
        if i % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, i, len(train_loader)) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(
                       loss=losses)

            print(info)
            label_out = torch.eq(label, 1).float()
            outputs.append(label_out)
            _, _, H, W = outputs[0].shape
            all_results = torch.zeros((len(outputs), 1, H, W))
            for j in range(len(outputs)):
                all_results[j, 0, :, :] = outputs[j][0, 0, :, :]
            torchvision.utils.save_image(1 - all_results,
                                         join(save_dir, "iter-%d.jpg" % i))
        # save checkpoint
    save_checkpoint(
        {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        },
        filename=join(save_dir, "epoch-%d-checkpoint.pth" % epoch))

    # vis = visdom.Visdom()

    #n  loss_window = vis.line(
    #     Y=torch.zeros((1)).cuda(),
    #     X=torch.zeros((1)).cuda(),
    #     opts=dict(xlabel='epoch',ylabel='Loss',title='training loss',legend=['Loss']))

    # vis.line(X=torch.ones((1,1)).cuda()*epoch,Y=torch.Tensor([epoch_loss]).unsqueeze(0).cuda(),win=loss_window,update='append')

    return losses.avg, epoch_loss
示例#24
0
def test(model, dataParser, epoch):
    # 读取数据的迭代器
    test_epoch = len(dataParser)

    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)

        images = input_data['tamper_image']
        labels = input_data['gt_band']

        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        # 网络输出
        try:
            outputs = model(images)[0]
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""

            loss = wce_dice_huber_loss(outputs, labels)
        except Exception as e:
            continue
        writer.add_scalar('val_fuse_loss_per_epoch',
                          loss.item(),
                          global_step=epoch * test_epoch + batch_index)

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs, labels)
        precisionscore = my_precision_score(outputs, labels)
        accscore = my_acc_score(outputs, labels)
        recallscore = my_recall_score(outputs, labels)

        writer.add_scalar('test_f1_score',
                          f1score,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_scalar('test_precision_score',
                          precisionscore,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_scalar('test_acc_score',
                          accscore,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_scalar('test_recall_score',
                          recallscore,
                          global_step=epoch * test_epoch + batch_index)
        writer.add_images('test_image_batch:%d' % (batch_index),
                          outputs,
                          global_step=epoch)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'TEST_Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, test_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'test_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'test_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'test_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(
                       precision=precision_value) + \
                   'test_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'test_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= test_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }