def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',          type=str,   default='F:\\sintel',
    help='base folder of datasets')
    parser.add_argument('--split',              type=str,   default='SceneSplit')
    parser.add_argument('--mode',               type=str,   default='two')
    parser.add_argument('--save_path',          type=str,   default='MPI_logs_new\\RIID_new_RIN_updateLR1_epoch240_CosBF_VGG0.1_shading_SceneSplit_selayer1_reflmultiSize\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',                 type=float, default=0.0005,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=120)
    parser.add_argument('--batch_size',         type=int,   default=20)
    parser.add_argument('--checkpoint',         type=bool,  default=False)
    parser.add_argument('--state_dict',         type=str,   default='composer_state.t7')
    parser.add_argument('--cuda',               type=str,   default='cuda')
    parser.add_argument('--choose',             type=str,   default='refl')
    parser.add_argument('--refl_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--shad_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--refl_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--shad_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--refl_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--shad_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--refl_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--refl_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--shad_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--refl_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--image_size',         type=StrToInt, default=256)
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se).to(device)
    shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size, low_se=args.shad_low_se, skip_se=args.shad_skip_se).to(device)
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device)

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 0
    if args.checkpoint:
        composer.load_state_dict(torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)
    
    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.OctaveTrainer(composer, train_loader, device, writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        trainer.train()

        if (epoch + 1) % 40 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
            
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)

        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f:
            f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            if args.save_model:
                state = composer.reflectance.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_reflectance_state.t7'))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss))
        
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            if args.save_model:
                state = composer.shading.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_shading_state.t7'))
            with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))
예제 #2
0
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\sintel',
                        help='base folder of datasets')
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='one')
    parser.add_argument('--choose', type=str, default='refl')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_mish_ranger_Octave_one_bs10_updateLR1_epoch240_CosBF_VGG0.1_refl_SceneSplit\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=240)
    parser.add_argument('--batch_size', type=StrToInt, default=10)
    parser.add_argument('--checkpoint', type=bool, default=True)
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')

    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device("cuda: 1" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    # shader = RIN.Shader(output_ch=3)
    composer = octave.get_mish_model_one_output().to(device)
    # composer = RIN.Composer(reflection, shader).to(device)

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 119
    if args.checkpoint:
        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')
        # cur_epoch = int(args.state_dict.split('_')[-1].split('.')[0]) + 1

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.OctaveTrainer(composer,
                                         train_loader,
                                         args.lr,
                                         device,
                                         writer,
                                         mode=args.mode,
                                         choose=args.choose)

    best_average_loss = 0.01970498738396499

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        step = trainer.train()

        if (epoch + 1) % 120 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
        if args.mode == 'two':
            albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
                composer, test_loader, device)
            average_loss = (albedo_test_loss + shading_test_loss) / 2
            writer.add_scalar('A_mse', albedo_test_loss, epoch)
            writer.add_scalar('S_mse', shading_test_loss, epoch)
            writer.add_scalar('aver_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write(
                    'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                    .format(epoch, average_loss, albedo_test_loss,
                            shading_test_loss))
        else:
            average_loss = RIN_pipeline.MPI_test_unet_one(composer,
                                                          test_loader,
                                                          device,
                                                          choose=args.choose)
            writer.add_scalar('test_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write('epoch{} --- test_loss: {}\n'.format(
                    epoch, average_loss))

        if average_loss < best_average_loss:
            best_average_loss = average_loss
            if args.save_model:
                state = composer.state_dict()
                torch.save(state,
                           os.path.join(args.save_path, 'composer_state.t7'))

            with open(os.path.join(args.save_path, 'loss.txt'), 'a+') as f:
                if args.mode == 'two':
                    f.write(
                        'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                        .format(epoch, average_loss, albedo_test_loss,
                                shading_test_loss))
                else:
                    f.write('epoch{} --- average_loss: {}\n'.format(
                        epoch, average_loss))