Esempio n. 1
0
def main():
    args.cuda = True
    # 1 choose the data you want to use
    using_data = {
        'my_sp': True,
        'my_cm': True,
        'template_casia_casia': True,
        'template_coco_casia': False,
        'cod10k': True,
        'casia': False,
        'copy_move': False,
        'texture_sp': True,
        'texture_cm': True,
        'columb': False,
        'negative': True,
        'negative_casia': False,
    }
    # using_data = {'my_sp': False,
    #               'my_cm': False,
    #               'template_casia_casia': False,
    #               'template_coco_casia': False,
    #               'cod10k': True,
    #               'casia': False,
    #               'copy_move': False,
    #               'texture_sp': False,
    #               'texture_cm': False,
    #               'columb': False,
    #               'negative': False,
    #               'negative_casia': False,
    #               }
    using_data_test = {
        'my_sp': False,
        'my_cm': False,
        'template_casia_casia': False,
        'template_coco_casia': False,
        'cod10k': False,
        'casia': False,
        'coverage': True,
        'columb': False,
        'negative': False,
        'negative_casia': False,
    }
    # 2 define 3 types
    trainData = TamperDataset(stage_type='stage2',
                              using_data=using_data,
                              train_val_test_mode='train')
    valData = TamperDataset(stage_type='stage2',
                            using_data=using_data,
                            train_val_test_mode='val')
    testData = TamperDataset(stage_type='stage2',
                             using_data=using_data_test,
                             train_val_test_mode='test')

    # 3 specific dataloader
    trainDataLoader = torch.utils.data.DataLoader(trainData,
                                                  batch_size=args.batch_size,
                                                  num_workers=4,
                                                  shuffle=True,
                                                  pin_memory=False)
    valDataLoader = torch.utils.data.DataLoader(valData,
                                                batch_size=args.batch_size,
                                                num_workers=4,
                                                shuffle=True)

    testDataLoader = torch.utils.data.DataLoader(testData,
                                                 batch_size=1,
                                                 num_workers=0)
    # model
    model1 = Net1(bilinear=True)
    model2 = Net2(bilinear=True)
    if torch.cuda.is_available():
        model1.cuda()
        model2.cuda()
    else:
        model1.cpu()
        model2.cpu()

    # 模型初始化
    # 如果没有这一步会根据正态分布自动初始化
    model1.apply(weights_init)
    model2.apply(weights_init)

    # 模型可持续化
    optimizer1 = optim.Adam(model1.parameters(),
                            lr=1e-6,
                            betas=(0.9, 0.999),
                            eps=1e-8)
    optimizer2 = optim.Adam(model2.parameters(),
                            lr=1e-3,
                            betas=(0.9, 0.999),
                            eps=1e-8)

    # 加载模型
    if isfile(args.resume[0]) and isfile(args.resume[1]):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint1 = torch.load(args.resume[0])
        checkpoint2 = torch.load(args.resume[1])
        model1.load_state_dict(checkpoint1['state_dict'])
        # optimizer1.load_state_dict(checkpoint1['optimizer'])
        ################################################
        model2.load_state_dict(checkpoint2['state_dict'])
        # optimizer2.load_state_dict(checkpoint2['optimizer'])
        print("=> loaded checkpoint '{}'".format(args.resume))
    elif isfile(args.resume[0]) and not isfile(args.resume[1]):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint1 = torch.load(args.resume[0])
        # checkpoint2 = torch.load(args.resume[1])
        model1.load_state_dict(checkpoint1['state_dict'])
        # optimizer1.load_state_dict(checkpoint1['optimizer'])
        ################################################
        # model2.load_state_dict(checkpoint2['state_dict'])
        # optimizer2.load_state_dict(checkpoint2['optimizer'])
        print("=> loaded checkpoint '{}'".format(args.resume))
    elif not isfile(args.resume[0]) and isfile(args.resume[1]):
        print("=> loading checkpoint '{}'".format(args.resume))
        # checkpoint1 = torch.load(args.resume[0])
        checkpoint2 = torch.load(args.resume[1])
        # model1.load_state_dict(checkpoint1['state_dict'])
        # optimizer1.load_state_dict(checkpoint1['optimizer'])
        ################################################
        model2.load_state_dict(checkpoint2['state_dict'])
        optimizer2.load_state_dict(checkpoint2['optimizer'])
        print("=> loaded checkpoint '{}'".format(args.resume))
    else:
        print("=> !!!!!!! checkpoint found at '{}'".format(args.resume))

    # 调整学习率
    scheduler1 = lr_scheduler.StepLR(optimizer1,
                                     step_size=args.stepsize,
                                     gamma=args.gamma)
    scheduler2 = lr_scheduler.StepLR(optimizer2,
                                     step_size=args.stepsize,
                                     gamma=args.gamma)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)
    # 数据迭代器

    for epoch in range(args.start_epoch, args.maxepoch):
        train_avg = train(model1=model1,
                          model2=model2,
                          optimizer1=optimizer1,
                          optimizer2=optimizer2,
                          dataParser=trainDataLoader,
                          epoch=epoch)

        val_avg = val(model1=model1,
                      model2=model2,
                      dataParser=valDataLoader,
                      epoch=epoch)
        test(model1=model1,
             model2=model2,
             dataParser=testDataLoader,
             epoch=epoch)
        """""" """""" """""" """""" """"""
        "          写入图             "
        """""" """""" """""" """""" """"""
        try:
            writer.add_scalars('lr_per_epoch', {
                'stage1': scheduler1.get_lr(),
                'stage2': scheduler2.get_lr()
            },
                               global_step=epoch)
            writer.add_scalars('tr-val_avg_loss_per_epoch', {
                'train': train_avg['loss_avg'],
                'val': val_avg['loss_avg'],
            },
                               global_step=epoch)
            writer.add_scalars('tr-val_avg_f1_per_epoch', {
                'train': train_avg['f1_avg_stage2'],
                'val': val_avg['f1_avg_stage2'],
            },
                               global_step=epoch)

            writer.add_scalars('tr-val_avg_precision_per_epoch', {
                'train': train_avg['precision_avg_stage2'],
                'val': val_avg['precision_avg_stage2'],
            },
                               global_step=epoch)
            writer.add_scalars('tr-val_avg_acc_per_epoch', {
                'train': train_avg['accuracy_avg_stage2'],
                'val': val_avg['accuracy_avg_stage2'],
            },
                               global_step=epoch)
            writer.add_scalars('tr-val_avg_recall_per_epoch', {
                'train': train_avg['recall_avg_stage2'],
                'val': val_avg['recall_avg_stage2'],
            },
                               global_step=epoch)

            writer.add_scalars('tr_avg_map8_loss_per_epoch', {
                'map1': train_avg['map8_loss'][0],
                'map2': train_avg['map8_loss'][1],
                'map3': train_avg['map8_loss'][2],
                'map4': train_avg['map8_loss'][3],
                'map5': train_avg['map8_loss'][4],
                'map6': train_avg['map8_loss'][5],
                'map7': train_avg['map8_loss'][6],
                'map8': train_avg['map8_loss'][7]
            },
                               global_step=epoch)
            writer.add_scalars('val_avg_map8_loss_per_epoch', {
                'map1': train_avg['map8_loss'][0],
                'map2': val_avg['map8_loss'][1],
                'map3': val_avg['map8_loss'][2],
                'map4': val_avg['map8_loss'][3],
                'map5': val_avg['map8_loss'][4],
                'map6': val_avg['map8_loss'][5],
                'map7': val_avg['map8_loss'][6],
                'map8': val_avg['map8_loss'][7]
            },
                               global_step=epoch)

        except Exception as e:
            print(e)
        """""" """""" """""" """""" """"""
        "          写入图            "
        """""" """""" """""" """""" """"""

        output_name1 = output_name_file_name % \
                      (epoch, val_avg['loss_avg'],
                       val_avg['f1_avg_stage1'],
                       val_avg['precision_avg_stage1'],
                       val_avg['accuracy_avg_stage1'],
                       val_avg['recall_avg_stage1'])
        output_name2 = output_name_file_name % \
                      (epoch, val_avg['loss_avg'],
                       val_avg['f1_avg_stage2'],
                       val_avg['precision_avg_stage2'],
                       val_avg['accuracy_avg_stage2'],
                       val_avg['recall_avg_stage2'])

        try:
            # # send_msn(epoch, f1=val_avg['f1_avg'])
            # email_output_train = 'The train epoch:%d,f1:%f,loss:%f,precision:%f,accuracy:%f,recall:%f' % \
            #                      (epoch, train_avg['loss_avg'], train_avg['f1_avg'], train_avg['precision_avg'],
            #                       train_avg['accuracy_avg'], train_avg['recall_avg'])
            # email_output_val = 'The val epoch:%d,f1:%f,loss:%f,precision:%f,accuracy:%f,recall:%f' % \
            #                    (epoch, val_avg['loss_avg'], val_avg['f1_avg'], val_avg['precision_avg'],
            #                     val_avg['accuracy_avg'], val_avg['recall_avg'])
            # email_output_test = 'The test epoch:%d,f1:%f,loss:%f,precision:%f,accuracy:%f,recall:%f' % \
            #                     (epoch, test_avg['loss_avg'], test_avg['f1_avg'], test_avg['precision_avg'],
            #                      test_avg['accuracy_avg'], test_avg['recall_avg'])
            #
            # email_output = email_output_train + '\n' + email_output_val + '\n' + email_output_test + '\n\n\n'
            # email_list.append(email_output)
            # send_email(str(email_header), context=str(email_list))
            pass

        except:
            pass
        if epoch % 1 == 0:
            save_model_name_stage1 = os.path.join(args.model_save_dir,
                                                  'stage1_' + output_name1)
            save_model_name_stage2 = os.path.join(args.model_save_dir,
                                                  'stage2_' + output_name2)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model1.state_dict(),
                    'optimizer': optimizer1.state_dict()
                }, save_model_name_stage1)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model2.state_dict(),
                    'optimizer': optimizer2.state_dict()
                }, save_model_name_stage2)

        scheduler1.step(epoch=epoch)
        scheduler2.step(epoch=epoch)
    print('训练已完成!')
def main():
    args.cuda = True
    # 1 choose the data you want to use
    using_data = {
        'my_sp': False,
        'my_cm': False,
        'template_casia_casia': False,
        'template_coco_casia': False,
        'cod10k': True,
        'casia': False,
        'coverage': False,
        'columb': False,
        'negative_coco': False,
        'negative_casia': False,
        'texture_sp': False,
        'texture_cm': False,
    }
    using_data_test = {
        'my_sp': False,
        'my_cm': False,
        'template_casia_casia': False,
        'template_coco_casia': False,
        'cod10k': False,
        'casia': False,
        'coverage': True,
        'columb': False,
        'negative_coco': False,
        'negative_casia': False,
    }
    # 2 define 3 types
    trainData = TamperDataset(stage_type='stage2',
                              using_data=using_data,
                              train_val_test_mode='train')
    valData = TamperDataset(stage_type='stage2',
                            using_data=using_data,
                            train_val_test_mode='val')
    testData = TamperDataset(stage_type='stage2',
                             using_data=using_data_test,
                             train_val_test_mode='test')

    # 3 specific dataloader
    trainDataLoader = torch.utils.data.DataLoader(trainData,
                                                  batch_size=args.batch_size,
                                                  num_workers=3,
                                                  shuffle=True,
                                                  pin_memory=False)
    valDataLoader = torch.utils.data.DataLoader(valData,
                                                batch_size=args.batch_size,
                                                num_workers=3)

    testDataLoader = torch.utils.data.DataLoader(testData,
                                                 batch_size=args.batch_size,
                                                 num_workers=0)
    # model
    model1 = Net1()
    model2 = Net2()
    if torch.cuda.is_available():
        model1.cuda()
        model2.cuda()
    else:
        model1.cpu()
        model2.cuda()

    # 模型初始化
    # 如果没有这一步会根据正态分布自动初始化
    model1.apply(weights_init)
    model2.apply(weights_init)

    # 模型可持续化
    # 这是tensorflow代码中的配置:    optimizer = Adam(lr=1e-2, beta_1=0.9, beta_2=0.999)
    optimizer1 = optim.Adam(model1.parameters(),
                            lr=1e-5,
                            betas=(0.9, 0.999),
                            eps=1e-8)
    optimizer2 = optim.Adam(model2.parameters(),
                            lr=args.lr,
                            betas=(0.9, 0.999),
                            eps=1e-8)
    if args.resume[0]:
        if isfile(args.resume[0]):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint1 = torch.load(args.resume[0])
            checkpoint2 = torch.load(args.resume[1])
            model1.load_state_dict(checkpoint1['state_dict'])
            # optimizer1.load_state_dict(checkpoint1['optimizer'])
            ################################################
            model2.load_state_dict(checkpoint2['state_dict'])
            # optimizer2.load_state_dict(checkpoint2['optimizer'])
            print("=> loaded checkpoint '{}'".format(args.resume))

        else:
            print("=> Error!!!! checkpoint found at '{}'".format(args.resume))

    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    # 调整学习率
    scheduler1 = lr_scheduler.StepLR(optimizer1,
                                     step_size=args.stepsize,
                                     gamma=args.gamma)
    scheduler2 = lr_scheduler.StepLR(optimizer2,
                                     step_size=args.stepsize,
                                     gamma=args.gamma)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)
    # 数据迭代器

    send_email = []
    for epoch in range(args.start_epoch, args.maxepoch):
        train_avg = train(model1=model1,
                          model2=model2,
                          optimizer1=optimizer1,
                          optimizer2=optimizer2,
                          dataParser=trainDataLoader,
                          epoch=epoch)
        print('start val')
        # val_avg = val(model=model1, model2=model2, dataParser=valDataLoader, epoch=epoch)
        # test_avg = test(model=model1, model2=model2, dataParser=testDataLoader, epoch=epoch)
        """""" """""" """""" """""" """"""
        "          写入图             "
        """""" """""" """""" """""" """"""
        try:
            writer.add_scalar('tr_avg_loss_per_epoch',
                              train_avg['loss_avg_stage2'],
                              global_step=epoch)
            writer.add_scalar('tr_avg_f1_per_epoch',
                              train_avg['f1_avg_stage2'],
                              global_step=epoch)
            writer.add_scalar('tr_avg_precision_per_epoch',
                              train_avg['precision_avg_stage2'],
                              global_step=epoch)
            writer.add_scalar('tr_avg_acc_per_epoch',
                              train_avg['accuracy_avg_stage2'],
                              global_step=epoch)
            writer.add_scalar('tr_avg_recall_per_epoch',
                              train_avg['recall_avg_stage2'],
                              global_step=epoch)

            #
            #
            # writer.add_scalar('val_avg_loss_per_epoch', val_avg['loss_avg_stage2'], global_step=epoch)
            # writer.add_scalar('val_avg_f1_per_epoch', val_avg['f1_avg_stage2'], global_step=epoch)
            # writer.add_scalar('val_avg_precision_per_epoch', val_avg['precision_avg_stage2'], global_step=epoch)
            # writer.add_scalar('val_avg_acc_per_epoch', val_avg['accuracy_avg_stage2'], global_step=epoch)
            # writer.add_scalar('val_avg_recall_per_epoch', val_avg['recall_avg_stage2'], global_step=epoch)
            #
            # writer.add_scalar('test_avg_loss_per_epoch', test_avg['loss_avg_stage2'], global_step=epoch)
            # writer.add_scalar('test_avg_f1_per_epoch', test_avg['f1_avg_stage2'], global_step=epoch)
            # writer.add_scalar('test_avg_precision_per_epoch', test_avg['precision_avg_stage2'], global_step=epoch)
            # writer.add_scalar('test_avg_acc_per_epoch', test_avg['accuracy_avg_stage2'], global_step=epoch)
            # writer.add_scalar('test_avg_recall_per_epoch', test_avg['recall_avg_stage2'], global_step=epoch)

            writer.add_scalar('lr_per_epoch_stage1',
                              scheduler1.get_lr(),
                              global_step=epoch)
            writer.add_scalar('lr_per_epoch_stage2',
                              scheduler2.get_lr(),
                              global_step=epoch)
        except Exception as e:
            print(e)
        """""" """""" """""" """""" """"""
        "          写入图            "
        """""" """""" """""" """""" """"""

        # output_name = output_name_file_name % \
        #               (epoch, val_avg['loss_avg'],
        #                val_avg['f1_avg'],
        #                val_avg['precision_avg'],
        #                val_avg['accuracy_avg'],
        #                val_avg['recall_avg'])
        output_name = output_name_file_name % \
                      (epoch, train_avg['loss_avg_stage2'],
                       train_avg['f1_avg_stage2'],
                       train_avg['precision_avg_stage2'],
                       train_avg['accuracy_avg_stage2'],
                       train_avg['recall_avg_stage2'])
        try:
            # send_msn(epoch, f1=val_avg['f1_avg'])
            send_email(output_name)
        except:
            pass

        if epoch % 1 == 0:
            save_model_name_stage1 = os.path.join(args.model_save_dir,
                                                  'stage1' + output_name)
            save_model_name_stage2 = os.path.join(args.model_save_dir,
                                                  'stage2' + output_name)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model1.state_dict(),
                    'optimizer': optimizer1.state_dict()
                }, save_model_name_stage1)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model1.state_dict(),
                    'optimizer': optimizer2.state_dict()
                }, save_model_name_stage2)

        scheduler1.step(epoch=epoch)
        scheduler2.step(epoch=epoch)
    print('训练已完成!')
Esempio n. 3
0
def main():
    args.cuda = True
    # 1 choose the data you want to use
    using_data = {
        'my_sp': True,
        'my_cm': True,
        'template_casia_casia': True,
        'template_coco_casia': False,
        'cod10k': True,
        'casia': False,
        'coverage': False,
        'columb': False,
        'negative': True,
        'negative_casia': False,
        'texture_sp': True,
        'texture_cm': False,
    }
    using_data_test = {
        'my_sp': False,
        'my_cm': False,
        'template_casia_casia': False,
        'template_coco_casia': False,
        'cod10k': False,
        'casia': False,
        'coverage': True,
        'columb': False,
        'negative': False,
        'negative_casia': False,
    }
    # 2 define 3 types
    trainData = TamperDataset(stage_type='stage1',
                              using_data=using_data,
                              train_val_test_mode='train',
                              device='wkl')
    valData = TamperDataset(stage_type='stage1',
                            using_data=using_data,
                            train_val_test_mode='val',
                            device='wkl')
    testData = TamperDataset(stage_type='stage1',
                             using_data=using_data_test,
                             train_val_test_mode='test',
                             device='wkl')

    # 3 specific dataloader
    trainDataLoader = torch.utils.data.DataLoader(trainData,
                                                  batch_size=args.batch_size,
                                                  num_workers=4,
                                                  shuffle=True,
                                                  pin_memory=True)
    valDataLoader = torch.utils.data.DataLoader(valData,
                                                batch_size=args.batch_size,
                                                num_workers=4)

    testDataLoader = torch.utils.data.DataLoader(testData,
                                                 batch_size=1,
                                                 num_workers=1)
    # model
    model = Net(bilinear=True)
    if torch.cuda.is_available():
        model.cuda()
    else:
        model.cpu()

    model.apply(weights_init)
    # 模型初始化
    # 如果没有这一步会根据正态分布自动初始化
    # model.apply(weights_init)

    # 模型可持续化

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.999),
                           eps=1e-8)

    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}'".format(args.resume))
            # optimizer.load_state_dict(checkpoint['optimizer'])

        else:
            print("=> 想要使用预训练模型,但是路径出错 '{}'".format(args.resume))
            sys.exit(1)

    else:
        print("=> 不使用预训练模型,直接开始训练 '{}'".format(args.resume))

    # 调整学习率
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.stepsize,
                                    gamma=args.gamma)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)
    # 数据迭代器

    for epoch in range(args.start_epoch, args.maxepoch):
        train_avg = train(model=model,
                          optimizer=optimizer,
                          dataParser=trainDataLoader,
                          epoch=epoch)
        val_avg = val(model=model, dataParser=valDataLoader, epoch=epoch)
        test_avg = test(model=model, dataParser=testDataLoader, epoch=epoch)
        """""" """""" """""" """""" """"""
        "          写入图            "
        """""" """""" """""" """""" """"""
        try:

            writer.add_scalar('lr_per_epoch',
                              scheduler.get_lr(),
                              global_step=epoch)
            writer.add_scalars('tr-val-test_avg_loss_per_epoch', {
                'train': train_avg['loss_avg'],
                'val': val_avg['loss_avg']
            },
                               global_step=epoch)
            writer.add_scalars('tr-val-test_avg_f1_per_epoch', {
                'train': train_avg['f1_avg'],
                'val': val_avg['f1_avg']
            },
                               global_step=epoch)

            writer.add_scalars('tr-val-test_avg_precision_per_epoch', {
                'train': train_avg['precision_avg'],
                'val': val_avg['precision_avg']
            },
                               global_step=epoch)
            writer.add_scalars('tr-val-test_avg_acc_per_epoch', {
                'train': train_avg['accuracy_avg'],
                'val': val_avg['accuracy_avg']
            },
                               global_step=epoch)
            writer.add_scalars('tr-val-test_avg_recall_per_epoch', {
                'train': train_avg['recall_avg'],
                'val': val_avg['recall_avg']
            },
                               global_step=epoch)

        except Exception as e:
            print(e)
        """""" """""" """""" """""" """"""
        "          写入图            "
        """""" """""" """""" """""" """"""

        # 保存模型
        """
                    info = 'Epoch: [{0}/{1}][{2}/{3}] '.format(epoch, args.maxepoch, batch_index, dataParser.steps_per_epoch) + \
                   'Time {batch_time.val:.3f} (avg:{batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
                   'Loss {loss.val:f} (avg:{loss.avg:f}) '.format(loss=losses) + \
                   'f1_score {f1.val:f} (avg:{f1.avg:f}) '.format(f1=f1_value) + \
                   'precision_score: {precision.val:f} (avg:{precision.avg:f}) '.format(precision=precision_value) + \
                   'acc_score {acc.val:f} (avg:{acc.avg:f})'.format(acc=acc_value) +\
                   'recall_score {recall.val:f} (avg:{recall.avg:f})'.format(recall=recall_value)

        """
        output_name = output_name_file_name % \
                      (epoch, val_avg['loss_avg'],
                       val_avg['f1_avg'],
                       val_avg['precision_avg'],
                       val_avg['accuracy_avg'],
                       val_avg['recall_avg'])

        if epoch % 1 == 0:
            save_model_name = os.path.join(args.model_save_dir, output_name)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, save_model_name)

        scheduler.step(epoch=epoch)

    print('训练已完成!')
def main():
    args.cuda = True
    # 1 choose the data you want to use
    using_data = {
        'my_sp': True,
        'my_cm': True,
        'template_casia_casia': True,
        'template_coco_casia': False,
        'cod10k': True,
        'casia': False,
        'copy_move': False,
        'texture_sp': True,
        'texture_cm': True,
        'columb': False,
        'negative': True,
        'negative_casia': False,
    }

    using_data_test = {
        'my_sp': False,
        'my_cm': False,
        'template_casia_casia': False,
        'template_coco_casia': False,
        'cod10k': False,
        'casia': False,
        'coverage': True,
        'columb': False,
        'negative': False,
        'negative_casia': False,
    }
    # 2 define 3 types
    trainData = TamperDataset(stage_type='stage2',
                              using_data=using_data,
                              train_val_test_mode='train',
                              device='413')
    valData = TamperDataset(stage_type='stage2',
                            using_data=using_data,
                            train_val_test_mode='val',
                            device='413')
    testData = TamperDataset(stage_type='stage2',
                             using_data=using_data_test,
                             train_val_test_mode='test',
                             device='413')

    # 3 specific dataloader
    trainDataLoader = torch.utils.data.DataLoader(trainData,
                                                  batch_size=args.batch_size,
                                                  num_workers=8,
                                                  shuffle=True,
                                                  pin_memory=False)
    valDataLoader = torch.utils.data.DataLoader(valData,
                                                batch_size=args.batch_size,
                                                num_workers=4,
                                                shuffle=True)

    testDataLoader = torch.utils.data.DataLoader(testData,
                                                 batch_size=1,
                                                 num_workers=0)
    # model
    model1 = Net1()
    model2 = Net2()
    model3 = Net3(2, 10)
    if torch.cuda.is_available():
        model1.cuda()
        model2.cuda()
        model3.cuda()
    else:
        model1.cpu()
        model2.cpu()

    # 模型初始化
    # 如果没有这一步会根据正态分布自动初始化
    model1.apply(weights_init)
    model2.apply(weights_init)
    model3.apply(weights_init)
    # 模型可持续化
    optimizer1 = optim.Adam(model1.parameters(),
                            lr=1e-2,
                            betas=(0.9, 0.999),
                            eps=1e-8)
    optimizer2 = optim.Adam(model2.parameters(),
                            lr=1e-2,
                            betas=(0.9, 0.999),
                            eps=1e-8)
    optimizer3 = optim.Adam(model3.parameters(),
                            lr=1e-2,
                            betas=(0.9, 0.999),
                            eps=1e-8)
    # 加载模型
    if isfile(args.resume[0]) and isfile(args.resume[1]):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint1 = torch.load(args.resume[0])
        checkpoint2 = torch.load(args.resume[1])
        model1.load_state_dict(checkpoint1['state_dict'])
        # optimizer1.load_state_dict(checkpoint1['optimizer'])
        ################################################
        model2.load_state_dict(checkpoint2['state_dict'])
        # optimizer2.load_state_dict(checkpoint2['optimizer'])
        print("=> loaded checkpoint '{}'".format(args.resume))
    elif isfile(args.resume[0]) and not isfile(args.resume[1]):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint1 = torch.load(args.resume[0])
        # checkpoint2 = torch.load(args.resume[1])
        model1.load_state_dict(checkpoint1['state_dict'])
        # optimizer1.load_state_dict(checkpoint1['optimizer'])
        ################################################
        # model2.load_state_dict(checkpoint2['state_dict'])
        # optimizer2.load_state_dict(checkpoint2['optimizer'])
        print("=> loaded checkpoint '{}'".format(args.resume))
    elif not isfile(args.resume[0]) and isfile(args.resume[1]):
        print("=> loading checkpoint '{}'".format(args.resume))
        # checkpoint1 = torch.load(args.resume[0])
        checkpoint2 = torch.load(args.resume[1])
        # model1.load_state_dict(checkpoint1['state_dict'])
        # optimizer1.load_state_dict(checkpoint1['optimizer'])
        ################################################
        model2.load_state_dict(checkpoint2['state_dict'])
        optimizer2.load_state_dict(checkpoint2['optimizer'])
        print("=> loaded checkpoint '{}'".format(args.resume))
    else:
        print("=> !!!!!!! checkpoint found at '{}'".format(args.resume))

    # 调整学习率
    scheduler1 = lr_scheduler.StepLR(optimizer1,
                                     step_size=args.stepsize,
                                     gamma=args.gamma)
    scheduler2 = lr_scheduler.StepLR(optimizer2,
                                     step_size=args.stepsize,
                                     gamma=args.gamma)
    scheduler3 = lr_scheduler.StepLR(optimizer3,
                                     step_size=args.stepsize,
                                     gamma=args.gamma)
    # 数据迭代器
    for epoch in range(args.start_epoch, args.maxepoch):
        train_avg = train(model1=model1,
                          model2=model2,
                          model3=model3,
                          optimizer1=optimizer1,
                          optimizer2=optimizer2,
                          optimizer3=optimizer3,
                          dataParser=trainDataLoader,
                          epoch=epoch)

        val_avg = val(model1=model1,
                      model2=model2,
                      model3=model3,
                      dataParser=valDataLoader,
                      epoch=epoch)
        test(model1=model1,
             model2=model2,
             model3=model3,
             dataParser=testDataLoader,
             epoch=epoch)
        """""" """""" """""" """""" """"""
        "          写入图             "
        """""" """""" """""" """""" """"""
        try:
            writer.add_scalars('lr_per_epoch', {
                'stage1': scheduler1.get_lr(),
                'stage2': scheduler2.get_lr()
            },
                               global_step=epoch)
            writer.add_scalars('tr-val_avg_loss_per_epoch', {
                'train': train_avg['loss_avg'],
                'val': val_avg['loss_avg'],
            },
                               global_step=epoch)
            writer.add_scalars('val_avg_f1_recall__precision_acc', {
                'precision': val_avg['precision_avg_stage2'],
                'acc': val_avg['accuracy_avg_stage2'],
                'f1': val_avg['f1_avg_stage2'],
                'recall': val_avg['recall_avg_stage2'],
            },
                               global_step=epoch)

            writer.add_scalars('tr_avg_map8_loss_per_epoch', {
                'map1': train_avg['map8_loss'][0],
                'map2': train_avg['map8_loss'][1],
                'map3': train_avg['map8_loss'][2],
                'map4': train_avg['map8_loss'][3],
                'map5': train_avg['map8_loss'][4],
                'map6': train_avg['map8_loss'][5],
                'map7': train_avg['map8_loss'][6],
                'map8': train_avg['map8_loss'][7]
            },
                               global_step=epoch)
            writer.add_scalars('val_avg_map8_loss_per_epoch', {
                'map1': train_avg['map8_loss'][0],
                'map2': val_avg['map8_loss'][1],
                'map3': val_avg['map8_loss'][2],
                'map4': val_avg['map8_loss'][3],
                'map5': val_avg['map8_loss'][4],
                'map6': val_avg['map8_loss'][5],
                'map7': val_avg['map8_loss'][6],
                'map8': val_avg['map8_loss'][7]
            },
                               global_step=epoch)

        except Exception as e:
            print(e)
        """""" """""" """""" """""" """"""
        "          写入图            "
        """""" """""" """""" """""" """"""

        output_name1 = output_name_file_name % \
                       (epoch, val_avg['loss_avg'],
                        val_avg['f1_avg_stage1'],
                        val_avg['precision_avg_stage1'],
                        val_avg['accuracy_avg_stage1'],
                        val_avg['recall_avg_stage1'])
        output_name2 = output_name_file_name % \
                       (epoch, val_avg['loss_avg'],
                        val_avg['f1_avg_stage2'],
                        val_avg['precision_avg_stage2'],
                        val_avg['accuracy_avg_stage2'],
                        val_avg['recall_avg_stage2'])

        if epoch % 1 == 0:
            save_model_name_stage1 = os.path.join(args.model_save_dir,
                                                  'stage1_' + output_name1)
            save_model_name_stage2 = os.path.join(args.model_save_dir,
                                                  'stage2_' + output_name2)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model1.state_dict(),
                    'optimizer': optimizer1.state_dict()
                }, save_model_name_stage1)
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model2.state_dict(),
                    'optimizer': optimizer2.state_dict()
                }, save_model_name_stage2)

        scheduler1.step(epoch=epoch)
        scheduler2.step(epoch=epoch)
    print('训练已完成!')