def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\sintel',
                        help='base folder of datasets')
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_SceneSplit\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=240)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--checkpoint', type=bool, default=False)
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device("cuda: 1" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    shader = RIN.Shader(output_ch=3)
    reflection = RIN.Decomposer()
    composer = RIN.Composer(reflection, shader).to(device)
    # shader = RIN.Shader()
    # decomposer = RIN.Decomposer()
    # composer = RIN.Composer(decomposer, shader).to(device)
    # RIN.init_weights(composer, init_type='kaiming')

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 0
    if args.checkpoint:
        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')
        cur_epoch = cur_epoch = int(
            args.state_dict.split('_')[-1].split('.')[0]) + 1

    # train_transform = RIN_pipeline.MPI_Train_Agumentation()
    # train_set = RIN_pipeline.MPI_Dataset(args.data_path, mode=args.mode[0], transform=train_transform)
    # train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    # test_transform = RIN_pipeline.MPI_Test_Agumentation()
    # test_set = RIN_pipeline.MPI_Dataset(args.data_path, mode=args.mode[1], transform=test_transform)
    # test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=True)

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=True)

    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.MPI_TrainerOriginRemoveShape(
        composer, train_loader, args.lr, device, writer)

    best_average_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        step = trainer.train()

        if (epoch + 1) % 120 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)

            if args.save_model:
                state = composer.state_dict()
                torch.save(
                    state,
                    os.path.join(args.save_path,
                                 'composer_state_{}.t7'.format(epoch)))

        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_remove_shape(
            composer, test_loader, device)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)

        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))

        if average_loss < best_average_loss:
            best_average_loss = average_loss
            if args.save_model:
                state = composer.state_dict()
                torch.save(
                    state,
                    os.path.join(args.save_path,
                                 'composer_state_{}.t7'.format(epoch)))
            #RIN_pipeline.visualize_MPI(composer, test_loader, device, os.path.join(args.save_path, 'image_{}.png'.format(epoch)))
            with open(os.path.join(args.save_path, 'loss.txt'), 'a+') as f:
                f.write(
                    'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                    .format(epoch, average_loss, albedo_test_loss,
                            shading_test_loss))
Esempio n. 2
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_log_paper\\GAN_RIID_updateLR3_epoch100_CosbfVGG_SceneSplit_refl-se-skip_shad-se-low_multi_new_shadSqueeze_256_Reduction2\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state_25.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state_60.t7')
    parser.add_argument('--refl_skip_se', type=StrToBool, default=True)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=True)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=True)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=True)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=True)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_reduction', type=StrToInt, default=2)
    parser.add_argument('--shad_reduction', type=StrToInt, default=2)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    args = parser.parse_args()

    device = torch.device(args.cuda)
    reflectance = RIN.SEDecomposerSingle(
        multi_size=args.refl_multi_size,
        low_se=args.refl_low_se,
        skip_se=args.refl_skip_se,
        detach=args.refl_detach_flag,
        reduction=args.refl_reduction).to(device)
    shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size,
                                     low_se=args.shad_low_se,
                                     skip_se=args.shad_skip_se,
                                     se_squeeze=args.shad_squeeze_flag,
                                     reduction=args.shad_reduction,
                                     detach=args.shad_detach_flag).to(device)
    reflectance.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.refl_checkpoint,
                         args.state_dict_refl)))
    shading.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.shad_checkpoint,
                         args.state_dict_shad)))
    print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    if args.fullsize:
        print('test fullsize....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'
        h, w = 436, 1024
        pad_h, pad_w = clc_pad(h, w, 16)
        print(pad_h, pad_w)
        tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h))
        tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h))
    else:
        print('test size256....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.fullsize:
        check_folder(os.path.join(args.save_path, "refl_target_fullsize"))
        check_folder(os.path.join(args.save_path, "refl_output_fullsize"))
        check_folder(os.path.join(args.save_path, "shad_target_fullsize"))
        check_folder(os.path.join(args.save_path, "shad_output_fullsize"))
    else:
        check_folder(os.path.join(args.save_path, "refl_target"))
        check_folder(os.path.join(args.save_path, "shad_target"))
        check_folder(os.path.join(args.save_path, "refl_output"))
        check_folder(os.path.join(args.save_path, "shad_output"))

    ToPIL = transforms.ToPILImage()

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            if args.fullsize:
                input_g = tmp_pad(input_g)
            if args.refl_multi_size and args.shad_multi_size:
                albedo_fake, shading_fake, _, _ = composer.forward(input_g)
            elif args.refl_multi_size or args.shad_multi_size:
                albedo_fake, shading_fake, _ = composer.forward(input_g)
            else:
                albedo_fake, shading_fake = composer.forward(input_g)
            if args.fullsize:
                albedo_fake, shading_fake = tmp_inversepad(
                    albedo_fake), tmp_inversepad(shading_fake)

            albedo_fake = albedo_fake * mask_g

            # lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1,2,0)
            # lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1,2,0)
            # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0)

            albedo_fake = albedo_fake.cpu().clamp(0, 1)
            shading_fake = shading_fake.cpu().clamp(0, 1)
            albedo_g = albedo_g.cpu().clamp(0, 1)
            shading_g = shading_g.cpu().clamp(0, 1)

            lab_refl_targ = ToPIL(albedo_g.squeeze())
            lab_sha_targ = ToPIL(shading_g.squeeze())
            refl_pred = ToPIL(albedo_fake.squeeze())
            sha_pred = ToPIL(shading_fake.squeeze())

            lab_refl_targ.save(
                os.path.join(
                    args.save_path,
                    "refl_target_fullsize" if args.fullsize else "refl_target",
                    "{}.png".format(ind)))
            lab_sha_targ.save(
                os.path.join(
                    args.save_path,
                    "shad_target_fullsize" if args.fullsize else "shad_target",
                    "{}.png".format(ind)))
            refl_pred.save(
                os.path.join(
                    args.save_path,
                    "refl_output_fullsize" if args.fullsize else "refl_output",
                    "{}.png".format(ind)))
            sha_pred.save(
                os.path.join(
                    args.save_path,
                    "shad_output_fullsize" if args.fullsize else "shad_output",
                    "{}.png".format(ind)))
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',          type=str,   default='F:\\sintel',
    help='base folder of datasets')
    parser.add_argument('--split',              type=str,   default='SceneSplit')
    parser.add_argument('--mode',               type=str,   default='two')
    parser.add_argument('--save_path',          type=str,   default='MPI_logs_new\\RIID_new_RIN_updateLR1_epoch240_CosBF_VGG0.1_shading_SceneSplit_selayer1_reflmultiSize\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',                 type=float, default=0.0005,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=120)
    parser.add_argument('--batch_size',         type=int,   default=20)
    parser.add_argument('--checkpoint',         type=bool,  default=False)
    parser.add_argument('--state_dict',         type=str,   default='composer_state.t7')
    parser.add_argument('--cuda',               type=str,   default='cuda')
    parser.add_argument('--choose',             type=str,   default='refl')
    parser.add_argument('--refl_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--shad_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--refl_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--shad_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--refl_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--shad_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--refl_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--refl_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--shad_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--refl_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--image_size',         type=StrToInt, default=256)
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se).to(device)
    shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size, low_se=args.shad_low_se, skip_se=args.shad_skip_se).to(device)
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device)

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 0
    if args.checkpoint:
        composer.load_state_dict(torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)
    
    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.OctaveTrainer(composer, train_loader, device, writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        trainer.train()

        if (epoch + 1) % 40 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
            
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)

        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f:
            f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            if args.save_model:
                state = composer.reflectance.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_reflectance_state.t7'))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss))
        
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            if args.save_model:
                state = composer.shading.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_shading_state.t7'))
            with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',          type=str,   default='F:\\MPI-sintel-PyrResNet\\Sintel\\images\\',
    help='base folder of datasets')
    parser.add_argument('--split',              type=str,   default='SceneSplit')
    parser.add_argument('--mode',               type=str,   default='train')
    parser.add_argument('--save_path',          type=str,   default='MPI_logs\\RIID_new_RIN_updateLR1_epoch240_CosBF_VGG0.1_shading_SceneSplit_GAN_selayer1_ReflMultiSize_DA\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',                 type=float, default=0.0005,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=1000)
    parser.add_argument('--batch_size',         type=int,   default=20)
    parser.add_argument('--checkpoint',         type=bool,  default=False)
    parser.add_argument('--state_dict',         type=str,   default='composer_state.t7')
    parser.add_argument('--skip_se',            type=StrToBool, default=False)
    parser.add_argument('--cuda',               type=str,   default='cuda:1')
    parser.add_argument('--dilation',           type=StrToBool,   default=False)
    parser.add_argument('--se_improved',        type=StrToBool,  default=False)
    parser.add_argument('--weight_decay',       type=float, default=0.0001)
    parser.add_argument('--refl_multi_size',    type=bool,  default=False)
    parser.add_argument('--shad_multi_size',    type=bool,  default=False)
    parser.add_argument('--data_augmentation',  type=bool,  default=True)
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    # shader = RIN.Shader(output_ch=3)
    print(args.skip_se)
    Generator_R = RIN.SESingleGenerator(multi_size=args.refl_multi_size).to(device)
    Generator_S = RIN.SESingleGenerator(multi_size=args.shad_multi_size).to(device)
    composer = RIN.SEComposerGenerater(Generator_R, Generator_S, args.refl_multi_size, args.shad_multi_size).to(device)
    Discriminator_R = RIN.SEUG_Discriminator().to(device)
    Discriminator_S = RIN.SEUG_Discriminator().to(device)
    # composer = RIN.Composer(reflection, shader).to(device)

    
    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    if args.data_augmentation:
        MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-train.txt'
    else:
        MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 0
    if args.checkpoint:
        composer.load_state_dict(torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')
        # cur_epoch = int(args.state_dict.split('_')[-1].split('.')[0]) + 1
    if args.data_augmentation:
        train_transform = RIN_pipeline.MPI_Train_Agumentation()
    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt, transform=train_transform if args.data_augmentation else None, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)
    
    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.SEUGTrainer(composer, Discriminator_R, Discriminator_S, train_loader, args.lr, device, writer, weight_decay=args.weight_decay, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        step = trainer.train()

        if (epoch + 1) % 100 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
            
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)

        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f:
            f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            if args.save_model:
                state = composer.reflectance.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_reflectance_state.t7'))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss))
        
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            if args.save_model:
                state = composer.shading.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_shading_state.t7'))
            with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))
Esempio n. 5
0
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\sintel',
                        help='base folder of datasets')
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='one')
    parser.add_argument('--choose', type=str, default='refl')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_mish_ranger_Octave_one_bs10_updateLR1_epoch240_CosBF_VGG0.1_refl_SceneSplit\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=240)
    parser.add_argument('--batch_size', type=StrToInt, default=10)
    parser.add_argument('--checkpoint', type=bool, default=True)
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')

    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device("cuda: 1" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    # shader = RIN.Shader(output_ch=3)
    composer = octave.get_mish_model_one_output().to(device)
    # composer = RIN.Composer(reflection, shader).to(device)

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 119
    if args.checkpoint:
        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')
        # cur_epoch = int(args.state_dict.split('_')[-1].split('.')[0]) + 1

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.OctaveTrainer(composer,
                                         train_loader,
                                         args.lr,
                                         device,
                                         writer,
                                         mode=args.mode,
                                         choose=args.choose)

    best_average_loss = 0.01970498738396499

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        step = trainer.train()

        if (epoch + 1) % 120 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
        if args.mode == 'two':
            albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
                composer, test_loader, device)
            average_loss = (albedo_test_loss + shading_test_loss) / 2
            writer.add_scalar('A_mse', albedo_test_loss, epoch)
            writer.add_scalar('S_mse', shading_test_loss, epoch)
            writer.add_scalar('aver_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write(
                    'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                    .format(epoch, average_loss, albedo_test_loss,
                            shading_test_loss))
        else:
            average_loss = RIN_pipeline.MPI_test_unet_one(composer,
                                                          test_loader,
                                                          device,
                                                          choose=args.choose)
            writer.add_scalar('test_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write('epoch{} --- test_loss: {}\n'.format(
                    epoch, average_loss))

        if average_loss < best_average_loss:
            best_average_loss = average_loss
            if args.save_model:
                state = composer.state_dict()
                torch.save(state,
                           os.path.join(args.save_path, 'composer_state.t7'))

            with open(os.path.join(args.save_path, 'loss.txt'), 'a+') as f:
                if args.mode == 'two':
                    f.write(
                        'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                        .format(epoch, average_loss, albedo_test_loss,
                                shading_test_loss))
                else:
                    f.write('epoch{} --- average_loss: {}\n'.format(
                        epoch, average_loss))
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='ImageSplit')
    parser.add_argument('--mode', type=str, default='test')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_epoch240_ImageSplit_size256_remove_shape\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--state_dict',
                        type=str,
                        default='composer_state_222.t7')
    args = parser.parse_args()

    # pylint: disable=E1101
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    shader = RIN.Shader(output_ch=3)
    reflection = RIN.Decomposer()
    composer = RIN.Composer(reflection, shader).to(device)
    # RIN.init_weights(composer, init_type='kaiming')
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    composer.load_state_dict(
        torch.load(os.path.join(args.save_path, args.state_dict)))
    print('load checkpoint success!')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    check_folder(os.path.join(args.save_path, "refl_target"))
    check_folder(os.path.join(args.save_path, "shad_target"))
    check_folder(os.path.join(args.save_path, "refl_output"))
    check_folder(os.path.join(args.save_path, "shad_output"))
    # check_folder(os.path.join(args.save_path, "shape_output"))
    check_folder(os.path.join(args.save_path, "mask"))

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            _, albedo_fake, shading_fake = composer.forward(input_g)
            albedo_fake = albedo_fake * mask_g

            lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            mask = mask_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1, 2, 0)
            sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1, 2, 0)

            lab_refl_targ = np.clip(lab_refl_targ, 0, 1)
            lab_sha_targ = np.clip(lab_sha_targ, 0, 1)
            refl_pred = np.clip(refl_pred, 0, 1)
            sha_pred = np.clip(sha_pred, 0, 1)
            mask = np.clip(mask, 0, 1)

            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_target",
                             "{}.png".format(ind)), lab_refl_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_target",
                             "{}.png".format(ind)), lab_sha_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "mask", "{}.png".format(ind)),
                mask)
            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_output",
                             "{}.png".format(ind)), refl_pred)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_output",
                             "{}.png".format(ind)), sha_pred)
Esempio n. 7
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs_new\\GAN_RIID_updateLR3_epoch160_CosbfVGG_SceneSplit_refl-se-skip_shad-se-low_multi_new_shadSqueeze_grad\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state_81.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state_81.t7')
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    parser.add_argument('--heatmap', type=StrToBool, default=False)
    args = parser.parse_args()

    device = torch.device(args.cuda)
    model = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size,
                                   low_se=args.refl_low_se,
                                   skip_se=args.refl_skip_se,
                                   detach=args.refl_detach_flag,
                                   reduction=args.refl_reduction,
                                   heatmap=args.heatmap).to(device)
    model.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.refl_checkpoint,
                         args.state_dict_refl)))
    print('load checkpoint success!')

    if args.fullsize:
        print('test fullsize....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'
        h, w = 436, 1024
        pad_h, pad_w = clc_pad(h, w, 16)
        print(pad_h, pad_w)
        tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h))
        tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h))
        tmp_inversepad_heatmap = nn.ReflectionPad2d((0, 0, 0, -3))
    else:
        print('test size256....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.fullsize:
        check_folder(os.path.join(args.save_path, "refl_heapmapin"))
        check_folder(os.path.join(args.save_path, "refl_heapmapout"))

    ToPIL = transforms.ToPILImage()

    model.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            if args.fullsize:
                input_g = tmp_pad(input_g)
            if args.refl_multi_size:
                albedo_fake, _, heapmap = model.forward(input_g)
            else:
                albedo_fake = model.forward(input_g)
            if args.fullsize:
                input_g = tmp_inversepad(input_g)
                heapmap[0] = tmp_inversepad_heatmap(heapmap[0])
                heapmap[1] = tmp_inversepad_heatmap(heapmap[1])

            # albedo_fake  = albedo_fake*mask_g

            # lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1,2,0)
            # lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1,2,0)
            # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0)
            print(heapmap[0].squeeze().size())
            heapmap[0] = torch.sum(heapmap[0], dim=1, keepdim=True)
            heapmap[1] = torch.sum(heapmap[1], dim=1, keepdim=True)
            heapmapin = tensor2numpy(heapmap[0][0])
            heapmapout = tensor2numpy(heapmap[1][0])
            #heapmapout = torch.sum(heapmap[1].squeeze(), dim=0, keepdim=True).cpu().clamp(0,1).numpy().transpose(1,2,0)
            print(heapmapin.shape)
            heapmapin = cam(heapmapin)
            print(heapmapin.shape)
            heapmapout = cam(heapmapout)
            # heapmapin = heapmapin.transpose(2,0,1)
            # heapmapout = heapmapout.transpose(2,0,1)
            # input_g = input_g.squeeze().cpu().clamp(0, 1).numpy()
            # print(heapmapin.shape)
            # print(input_g.shape)
            # heapmapin = np.concatenate((heapmapin, input_g), 1).astype(np.float32)
            # heapmapout = np.concatenate((heapmapout, input_g), 1).astype(np.float32)
            # print(heapmapin.shape)
            # heapmapin = torch.from_numpy(heapmapin)
            # heapmapout = torch.from_numpy(heapmapout)
            # lab_refl_targ = ToPIL(input_g.squeeze())
            # refl_pred = ToPIL(albedo_fake.squeeze())

            # heapmapin = torch.cat([heapmapin, torch.zeros(2, h // 4, w // 4)])
            # heapmapout = torch.cat([heapmapout, torch.zeros(2, h // 4, w // 4)])

            # print(heapmapin.size)
            # print(heapmapout.size)
            cv2.imwrite(
                os.path.join(args.save_path, "refl_heapmapin",
                             '{}.png'.format(ind)), heapmapin * 255.0)
            cv2.imwrite(
                os.path.join(args.save_path, "refl_heapmapout",
                             '{}.png'.format(ind)), heapmapout * 255.0)
Esempio n. 8
0
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default='MPI_logs_vqvae\\vqvae_base_256x256\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--skip_se', type=StrToBool, default=False)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--dilation', type=StrToBool, default=False)
    parser.add_argument('--se_improved', type=StrToBool, default=False)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--refl_bn', type=StrToBool, default=True)
    parser.add_argument('--shad_bn', type=StrToBool, default=True)
    parser.add_argument('--refl_act', type=str, default='relu')
    parser.add_argument('--shad_act', type=str, default='relu')
    # parser.add_argument('--refl_gan',           type=StrToBool,  default=False)
    # parser.add_argument('--shad_gan',           type=StrToBool,  default=False)
    parser.add_argument('--data_augmentation', type=StrToBool, default=False)
    parser.add_argument('--fullsize', type=StrToBool, default=False)
    parser.add_argument('--fullsize_test', type=StrToBool, default=False)
    parser.add_argument('--image_size', type=StrToInt, default=256)
    parser.add_argument('--ttur', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=True)
    parser.add_argument('--vq_flag', type=StrToBool, default=True)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.VQVAE(vq_flag=args.vq_flag).to(device)
    shading = RIN.VQVAE(vq_flag=args.vq_flag).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    # if not args.ttur:
    #     Discriminator_R = RIN.SEUG_Discriminator().to(device)
    #     Discriminator_S = RIN.SEUG_Discriminator().to(device)
    # else:
    #     Discriminator_R = RIN.SEUG_Discriminator_new().to(device)
    #     Discriminator_S = RIN.SEUG_Discriminator_new().to(device)

    if args.data_augmentation:
        print('data_augmentation.....')
        if args.fullsize:
            MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-train.txt'
            MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-train.txt'
        else:
            MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-320-train.txt'
            MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-320-train.txt'
    else:
        MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
        MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'

    if args.fullsize_test:
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
    else:
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'

    if args.fullsize_test:
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'
    else:
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    if args.data_augmentation:
        print('augmentation...')
        train_transform = RIN_pipeline.MPI_Train_Agumentation_fy2()

    train_set = RIN_pipeline.MPI_Dataset_Revisit(
        train_txt,
        transform=train_transform if args.data_augmentation else None,
        refl_multi_size=args.refl_multi_size,
        shad_multi_size=args.shad_multi_size,
        image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.mode == 'test':
        print('test mode .....')
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        print('albedo_test_loss: ', albedo_test_loss)
        print('shading_test_loss: ', shading_test_loss)
        return

    writer = SummaryWriter(log_dir=args.save_path)

    if not args.ttur:
        trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                            writer, args)
    else:
        trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                            writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if (epoch + 1) % 40 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)

        # if (epoch + 1) % 10 == 0:
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))
        if albedo_test_loss < best_albedo_loss:
            state = composer.reflectance.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.refl_checkpoint,
                             'composer_reflectance_state_{}.t7'.format(epoch)))
            best_albedo_loss = albedo_test_loss
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(
                    epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            state = composer.shading.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.shad_checkpoint,
                             'composer_shading_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'shading_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(
                    epoch, shading_test_loss))
Esempio n. 9
0
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='test')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_SceneSplit\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--state_dict',
                        type=str,
                        default='composer_state_179.t7')
    args = parser.parse_args()

    # pylint: disable=E1101
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    shader = RIN.Shader(output_ch=3)
    reflection = RIN.Decomposer()
    composer = RIN.Composer(reflection, shader).to(device)

    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    composer.load_state_dict(
        torch.load(os.path.join(args.save_path, args.state_dict)))
    print('load checkpoint success!')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    check_folder(os.path.join(args.save_path, "refl_target_fullsize"))
    check_folder(os.path.join(args.save_path, "shad_target_fullsize"))
    check_folder(os.path.join(args.save_path, "refl_output_fullsize"))
    check_folder(os.path.join(args.save_path, "shad_output_fullsize"))
    check_folder(os.path.join(args.save_path, "shape_output_fullsize"))
    check_folder(os.path.join(args.save_path, "mask_fullsize"))

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            lab_refl_pred = np.zeros(
                (input_g.size()[2], input_g.size()[3], input_g.size()[1]))
            lab_sha_pred = np.zeros_like(lab_refl_pred)
            lab_shape_pred = np.zeros_like(lab_refl_pred)
            for ind2 in range(8):
                if ind2 % 2 == 0:
                    input_g_tmp = input_g[:, :, :256, (ind2 // 2) *
                                          256:(ind2 // 2 + 1) * 256]
                    mask_g_tmp = mask_g[:, :, :256, (ind2 // 2) *
                                        256:(ind2 // 2 + 1) * 256]
                    _, albedo_fake, shading_fake, shape_fake = composer.forward(
                        input_g_tmp)
                    albedo_fake = albedo_fake * mask_g_tmp
                    lab_refl_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                  256, :] += albedo_fake.squeeze().cpu().numpy(
                                  ).transpose(1, 2, 0)
                    lab_sha_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                 256, :] += shading_fake.squeeze().cpu().numpy(
                                 ).transpose(1, 2, 0)
                    lab_shape_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                   256, :] += shape_fake.squeeze().cpu().numpy(
                                   ).transpose(1, 2, 0)
                else:
                    input_g_tmp = input_g[:, :, 180:, (ind2 // 2) *
                                          256:(ind2 // 2 + 1) * 256]
                    mask_g_tmp = mask_g[:, :, 180:, (ind2 // 2) *
                                        256:(ind2 // 2 + 1) * 256]
                    _, albedo_fake, shading_fake, shape_fake = composer.forward(
                        input_g_tmp)
                    albedo_fake = albedo_fake * mask_g_tmp
                    lab_refl_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                  256, :] += albedo_fake.squeeze().cpu().numpy(
                                  ).transpose(1, 2, 0)
                    lab_sha_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                 256, :] += shading_fake.squeeze().cpu().numpy(
                                 ).transpose(1, 2, 0)
                    lab_shape_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                   256, :] += shape_fake.squeeze().cpu().numpy(
                                   ).transpose(1, 2, 0)

            lab_refl_pred[180:256, :, :] /= 2
            lab_sha_pred[180:256, :, :] /= 2
            lab_shape_pred[180:256, :, :] /= 2

            lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            mask = mask_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # shape_pred = shape_fake.squeeze().cpu().numpy().transpose(1,2,0)
            refl_pred = lab_refl_pred
            sha_pred = lab_sha_pred
            shape_pred = lab_shape_pred

            lab_refl_targ = np.clip(lab_refl_targ, 0, 1)
            lab_sha_targ = np.clip(lab_sha_targ, 0, 1)
            refl_pred = np.clip(refl_pred, 0, 1)
            sha_pred = np.clip(sha_pred, 0, 1)
            shape_pred = np.clip(shape_pred, 0, 1)
            mask = np.clip(mask, 0, 1)

            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_target_fullsize",
                             "{}.png".format(ind)), lab_refl_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_target_fullsize",
                             "{}.png".format(ind)), lab_sha_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "mask_fullsize",
                             "{}.png".format(ind)), mask)
            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_output_fullsize",
                             "{}.png".format(ind)), refl_pred)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_output_fullsize",
                             "{}.png".format(ind)), sha_pred)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shape_output_fullsize",
                             "{}.png".format(ind)), shape_pred)