コード例 #1
0
ファイル: two_stage_train.py プロジェクト: MuskAI/Mymodel
def val(model, epoch):
    torch.cuda.empty_cache()
    # 读取数据的迭代器
    dataParser = DataParser(args.batch_size)
    #
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()

    # switch to train mode
    model.eval()
    end = time.time()
    epoch_loss = []
    counter = 0

    for batch_index, (images, labels_numpy) in enumerate(
            generate_minibatches(dataParser, False)):

        # 读取数据的时间
        data_time.update(time.time() - end)

        # 对读取的numpy类型数据进行调整
        labels = []
        if torch.cuda.is_available():
            images = torch.from_numpy(images).cuda()
            for item in labels_numpy:
                labels.append(torch.from_numpy(item).cuda())
        else:
            images = torch.from_numpy(images)
            for item in labels_numpy:
                labels.append(torch.from_numpy(item))

        # 输出结果[img,8张图]
        outputs = model(images)

        # 这里放保存中间结果的代码
        if batch_index in args.mid_result_index:
            save_mid_result(outputs,
                            labels,
                            epoch,
                            batch_index,
                            args.mid_result_root,
                            save_8map=True,
                            train_phase=False)

        # 建立loss
        loss = wce_huber_loss(outputs[0], labels[0]) * 12
        for c_index, c in enumerate(outputs[1:]):
            loss = loss + wce_huber_loss_8(c, labels[c_index + 1])
        loss = loss / 20

        # measure the accuracy and record loss
        losses.update(loss.item(), images.size(0))
        epoch_loss.append(loss.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels[0])
        precisionscore = my_precision_score(outputs[0], labels[0])
        accscore = my_acc_score(outputs[0], labels[0])
        writer.add_scalar('val_f1score',
                          f1score,
                          global_step=epoch * dataParser.val_steps +
                          batch_index)
        writer.add_scalar('val_precisionscore',
                          precisionscore,
                          global_step=epoch * dataParser.val_steps +
                          batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * dataParser.val_steps +
                          batch_index)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.val_steps) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score : %.4f ' % f1score + \
                   'precision_score: %.4f ' % precisionscore + \
                   'acc_score %.4f ' % accscore

            print('val: ', info)
        writer.add_scalar('val_avg_loss2',
                          losses.val,
                          global_step=epoch * (dataParser.val_steps // 100) +
                          batch_index)
        if batch_index > dataParser.val_steps // 1:
            break

    return losses.avg, epoch_loss
コード例 #2
0
def val(model, dataParser, epoch):
    # 读取数据的迭代器
    train_epoch = int(dataParser.val_steps)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, (images, labels_numpy) in enumerate(
            generate_minibatches(dataParser, False)):
        # 读取数据的时间
        data_time.update(time.time() - end)

        # 对读取的numpy类型数据进行调整
        labels = []
        if torch.cuda.is_available():
            images = torch.from_numpy(images).cuda()
            for item in labels_numpy:
                labels.append(torch.from_numpy(item).cuda())
        else:
            images = torch.from_numpy(images)
            for item in labels_numpy:
                labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        # 网络输出
        outputs = model(images)
        # 这里放保存中间结果的代码
        if args.save_mid_result:
            if batch_index in args.mid_result_index:
                save_mid_result(outputs,
                                labels,
                                epoch,
                                batch_index,
                                args.mid_result_root,
                                save_8map=True,
                                train_phase=True)
            else:
                pass
        else:
            pass
        """""" """""" """""" """""" """"""
        "         Loss 函数           "
        """""" """""" """""" """""" """"""

        if not args.band_mode:
            # 如果不是使用band_mode 则需要计算8张图的loss
            loss = wce_dice_huber_loss(outputs[0],
                                       labels[0]) * args.fuse_loss_weight

            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item() / args.fuse_loss_weight,
                              global_step=epoch * train_epoch + batch_index)

            for c_index, c in enumerate(outputs[1:]):
                one_loss_t = wce_dice_huber_loss(c, labels[c_index + 1])
                loss_8t += one_loss_t
                writer.add_scalar('val_%d_map_loss' % (c_index),
                                  one_loss_t.item(),
                                  global_step=train_epoch)
            loss += loss_8t
            loss = loss / 20
        else:
            loss = wce_dice_huber_loss(outputs[0], labels[0])
            writer.add_scalar('val_fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels[0])
        precisionscore = my_precision_score(outputs[0], labels[0])
        accscore = my_acc_score(outputs[0], labels[0])
        recallscore = my_recall_score(outputs[0], labels[0])

        writer.add_scalar('val_f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('val_recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.val_steps) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'vla_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'val_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'val_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'val_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'val_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
コード例 #3
0
def val(model, dataParser, epoch):
    # 读取数据的迭代器
    val_epoch = len(dataParser)

    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to test mode
    model.eval()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)

        images = input_data['tamper_image']
        labels = input_data['gt_band']

        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()

        # 对读取的numpy类型数据进行调整

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())
        # 网络输出
        outputs = model(images)[0]
        # 这里放保存中间结果的代码
        if args.save_mid_result:
            if batch_index in args.mid_result_index:
                save_mid_result(outputs,
                                labels,
                                epoch,
                                batch_index,
                                args.mid_result_root,
                                save_8map=True,
                                train_phase=True)
            else:
                pass
        else:
            pass
        """""" """""" """""" """""" """"""
        "         Loss 函数           "
        """""" """""" """""" """""" """"""
        loss = wce_dice_huber_loss(outputs, labels)
        writer.add_scalar('val_fuse_loss_per_epoch',
                          loss.item(),
                          global_step=epoch * val_epoch + batch_index)
        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs, labels)
        precisionscore = my_precision_score(outputs, labels)
        accscore = my_acc_score(outputs, labels)
        recallscore = my_recall_score(outputs, labels)

        writer.add_scalar('val_f1_score',
                          f1score,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_precision_score',
                          precisionscore,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_acc_score',
                          accscore,
                          global_step=epoch * val_epoch + batch_index)
        writer.add_scalar('val_recall_score',
                          recallscore,
                          global_step=epoch * val_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, val_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'vla_Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'val_f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'val_precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'val_acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'val_recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= val_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }
コード例 #4
0
def train(model, optimizer, dataParser, epoch):
    # 读取数据的迭代器

    train_epoch = len(dataParser)
    # 变量保存
    batch_time = Averagvalue()
    data_time = Averagvalue()
    losses = Averagvalue()
    f1_value = Averagvalue()
    acc_value = Averagvalue()
    recall_value = Averagvalue()
    precision_value = Averagvalue()
    map8_loss_value = Averagvalue()

    # switch to train mode
    model.train()
    end = time.time()

    for batch_index, input_data in enumerate(dataParser):
        # 读取数据的时间
        data_time.update(time.time() - end)
        # check_4dim_img_pair(input_data['tamper_image'],input_data['gt_band'])
        # 准备输入数据
        images = input_data['tamper_image'].cuda()
        labels = input_data['gt_band'].cuda()

        # 对读取的numpy类型数据进行调整
        # labels = []
        # if torch.cuda.is_available():
        #     images = torch.from_numpy(images).cuda()
        #     for item in labels_numpy:
        #         labels.append(torch.from_numpy(item).cuda())
        # else:
        #     images = torch.from_numpy(images)
        #     for item in labels_numpy:
        #         labels.append(torch.from_numpy(item))

        if torch.cuda.is_available():
            loss = torch.zeros(1).cuda()
            loss_8t = torch.zeros(()).cuda()
        else:
            loss = torch.zeros(1)
            loss_8t = torch.zeros(())

        with torch.set_grad_enabled(True):
            images.requires_grad = True
            optimizer.zero_grad()
            # 网络输出
            outputs = model(images)
            # 这里放保存中间结果的代码
            if args.save_mid_result:
                if batch_index in args.mid_result_index:
                    save_mid_result(outputs,
                                    labels,
                                    epoch,
                                    batch_index,
                                    args.mid_result_root,
                                    save_8map=True,
                                    train_phase=True)
                else:
                    pass
            else:
                pass
            """""" """""" """""" """""" """"""
            "         Loss 函数           "
            """""" """""" """""" """""" """"""

            # if not args.band_mode:
            #     # 如果不是使用band_mode 则需要计算8张图的loss
            #     loss = wce_dice_huber_loss(outputs[0], labels[0]) * args.fuse_loss_weight
            #
            #     writer.add_scalar('fuse_loss_per_epoch', loss.item() / args.fuse_loss_weight,
            #                       global_step=epoch * train_epoch + batch_index)
            #
            #     for c_index, c in enumerate(outputs[1:]):
            #         one_loss_t = wce_dice_huber_loss(c, labels[c_index + 1])
            #         loss_8t += one_loss_t
            #         writer.add_scalar('%d_map_loss' % (c_index), one_loss_t.item(), global_step=train_epoch)
            #     loss += loss_8t
            #     loss = loss / 20
            loss = wce_dice_huber_loss(outputs[0], labels)
            writer.add_scalar('fuse_loss_per_epoch',
                              loss.item(),
                              global_step=epoch * train_epoch + batch_index)

            loss.backward()
            optimizer.step()

        # 将各种数据记录到专门的对象中
        losses.update(loss.item())
        map8_loss_value.update(loss_8t.item())
        batch_time.update(time.time() - end)
        end = time.time()

        # 评价指标
        f1score = my_f1_score(outputs[0], labels)
        precisionscore = my_precision_score(outputs[0], labels)
        accscore = my_acc_score(outputs[0], labels)
        recallscore = my_recall_score(outputs[0], labels)

        writer.add_scalar('f1_score',
                          f1score,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('precision_score',
                          precisionscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('acc_score',
                          accscore,
                          global_step=epoch * train_epoch + batch_index)
        writer.add_scalar('recall_score',
                          recallscore,
                          global_step=epoch * train_epoch + batch_index)
        ################################

        f1_value.update(f1score)
        precision_value.update(precisionscore)
        acc_value.update(accscore)
        recall_value.update(recallscore)

        if batch_index % args.print_freq == 0:
            info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, train_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) + \
                   'recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

            print(info)

        if batch_index >= train_epoch:
            break

    return {
        'loss_avg': losses.avg,
        'f1_avg': f1_value.avg,
        'precision_avg': precision_value.avg,
        'accuracy_avg': acc_value.avg,
        'recall_avg': recall_value.avg
    }