예제 #1
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_log_paper\\GAN_RIID_updateLR3_epoch100_CosbfVGG_SceneSplit_refl-se-skip_shad-se-low_multi_new_shadSqueeze_256_Reduction2\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state_25.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state_60.t7')
    parser.add_argument('--refl_skip_se', type=StrToBool, default=True)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=True)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=True)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=True)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=True)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_reduction', type=StrToInt, default=2)
    parser.add_argument('--shad_reduction', type=StrToInt, default=2)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    args = parser.parse_args()

    device = torch.device(args.cuda)
    reflectance = RIN.SEDecomposerSingle(
        multi_size=args.refl_multi_size,
        low_se=args.refl_low_se,
        skip_se=args.refl_skip_se,
        detach=args.refl_detach_flag,
        reduction=args.refl_reduction).to(device)
    shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size,
                                     low_se=args.shad_low_se,
                                     skip_se=args.shad_skip_se,
                                     se_squeeze=args.shad_squeeze_flag,
                                     reduction=args.shad_reduction,
                                     detach=args.shad_detach_flag).to(device)
    reflectance.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.refl_checkpoint,
                         args.state_dict_refl)))
    shading.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.shad_checkpoint,
                         args.state_dict_shad)))
    print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    if args.fullsize:
        print('test fullsize....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'
        h, w = 436, 1024
        pad_h, pad_w = clc_pad(h, w, 16)
        print(pad_h, pad_w)
        tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h))
        tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h))
    else:
        print('test size256....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.fullsize:
        check_folder(os.path.join(args.save_path, "refl_target_fullsize"))
        check_folder(os.path.join(args.save_path, "refl_output_fullsize"))
        check_folder(os.path.join(args.save_path, "shad_target_fullsize"))
        check_folder(os.path.join(args.save_path, "shad_output_fullsize"))
    else:
        check_folder(os.path.join(args.save_path, "refl_target"))
        check_folder(os.path.join(args.save_path, "shad_target"))
        check_folder(os.path.join(args.save_path, "refl_output"))
        check_folder(os.path.join(args.save_path, "shad_output"))

    ToPIL = transforms.ToPILImage()

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            if args.fullsize:
                input_g = tmp_pad(input_g)
            if args.refl_multi_size and args.shad_multi_size:
                albedo_fake, shading_fake, _, _ = composer.forward(input_g)
            elif args.refl_multi_size or args.shad_multi_size:
                albedo_fake, shading_fake, _ = composer.forward(input_g)
            else:
                albedo_fake, shading_fake = composer.forward(input_g)
            if args.fullsize:
                albedo_fake, shading_fake = tmp_inversepad(
                    albedo_fake), tmp_inversepad(shading_fake)

            albedo_fake = albedo_fake * mask_g

            # lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1,2,0)
            # lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1,2,0)
            # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0)

            albedo_fake = albedo_fake.cpu().clamp(0, 1)
            shading_fake = shading_fake.cpu().clamp(0, 1)
            albedo_g = albedo_g.cpu().clamp(0, 1)
            shading_g = shading_g.cpu().clamp(0, 1)

            lab_refl_targ = ToPIL(albedo_g.squeeze())
            lab_sha_targ = ToPIL(shading_g.squeeze())
            refl_pred = ToPIL(albedo_fake.squeeze())
            sha_pred = ToPIL(shading_fake.squeeze())

            lab_refl_targ.save(
                os.path.join(
                    args.save_path,
                    "refl_target_fullsize" if args.fullsize else "refl_target",
                    "{}.png".format(ind)))
            lab_sha_targ.save(
                os.path.join(
                    args.save_path,
                    "shad_target_fullsize" if args.fullsize else "shad_target",
                    "{}.png".format(ind)))
            refl_pred.save(
                os.path.join(
                    args.save_path,
                    "refl_output_fullsize" if args.fullsize else "refl_output",
                    "{}.png".format(ind)))
            sha_pred.save(
                os.path.join(
                    args.save_path,
                    "shad_output_fullsize" if args.fullsize else "shad_output",
                    "{}.png".format(ind)))
예제 #2
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\BOLD_dataset',
                        help='base folder of datasets')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default='logs\\lihao\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=4,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--supervision_set_size', type=int, default=33900)
    # parser.add_argument('--unsupervision_set_size',     type=int,   default=10170)
    parser.add_argument('--num_epochs', type=int, default=80)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--checkpoint', type=bool, default=False)
    parser.add_argument('--img_resize_shape', type=str, default=(256, 256))
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')
    parser.add_argument('--dataset', type=str, default='BOLD')
    parser.add_argument('--remove_names',
                        type=str,
                        default='F:\\ShapeNet\\remove')
    args = parser.parse_args()

    # pylint: disable=E1101
    device = torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101

    shader = RIN.Shader()
    # shader.load_state_dict(torch.load('logs/shader/shader_state_59.t7'))
    decomposer = RIN.Decomposer()
    # reflection.load_state_dict(torch.load('reflection_state.t7'))
    composer = RIN.Composer(decomposer, shader).to(device)
    # RIN.init_weights(composer, init_type='kaiming')

    cur_epoch = 0
    if args.checkpoint:
        # cur_epoch = int(args.state_dict.split('.')[0].split('_')[-1])
        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')

    if args.mode == 'train':
        if args.dataset == "ShapeNet":
            remove_names = os.listdir(args.remove_names)
            supervision_train_set = RIN_pipeline.ShapeNet_Dateset_new(
                args.data_path,
                size_per_dataset=args.supervision_set_size,
                mode='train',
                img_size=args.img_resize_shape,
                remove_names=remove_names)
            sv_train_loader = torch.utils.data.DataLoader(
                supervision_train_set,
                batch_size=args.batch_size,
                num_workers=args.loader_threads,
                shuffle=True)
            test_set = RIN_pipeline.ShapeNet_Dateset_new(
                args.data_path,
                size_per_dataset=20,
                mode='test',
                img_size=args.img_resize_shape,
                remove_names=remove_names)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)
        else:
            train_txt = 'lihao/train_list.txt'
            test_txt = 'lihao/test_list.txt'
            supervision_train_set = RIN_pipeline.BOLD_Dataset(
                args.data_path,
                size_per_dataset=args.supervision_set_size,
                mode='train',
                img_size=args.img_resize_shape,
                file_name=train_txt)
            sv_train_loader = torch.utils.data.DataLoader(
                supervision_train_set,
                batch_size=args.batch_size,
                num_workers=args.loader_threads,
                shuffle=True)
            test_set = RIN_pipeline.BOLD_Dataset(
                args.data_path,
                size_per_dataset=20,
                mode='test',
                img_size=args.img_resize_shape,
                file_name=test_txt)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)

        writer = SummaryWriter(log_dir=args.save_path)

        step = 0
        trainer = RIN_pipeline.ShapeNetSupervisionTrainer(
            composer,
            sv_train_loader,
            args.lr,
            device,
            writer,
            step,
            optim_choose=args.optimizer)
        for epoch in range(cur_epoch + 1, args.num_epochs):
            print('<Main> Epoch {}'.format(epoch))
            if epoch % 10 == 0:
                trainer.update_lr(args.lr * 0.75)
            if epoch < 100:
                step = trainer.train()
            # else:
            #     trainer = RIN_pipeline.UnsupervisionTrainer(composer, unsv_train_loader, args.lr, device, writer, new_step)
            #     new_step = trainer.train()

            if args.save_model:
                state = composer.state_dict()
                torch.save(state,
                           os.path.join(args.save_path, 'composer_state.t7'))

            # step += new_step
            # loss = RIN_pipeline.visualize_composer(composer, test_loader, device, os.path.join(args.save_path, '{}.png'.format(epoch)))
            # writer.add_scalar('test_recon_loss', loss[0], epoch)
            # writer.add_scalar('test_refl_loss', loss[1], epoch)
            # writer.add_scalar('test_sha_loss', loss[2], epoch)
    else:
        check_folder(os.path.join(args.save_path, "refl_target"))
        check_folder(os.path.join(args.save_path, "shad_target"))
        check_folder(os.path.join(args.save_path, "refl_output"))
        check_folder(os.path.join(args.save_path, "shad_output"))
        check_folder(os.path.join(args.save_path, "shape_output"))
        if args.dataset == "ShapeNet":
            # check_folder(os.path.join(args.save_path, "mask"))
            remove_names = os.listdir(args.remove_names)
            test_set = RIN_pipeline.ShapeNet_Dateset_new(
                args.data_path,
                size_per_dataset=9488,
                mode='test',
                img_size=args.img_resize_shape,
                remove_names=remove_names)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)
        else:
            test_txt = 'lihao/test_list.txt'
            test_set = RIN_pipeline.BOLD_Dataset(
                args.data_path,
                size_per_dataset=18984,
                mode='test',
                img_size=args.img_resize_shape,
                file_name=test_txt)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)

        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        composer.eval()
        with torch.no_grad():
            for ind, tensors in enumerate(test_loader):

                inp = [t.float().to(device) for t in tensors]
                try:
                    lab_inp, lab_refl_targ, lab_sha_targ, mask = inp
                except ValueError:
                    lab_inp, lab_refl_targ, lab_sha_targ = inp
                else:
                    print('input dim should be 3 or 4')
                lab_inp, lab_refl_targ, lab_sha_targ = inp
                recon_pred, refl_pred, sha_pred, shape_pred = composer.forward(
                    lab_inp)

                lab_refl_targ = lab_refl_targ.squeeze().cpu().numpy(
                ).transpose(1, 2, 0)
                lab_sha_targ = lab_sha_targ.squeeze().cpu().numpy().transpose(
                    1, 2, 0)
                # mask = mask.squeeze().cpu().numpy().transpose(1,2,0)
                refl_pred = refl_pred.squeeze().cpu().numpy().transpose(
                    1, 2, 0)
                sha_pred = sha_pred.squeeze().cpu().numpy().transpose(1, 2, 0)
                shape_pred = shape_pred.squeeze().cpu().numpy().transpose(
                    1, 2, 0)
                lab_refl_targ = np.clip(lab_refl_targ, 0, 1)
                lab_sha_targ = np.clip(lab_sha_targ, 0, 1)
                refl_pred = np.clip(refl_pred, 0, 1)
                sha_pred = np.clip(sha_pred, 0, 1)
                shape_pred = np.clip(shape_pred, 0, 1)
                # mask = np.clip(mask, 0, 1)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "refl_target",
                                 "{}.png".format(ind)), lab_refl_targ)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "shad_target",
                                 "{}.png".format(ind)), lab_sha_targ)
                # scipy.misc.imsave(os.path.join(args.save_path, "mask", "{}.png".format(ind)), mask)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "refl_output",
                                 "{}.png".format(ind)), refl_pred)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "shad_output",
                                 "{}.png".format(ind)), sha_pred)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "shape_output",
                                 "{}.png".format(ind)), shape_pred)
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\sintel',
                        help='base folder of datasets')
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_SceneSplit\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=240)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--checkpoint', type=bool, default=False)
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device("cuda: 1" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    shader = RIN.Shader(output_ch=3)
    reflection = RIN.Decomposer()
    composer = RIN.Composer(reflection, shader).to(device)
    # shader = RIN.Shader()
    # decomposer = RIN.Decomposer()
    # composer = RIN.Composer(decomposer, shader).to(device)
    # RIN.init_weights(composer, init_type='kaiming')

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 0
    if args.checkpoint:
        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')
        cur_epoch = cur_epoch = int(
            args.state_dict.split('_')[-1].split('.')[0]) + 1

    # train_transform = RIN_pipeline.MPI_Train_Agumentation()
    # train_set = RIN_pipeline.MPI_Dataset(args.data_path, mode=args.mode[0], transform=train_transform)
    # train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    # test_transform = RIN_pipeline.MPI_Test_Agumentation()
    # test_set = RIN_pipeline.MPI_Dataset(args.data_path, mode=args.mode[1], transform=test_transform)
    # test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=True)

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=True)

    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.MPI_TrainerOriginRemoveShape(
        composer, train_loader, args.lr, device, writer)

    best_average_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        step = trainer.train()

        if (epoch + 1) % 120 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)

            if args.save_model:
                state = composer.state_dict()
                torch.save(
                    state,
                    os.path.join(args.save_path,
                                 'composer_state_{}.t7'.format(epoch)))

        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_remove_shape(
            composer, test_loader, device)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)

        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))

        if average_loss < best_average_loss:
            best_average_loss = average_loss
            if args.save_model:
                state = composer.state_dict()
                torch.save(
                    state,
                    os.path.join(args.save_path,
                                 'composer_state_{}.t7'.format(epoch)))
            #RIN_pipeline.visualize_MPI(composer, test_loader, device, os.path.join(args.save_path, 'image_{}.png'.format(epoch)))
            with open(os.path.join(args.save_path, 'loss.txt'), 'a+') as f:
                f.write(
                    'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                    .format(epoch, average_loss, albedo_test_loss,
                            shading_test_loss))
예제 #4
0
def main():
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)
    # cudnn.benchmark = True
    # cudnn.deterministic = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path',          type=str,   default='IIW_logs\\RIID_new_RIN_updateLR1_epoch240\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',                 type=float, default=0.0005,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=40)
    parser.add_argument('--batch_size',         type=int,   default=1)
    parser.add_argument('--checkpoint',         type=bool,  default=False)
    parser.add_argument('--state_dict',         type=str,   default='composer_state.t7')
    parser.add_argument('--cuda',               type=str,   default='cuda')
    parser.add_argument('--image_size',         type=StrToInt, default=256)
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    composer = RIN.SEDecomposerSingle().to(device)

    IIW_train_txt = 'F:\\revisit_IID\\iiw-dataset\\iiw_Learning_Lightness_train.txt'
    IIW_test_txt = 'F:\\revisit_IID\\iiw-dataset\\iiw_Learning_Lightness_test.txt'

    train_set = RIN_pipeline.IIW_Dataset_Revisit(IIW_train_txt)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    test_set = RIN_pipeline.IIW_Dataset_Revisit(IIW_test_txt, out_mode='txt')
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)
    
    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.IIWTrainer(composer, train_loader, device, writer, args)

    best_score = 9999

    for epoch in range(args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        trainer.train()

        if (epoch + 1) % 40 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
            
        score = RIN_pipeline.IIW_test_unet(composer, test_loader, device)
        writer.add_scalar('score', score, epoch)

        with open(os.path.join(args.save_path, 'score.txt'), 'a+') as f:
            f.write('epoch{} --- score: {}\n'.format(epoch, score))

        if score < best_score:
            best_score = score
            if args.save_model:
                state = composer.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_state.t7'))
            with open(os.path.join(args.save_path, 'score_best.txt'), 'a+') as f:
                f.write('epoch{} --- score:{}\n'.format(epoch, score))
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',          type=str,   default='F:\\MPI-sintel-PyrResNet\\Sintel\\images\\',
    help='base folder of datasets')
    parser.add_argument('--split',              type=str,   default='SceneSplit')
    parser.add_argument('--mode',               type=str,   default='train')
    parser.add_argument('--save_path',          type=str,   default='MPI_logs\\RIID_new_RIN_updateLR1_epoch240_CosBF_VGG0.1_shading_SceneSplit_GAN_selayer1_ReflMultiSize_DA\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',                 type=float, default=0.0005,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=1000)
    parser.add_argument('--batch_size',         type=int,   default=20)
    parser.add_argument('--checkpoint',         type=bool,  default=False)
    parser.add_argument('--state_dict',         type=str,   default='composer_state.t7')
    parser.add_argument('--skip_se',            type=StrToBool, default=False)
    parser.add_argument('--cuda',               type=str,   default='cuda:1')
    parser.add_argument('--dilation',           type=StrToBool,   default=False)
    parser.add_argument('--se_improved',        type=StrToBool,  default=False)
    parser.add_argument('--weight_decay',       type=float, default=0.0001)
    parser.add_argument('--refl_multi_size',    type=bool,  default=False)
    parser.add_argument('--shad_multi_size',    type=bool,  default=False)
    parser.add_argument('--data_augmentation',  type=bool,  default=True)
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    # shader = RIN.Shader(output_ch=3)
    print(args.skip_se)
    Generator_R = RIN.SESingleGenerator(multi_size=args.refl_multi_size).to(device)
    Generator_S = RIN.SESingleGenerator(multi_size=args.shad_multi_size).to(device)
    composer = RIN.SEComposerGenerater(Generator_R, Generator_S, args.refl_multi_size, args.shad_multi_size).to(device)
    Discriminator_R = RIN.SEUG_Discriminator().to(device)
    Discriminator_S = RIN.SEUG_Discriminator().to(device)
    # composer = RIN.Composer(reflection, shader).to(device)

    
    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    if args.data_augmentation:
        MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-train.txt'
    else:
        MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 0
    if args.checkpoint:
        composer.load_state_dict(torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')
        # cur_epoch = int(args.state_dict.split('_')[-1].split('.')[0]) + 1
    if args.data_augmentation:
        train_transform = RIN_pipeline.MPI_Train_Agumentation()
    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt, transform=train_transform if args.data_augmentation else None, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)
    
    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.SEUGTrainer(composer, Discriminator_R, Discriminator_S, train_loader, args.lr, device, writer, weight_decay=args.weight_decay, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        step = trainer.train()

        if (epoch + 1) % 100 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
            
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)

        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f:
            f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            if args.save_model:
                state = composer.reflectance.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_reflectance_state.t7'))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss))
        
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            if args.save_model:
                state = composer.shading.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_shading_state.t7'))
            with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))
예제 #6
0
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\sintel',
                        help='base folder of datasets')
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='one')
    parser.add_argument('--choose', type=str, default='refl')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_mish_ranger_Octave_one_bs10_updateLR1_epoch240_CosBF_VGG0.1_refl_SceneSplit\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=240)
    parser.add_argument('--batch_size', type=StrToInt, default=10)
    parser.add_argument('--checkpoint', type=bool, default=True)
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')

    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device("cuda: 1" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    # shader = RIN.Shader(output_ch=3)
    composer = octave.get_mish_model_one_output().to(device)
    # composer = RIN.Composer(reflection, shader).to(device)

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 119
    if args.checkpoint:
        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')
        # cur_epoch = int(args.state_dict.split('_')[-1].split('.')[0]) + 1

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.OctaveTrainer(composer,
                                         train_loader,
                                         args.lr,
                                         device,
                                         writer,
                                         mode=args.mode,
                                         choose=args.choose)

    best_average_loss = 0.01970498738396499

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        step = trainer.train()

        if (epoch + 1) % 120 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
        if args.mode == 'two':
            albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
                composer, test_loader, device)
            average_loss = (albedo_test_loss + shading_test_loss) / 2
            writer.add_scalar('A_mse', albedo_test_loss, epoch)
            writer.add_scalar('S_mse', shading_test_loss, epoch)
            writer.add_scalar('aver_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write(
                    'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                    .format(epoch, average_loss, albedo_test_loss,
                            shading_test_loss))
        else:
            average_loss = RIN_pipeline.MPI_test_unet_one(composer,
                                                          test_loader,
                                                          device,
                                                          choose=args.choose)
            writer.add_scalar('test_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write('epoch{} --- test_loss: {}\n'.format(
                    epoch, average_loss))

        if average_loss < best_average_loss:
            best_average_loss = average_loss
            if args.save_model:
                state = composer.state_dict()
                torch.save(state,
                           os.path.join(args.save_path, 'composer_state.t7'))

            with open(os.path.join(args.save_path, 'loss.txt'), 'a+') as f:
                if args.mode == 'two':
                    f.write(
                        'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                        .format(epoch, average_loss, albedo_test_loss,
                                shading_test_loss))
                else:
                    f.write('epoch{} --- average_loss: {}\n'.format(
                        epoch, average_loss))
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',          type=str,   default='F:\\sintel',
    help='base folder of datasets')
    parser.add_argument('--split',              type=str,   default='SceneSplit')
    parser.add_argument('--mode',               type=str,   default='two')
    parser.add_argument('--save_path',          type=str,   default='MPI_logs_new\\RIID_new_RIN_updateLR1_epoch240_CosBF_VGG0.1_shading_SceneSplit_selayer1_reflmultiSize\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--lr',                 type=float, default=0.0005,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=120)
    parser.add_argument('--batch_size',         type=int,   default=20)
    parser.add_argument('--checkpoint',         type=bool,  default=False)
    parser.add_argument('--state_dict',         type=str,   default='composer_state.t7')
    parser.add_argument('--cuda',               type=str,   default='cuda')
    parser.add_argument('--choose',             type=str,   default='refl')
    parser.add_argument('--refl_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--shad_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--refl_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--shad_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--refl_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--shad_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--refl_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--refl_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--shad_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--refl_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--image_size',         type=StrToInt, default=256)
    args = parser.parse_args()

    check_folder(args.save_path)
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se).to(device)
    shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size, low_se=args.shad_low_se, skip_se=args.shad_skip_se).to(device)
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device)

    MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    cur_epoch = 0
    if args.checkpoint:
        composer.load_state_dict(torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')

    train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)
    
    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.OctaveTrainer(composer, train_loader, device, writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        trainer.train()

        if (epoch + 1) % 40 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
            
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)

        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f:
            f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            if args.save_model:
                state = composer.reflectance.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_reflectance_state.t7'))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss))
        
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            if args.save_model:
                state = composer.shading.state_dict()
                torch.save(state, os.path.join(args.save_path, 'composer_shading_state.t7'))
            with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))
예제 #8
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs_new\\GAN_RIID_updateLR3_epoch160_CosbfVGG_SceneSplit_refl-se-skip_shad-se-low_multi_new_shadSqueeze_grad\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state_81.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state_81.t7')
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    parser.add_argument('--heatmap', type=StrToBool, default=False)
    args = parser.parse_args()

    device = torch.device(args.cuda)
    model = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size,
                                   low_se=args.refl_low_se,
                                   skip_se=args.refl_skip_se,
                                   detach=args.refl_detach_flag,
                                   reduction=args.refl_reduction,
                                   heatmap=args.heatmap).to(device)
    model.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.refl_checkpoint,
                         args.state_dict_refl)))
    print('load checkpoint success!')

    if args.fullsize:
        print('test fullsize....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'
        h, w = 436, 1024
        pad_h, pad_w = clc_pad(h, w, 16)
        print(pad_h, pad_w)
        tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h))
        tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h))
        tmp_inversepad_heatmap = nn.ReflectionPad2d((0, 0, 0, -3))
    else:
        print('test size256....')
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.fullsize:
        check_folder(os.path.join(args.save_path, "refl_heapmapin"))
        check_folder(os.path.join(args.save_path, "refl_heapmapout"))

    ToPIL = transforms.ToPILImage()

    model.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            if args.fullsize:
                input_g = tmp_pad(input_g)
            if args.refl_multi_size:
                albedo_fake, _, heapmap = model.forward(input_g)
            else:
                albedo_fake = model.forward(input_g)
            if args.fullsize:
                input_g = tmp_inversepad(input_g)
                heapmap[0] = tmp_inversepad_heatmap(heapmap[0])
                heapmap[1] = tmp_inversepad_heatmap(heapmap[1])

            # albedo_fake  = albedo_fake*mask_g

            # lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1,2,0)
            # lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1,2,0)
            # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0)
            print(heapmap[0].squeeze().size())
            heapmap[0] = torch.sum(heapmap[0], dim=1, keepdim=True)
            heapmap[1] = torch.sum(heapmap[1], dim=1, keepdim=True)
            heapmapin = tensor2numpy(heapmap[0][0])
            heapmapout = tensor2numpy(heapmap[1][0])
            #heapmapout = torch.sum(heapmap[1].squeeze(), dim=0, keepdim=True).cpu().clamp(0,1).numpy().transpose(1,2,0)
            print(heapmapin.shape)
            heapmapin = cam(heapmapin)
            print(heapmapin.shape)
            heapmapout = cam(heapmapout)
            # heapmapin = heapmapin.transpose(2,0,1)
            # heapmapout = heapmapout.transpose(2,0,1)
            # input_g = input_g.squeeze().cpu().clamp(0, 1).numpy()
            # print(heapmapin.shape)
            # print(input_g.shape)
            # heapmapin = np.concatenate((heapmapin, input_g), 1).astype(np.float32)
            # heapmapout = np.concatenate((heapmapout, input_g), 1).astype(np.float32)
            # print(heapmapin.shape)
            # heapmapin = torch.from_numpy(heapmapin)
            # heapmapout = torch.from_numpy(heapmapout)
            # lab_refl_targ = ToPIL(input_g.squeeze())
            # refl_pred = ToPIL(albedo_fake.squeeze())

            # heapmapin = torch.cat([heapmapin, torch.zeros(2, h // 4, w // 4)])
            # heapmapout = torch.cat([heapmapout, torch.zeros(2, h // 4, w // 4)])

            # print(heapmapin.size)
            # print(heapmapout.size)
            cv2.imwrite(
                os.path.join(args.save_path, "refl_heapmapin",
                             '{}.png'.format(ind)), heapmapin * 255.0)
            cv2.imwrite(
                os.path.join(args.save_path, "refl_heapmapout",
                             '{}.png'.format(ind)), heapmapout * 255.0)
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='ImageSplit')
    parser.add_argument('--mode', type=str, default='test')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_epoch240_ImageSplit_size256_remove_shape\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--state_dict',
                        type=str,
                        default='composer_state_222.t7')
    args = parser.parse_args()

    # pylint: disable=E1101
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    shader = RIN.Shader(output_ch=3)
    reflection = RIN.Decomposer()
    composer = RIN.Composer(reflection, shader).to(device)
    # RIN.init_weights(composer, init_type='kaiming')
    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    composer.load_state_dict(
        torch.load(os.path.join(args.save_path, args.state_dict)))
    print('load checkpoint success!')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    check_folder(os.path.join(args.save_path, "refl_target"))
    check_folder(os.path.join(args.save_path, "shad_target"))
    check_folder(os.path.join(args.save_path, "refl_output"))
    check_folder(os.path.join(args.save_path, "shad_output"))
    # check_folder(os.path.join(args.save_path, "shape_output"))
    check_folder(os.path.join(args.save_path, "mask"))

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            _, albedo_fake, shading_fake = composer.forward(input_g)
            albedo_fake = albedo_fake * mask_g

            lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            mask = mask_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1, 2, 0)
            sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1, 2, 0)

            lab_refl_targ = np.clip(lab_refl_targ, 0, 1)
            lab_sha_targ = np.clip(lab_sha_targ, 0, 1)
            refl_pred = np.clip(refl_pred, 0, 1)
            sha_pred = np.clip(sha_pred, 0, 1)
            mask = np.clip(mask, 0, 1)

            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_target",
                             "{}.png".format(ind)), lab_refl_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_target",
                             "{}.png".format(ind)), lab_sha_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "mask", "{}.png".format(ind)),
                mask)
            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_output",
                             "{}.png".format(ind)), refl_pred)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_output",
                             "{}.png".format(ind)), sha_pred)
def main():
    # random.seed(6666)
    # torch.manual_seed(6666)
    # torch.cuda.manual_seed(6666)
    # np.random.seed(6666)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'logs_vqvae\\MIT_base_256x256_noRetinex_withBf_ALLleakyrelu_BNUP_Sigmiod_inception_bs4_finetune\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=StrToInt, default=4)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--skip_se', type=StrToBool, default=False)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--dilation', type=StrToBool, default=False)
    parser.add_argument('--se_improved', type=StrToBool, default=False)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--refl_act', type=str, default='relu')
    parser.add_argument('--shad_act', type=str, default='relu')
    parser.add_argument('--data_augmentation', type=StrToBool, default=False)
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    parser.add_argument('--vae', type=StrToBool, default=False)
    parser.add_argument('--fullsize_test', type=StrToBool, default=False)
    parser.add_argument('--vq_flag', type=StrToBool, default=False)
    parser.add_argument('--image_size', type=StrToInt, default=256)
    parser.add_argument('--shad_out_conv', type=StrToInt, default=3)
    parser.add_argument('--finetune', type=StrToBool, default=True)
    parser.add_argument('--use_tanh', type=StrToBool, default=False)
    parser.add_argument('--use_inception', type=StrToBool, default=True)
    parser.add_argument('--use_skip', type=StrToBool, default=True)
    parser.add_argument('--use_multiPredict', type=StrToBool, default=True)
    parser.add_argument('--init_weights', type=StrToBool, default=False)
    parser.add_argument('--adam_flag', type=StrToBool, default=False)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    device = torch.device(args.cuda)

    reflectance = RIN.VQVAE(vq_flag=args.vq_flag,
                            init_weights=args.init_weights,
                            use_tanh=args.use_tanh,
                            use_inception=args.use_inception,
                            use_skip=args.use_skip,
                            use_multiPredict=args.use_multiPredict).to(device)
    shading = RIN.VQVAE(vq_flag=args.vq_flag,
                        init_weights=args.init_weights,
                        use_tanh=args.use_tanh,
                        use_inception=args.use_inception,
                        use_skip=args.use_skip,
                        use_multiPredict=args.use_multiPredict).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    MIT_train_fullsize_txt = 'MIT_TXT\\MIT_BarronSplit_fullsize_train.txt'
    MIT_test_fullsize_txt = 'MIT_TXT\\MIT_BarronSplit_fullsize_test.txt'
    MIT_train_txt = 'MIT_TXT\\MIT_BarronSplit_train.txt'
    MIT_test_txt = 'MIT_TXT\\MIT_BarronSplit_test.txt'
    if args.fullsize and not args.finetune:
        # train_set = RIN_pipeline.MIT_Dataset_Revisit(MIT_train_fullsize_txt, mode='train', refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size, fullsize=args.fullsize)
        train_set = RIN_pipeline.MIT_Dataset_Revisit(
            MIT_train_fullsize_txt,
            mode='train',
            refl_multi_size=args.refl_multi_size,
            shad_multi_size=args.shad_multi_size,
            image_size=args.image_size)
    else:
        train_set = RIN_pipeline.MIT_Dataset_Revisit(
            MIT_train_txt,
            mode='train',
            refl_multi_size=args.refl_multi_size,
            shad_multi_size=args.shad_multi_size,
            image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MIT_Dataset_Revisit(MIT_test_fullsize_txt,
                                                mode='test')
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    writer = SummaryWriter(log_dir=args.save_path)
    best_albedo_loss = 9999
    best_shading_loss = 9999
    best_avg_lmse = 9999
    flag = True
    trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device, writer,
                                        args)
    logging.info('start training....')
    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if epoch >= 80 and args.finetune and flag:
            flag = False
            train_set = RIN_pipeline.MIT_Dataset_Revisit(
                MIT_train_fullsize_txt,
                mode='train',
                refl_multi_size=args.refl_multi_size,
                shad_multi_size=args.shad_multi_size,
                image_size=args.image_size,
                fullsize=args.fullsize)
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=True)
            trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                                writer, args)
            # else:
            #     flag = True
            #     train_set = RIN_pipeline.MIT_Dataset_Revisit(MIT_train_txt, mode='train', refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size)
            #     train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)
            #     trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device, writer, args)

        albedo_test_loss, shading_test_loss = RIN_pipeline.MIT_test_unet(
            composer, test_loader, device, args)

        if (epoch + 1) % 40 == 0:
            args.lr *= 0.75
            trainer.update_lr(args.lr)

        average_loss = (albedo_test_loss + shading_test_loss) / 2
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            state = composer.reflectance.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.refl_checkpoint,
                             'composer_reflectance_state.t7'))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(
                    epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            state = composer.shading.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.shad_checkpoint,
                             'composer_shading_state.t7'))
            with open(os.path.join(args.save_path, 'shading_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(
                    epoch, shading_test_loss))
예제 #11
0
def main():
    random.seed(6666)
    torch.manual_seed(6666)
    torch.cuda.manual_seed(6666)
    np.random.seed(6666)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MIT_logs\\RIID_origin_RIN_updateLR0.0005_4_bf_cosLoss_VGG0.1_400epochs_bs22\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=StrToInt, default=16)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--skip_se', type=StrToBool, default=False)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--dilation', type=StrToBool, default=False)
    parser.add_argument('--se_improved', type=StrToBool, default=False)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--refl_bn', type=StrToBool, default=True)
    parser.add_argument('--shad_bn', type=StrToBool, default=True)
    parser.add_argument('--refl_act', type=str, default='relu')
    parser.add_argument('--shad_act', type=str, default='relu')
    # parser.add_argument('--refl_gan',           type=StrToBool,  default=False)
    # parser.add_argument('--shad_gan',           type=StrToBool,  default=False)
    parser.add_argument('--data_augmentation', type=StrToBool, default=False)
    parser.add_argument('--fullsize', type=StrToBool, default=False)
    # parser.add_argument('--fullsize_test',      type=StrToBool,  default=False)
    parser.add_argument('--image_size', type=StrToInt, default=256)
    parser.add_argument('--shad_out_conv', type=StrToInt, default=3)
    parser.add_argument('--finetune', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=False)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    # device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    reflectance = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size,
                                         low_se=args.refl_low_se,
                                         skip_se=args.refl_skip_se,
                                         detach=args.refl_detach_flag,
                                         reduction=args.refl_reduction,
                                         bn=args.refl_bn,
                                         act=args.refl_act).to(device)
    shading = RIN.SEDecomposerSingle(
        multi_size=args.shad_multi_size,
        low_se=args.shad_low_se,
        skip_se=args.shad_skip_se,
        se_squeeze=args.shad_squeeze_flag,
        reduction=args.shad_reduction,
        detach=args.shad_detach_flag,
        bn=args.shad_bn,
        act=args.shad_act,
        last_conv_ch=args.shad_out_conv).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)
    # shader = RIN.Shader()
    # reflection = RIN.Decomposer()
    # composer = RIN.Composer(reflection, shader).to(device)
    Discriminator_R = RIN.SEUG_Discriminator().to(device)
    # if args.shad_gan:
    Discriminator_S = RIN.SEUG_Discriminator().to(device)

    MIT_train_fullsize_txt = 'MIT_TXT\\MIT_BarronSplit_fullsize_train.txt'
    MIT_test_fullsize_txt = 'MIT_TXT\\MIT_BarronSplit_fullsize_test.txt'
    MIT_train_txt = 'MIT_TXT\\MIT_BarronSplit_train.txt'
    MIT_test_txt = 'MIT_TXT\\MIT_BarronSplit_test.txt'
    if args.fullsize and not args.finetune:
        train_set = RIN_pipeline.MIT_Dataset_Revisit(
            MIT_train_fullsize_txt,
            mode='train',
            refl_multi_size=args.refl_multi_size,
            shad_multi_size=args.shad_multi_size,
            image_size=args.image_size,
            fullsize=args.fullsize)
    else:
        train_set = RIN_pipeline.MIT_Dataset_Revisit(
            MIT_train_txt,
            mode='train',
            refl_multi_size=args.refl_multi_size,
            shad_multi_size=args.shad_multi_size,
            image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MIT_Dataset_Revisit(MIT_test_fullsize_txt,
                                                mode='test')
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    # test_loader_2 = torch.utils.data.DataLoader(test_set, batch_size=10, num_workers=args.loader_threads, shuffle=False)

    writer = SummaryWriter(log_dir=args.save_path)
    best_albedo_loss = 9999
    best_shading_loss = 9999
    best_avg_lmse = 9999
    flag = True
    # trainer = RIN_pipeline.MIT_TrainerOrigin(composer, train_loader, args.lr, device, writer)
    trainer = RIN_pipeline.SEUGTrainer(composer, Discriminator_R,
                                       Discriminator_S, train_loader, device,
                                       writer, args)
    logging.info('start training....')
    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if epoch >= 80 and args.finetune:
            if flag and args.finetune:
                flag = False
                train_set = RIN_pipeline.MIT_Dataset_Revisit(
                    MIT_train_fullsize_txt,
                    mode='train',
                    refl_multi_size=args.refl_multi_size,
                    shad_multi_size=args.shad_multi_size,
                    image_size=args.image_size,
                    fullsize=args.fullsize)
                train_loader = torch.utils.data.DataLoader(
                    train_set,
                    batch_size=1,
                    num_workers=args.loader_threads,
                    shuffle=True)
                trainer = RIN_pipeline.SEUGTrainer(composer, Discriminator_R,
                                                   Discriminator_S,
                                                   train_loader, device,
                                                   writer, args)

        albedo_test_loss, shading_test_loss = RIN_pipeline.MIT_test_unet(
            composer, test_loader, device, args)
        # albedo_test_loss, shading_test_loss = 0, 0
        # with torch.no_grad():
        #     composer.eval()
        #     criterion = torch.nn.MSELoss(size_average=True).to(device)
        #     for _, labeled in enumerate(test_loader):
        #         labeled = [t.to(device) for t in labeled]
        #         input_g, albedo_g, shading_g, mask_g = labeled
        #         lab_inp_pred, lab_refl_pred, lab_shad_pred, _ = composer.forward(input_g)
        #         lab_inp_pred = lab_inp_pred * mask_g
        #         lab_refl_pred = lab_refl_pred * mask_g
        #         lab_shad_pred = lab_shad_pred * mask_g
        #         refl_loss = criterion(lab_refl_pred, albedo_g)
        #         shad_loss = criterion(lab_shad_pred, shading_g)
        #         # recon_loss = criterion(lab_inp_pred, input_g)
        #         albedo_test_loss = refl_loss.item()
        #         shading_test_loss = shad_loss.item()
        #         writer.add_scalar('test_refl_loss', refl_loss.item(), epoch)
        #         writer.add_scalar('test_shad_loss', shad_loss.item(), epoch)
        # writer.add_scalar('test_recon_loss', recon_loss.item(), epoch)
        # cur_aver_loss = (refl_loss.item() + shad_loss.item()) / 2
        # writer.add_scalar('cur_aver_loss', cur_aver_loss, epoch)
        # if cur_aver_loss < best_loss:
        #     best_loss = cur_aver_loss

        if (epoch + 1) % 40 == 0:
            args.lr *= 0.75
            # logging.info('epoch{} learning rate : {}'.format(epoch, args.lr))
            trainer.update_lr(args.lr)

        average_loss = (albedo_test_loss + shading_test_loss) / 2
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))

        if args.save_model:
            state = composer.reflectance.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.refl_checkpoint,
                             'composer_reflectance_state_{}.t7'.format(epoch)))
            state = composer.shading.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.shad_checkpoint,
                             'composer_shading_state_{}.t7'.format(epoch)))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            # if args.save_model:
            #     state = composer.reflectance.state_dict()
            #     torch.save(state, os.path.join(args.save_path, args.refl_checkpoint, 'composer_reflectance_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(
                    epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            # if args.save_model:
            #     state = composer.shading.state_dict()
            #     torch.save(state, os.path.join(args.save_path, args.shad_checkpoint, 'composer_shading_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'shading_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(
                    epoch, shading_test_loss))
예제 #12
0
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='E:\\BOLD',
                        help='base folder of datasets')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default='logs_vqvae\\BOLD_base_256x256\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=60)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=False)
    parser.add_argument('--fullsize_test', type=StrToBool, default=False)
    parser.add_argument('--vq_flag', type=StrToBool, default=False)
    parser.add_argument('--img_resize_shape', type=str, default=(256, 256))
    parser.add_argument('--use_tanh', type=StrToBool, default=False)
    parser.add_argument('--use_inception', type=StrToBool, default=False)
    parser.add_argument('--init_weights', type=StrToBool, default=False)
    parser.add_argument('--adam_flag', type=StrToBool, default=False)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.VQVAE(vq_flag=args.vq_flag,
                            init_weights=args.init_weights,
                            use_tanh=args.use_tanh,
                            use_inception=args.use_inception).to(device)
    shading = RIN.VQVAE(vq_flag=args.vq_flag,
                        init_weights=args.init_weights,
                        use_tanh=args.use_tanh,
                        use_inception=args.use_inception).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    # train_txt = "BOLD_TXT\\train_list.txt"
    # test_txt = "BOLD_TXT\\test_list.txt"

    supervision_train_set = RIN_pipeline.BOLD_Dataset(
        args.data_path,
        size_per_dataset=40000,
        mode='train',
        img_size=args.img_resize_shape)
    train_loader = torch.utils.data.DataLoader(supervision_train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)
    test_set = RIN_pipeline.BOLD_Dataset(args.data_path,
                                         size_per_dataset=None,
                                         mode='val',
                                         img_size=args.img_resize_shape)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.mode == 'test':
        print('test mode .....')
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        print('albedo_test_loss: ', albedo_test_loss)
        print('shading_test_loss: ', shading_test_loss)
        return

    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.BOLDVQVAETrainer(composer, train_loader, device,
                                            writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if (epoch + 1) % 20 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)

        if (epoch + 1) % 5 == 0:
            albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
                composer, test_loader, device, args)
            average_loss = (albedo_test_loss + shading_test_loss) / 2
            writer.add_scalar('A_mse', albedo_test_loss, epoch)
            writer.add_scalar('S_mse', shading_test_loss, epoch)
            writer.add_scalar('aver_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write(
                    'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                    .format(epoch, average_loss, albedo_test_loss,
                            shading_test_loss))
            if albedo_test_loss < best_albedo_loss:
                state = composer.reflectance.state_dict()
                torch.save(
                    state,
                    os.path.join(
                        args.save_path, args.refl_checkpoint,
                        'composer_reflectance_state_{}.t7'.format(epoch)))
                best_albedo_loss = albedo_test_loss
                with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                          'a+') as f:
                    f.write('epoch{} --- albedo_loss:{}\n'.format(
                        epoch, albedo_test_loss))
            if shading_test_loss < best_shading_loss:
                best_shading_loss = shading_test_loss
                state = composer.shading.state_dict()
                torch.save(
                    state,
                    os.path.join(args.save_path, args.shad_checkpoint,
                                 'composer_shading_state_{}.t7'.format(epoch)))
                with open(os.path.join(args.save_path, 'shading_loss.txt'),
                          'a+') as f:
                    f.write('epoch{} --- shading_loss:{}\n'.format(
                        epoch, shading_test_loss))
예제 #13
0
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default='MPI_logs_vqvae\\vqvae_base_256x256\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--skip_se', type=StrToBool, default=False)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--dilation', type=StrToBool, default=False)
    parser.add_argument('--se_improved', type=StrToBool, default=False)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--refl_bn', type=StrToBool, default=True)
    parser.add_argument('--shad_bn', type=StrToBool, default=True)
    parser.add_argument('--refl_act', type=str, default='relu')
    parser.add_argument('--shad_act', type=str, default='relu')
    # parser.add_argument('--refl_gan',           type=StrToBool,  default=False)
    # parser.add_argument('--shad_gan',           type=StrToBool,  default=False)
    parser.add_argument('--data_augmentation', type=StrToBool, default=False)
    parser.add_argument('--fullsize', type=StrToBool, default=False)
    parser.add_argument('--fullsize_test', type=StrToBool, default=False)
    parser.add_argument('--image_size', type=StrToInt, default=256)
    parser.add_argument('--ttur', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=True)
    parser.add_argument('--vq_flag', type=StrToBool, default=True)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.VQVAE(vq_flag=args.vq_flag).to(device)
    shading = RIN.VQVAE(vq_flag=args.vq_flag).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    # if not args.ttur:
    #     Discriminator_R = RIN.SEUG_Discriminator().to(device)
    #     Discriminator_S = RIN.SEUG_Discriminator().to(device)
    # else:
    #     Discriminator_R = RIN.SEUG_Discriminator_new().to(device)
    #     Discriminator_S = RIN.SEUG_Discriminator_new().to(device)

    if args.data_augmentation:
        print('data_augmentation.....')
        if args.fullsize:
            MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-train.txt'
            MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-train.txt'
        else:
            MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-320-train.txt'
            MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-320-train.txt'
    else:
        MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
        MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'

    if args.fullsize_test:
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
    else:
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'

    if args.fullsize_test:
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'
    else:
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    if args.data_augmentation:
        print('augmentation...')
        train_transform = RIN_pipeline.MPI_Train_Agumentation_fy2()

    train_set = RIN_pipeline.MPI_Dataset_Revisit(
        train_txt,
        transform=train_transform if args.data_augmentation else None,
        refl_multi_size=args.refl_multi_size,
        shad_multi_size=args.shad_multi_size,
        image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.mode == 'test':
        print('test mode .....')
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        print('albedo_test_loss: ', albedo_test_loss)
        print('shading_test_loss: ', shading_test_loss)
        return

    writer = SummaryWriter(log_dir=args.save_path)

    if not args.ttur:
        trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                            writer, args)
    else:
        trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                            writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if (epoch + 1) % 40 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)

        # if (epoch + 1) % 10 == 0:
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))
        if albedo_test_loss < best_albedo_loss:
            state = composer.reflectance.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.refl_checkpoint,
                             'composer_reflectance_state_{}.t7'.format(epoch)))
            best_albedo_loss = albedo_test_loss
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(
                    epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            state = composer.shading.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.shad_checkpoint,
                             'composer_shading_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'shading_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(
                    epoch, shading_test_loss))
예제 #14
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'logs_vqvae\\MIT_base_256x256_noRetinex_withBf_leakyrelu_BNUP_Sigmiod_inception_bs4_finetune_woMultiPredict\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    parser.add_argument('--shad_out_conv', type=StrToInt, default=3)
    parser.add_argument('--dataset', type=str, default='mit')
    parser.add_argument('--shapenet_g', type=StrToBool, default=False)
    parser.add_argument('--vq_flag', type=StrToBool, default=False)
    parser.add_argument('--use_tanh', type=StrToBool, default=False)
    parser.add_argument('--use_inception', type=StrToBool, default=True)
    parser.add_argument('--use_skip', type=StrToBool, default=True)
    parser.add_argument('--use_multiPredict', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=True)
    args = parser.parse_args()

    device = torch.device(args.cuda)
    if args.vae:
        reflectance = RIN.VQVAE(
            vq_flag=args.vq_flag,
            use_tanh=args.use_tanh,
            use_inception=args.use_inception,
            use_skip=args.use_skip,
            use_multiPredict=args.use_multiPredict).to(device)
        shading = RIN.VQVAE(vq_flag=args.vq_flag,
                            use_tanh=args.use_tanh,
                            use_inception=args.use_inception,
                            use_skip=args.use_skip,
                            use_multiPredict=args.use_multiPredict).to(device)
    else:
        reflectance = RIN.SEDecomposerSingle(
            multi_size=args.refl_multi_size,
            low_se=args.refl_low_se,
            skip_se=args.refl_skip_se,
            detach=args.refl_detach_flag,
            reduction=args.refl_reduction).to(device)
        shading = RIN.SEDecomposerSingle(
            multi_size=args.shad_multi_size,
            low_se=args.shad_low_se,
            skip_se=args.shad_skip_se,
            se_squeeze=args.shad_squeeze_flag,
            reduction=args.shad_reduction,
            detach=args.shad_detach_flag,
            last_conv_ch=args.shad_out_conv).to(device)
    reflectance.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.refl_checkpoint,
                         args.state_dict_refl)))
    shading.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.shad_checkpoint,
                         args.state_dict_shad)))
    print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    if args.dataset == 'mit':
        if args.fullsize:
            print('test fullsize....')
            test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MIT_TXT\\MIT_BarronSplit_fullsize_test.txt'
        else:
            print('test size256....')
            test_txt = 'MIT_TXT\\MIT_BarronSplit_test.txt'
        test_set = RIN_pipeline.MIT_Dataset_Revisit(test_txt, mode='test')
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=1,
            num_workers=args.loader_threads,
            shuffle=False)
    else:
        remove_names = os.listdir('F:\\ShapeNet\\remove')
        if args.shapenet_g:
            test_set = RIN_pipeline.ShapeNet_Dateset_new_new(
                'F:\\ShapeNet',
                size_per_dataset=9000,
                mode='test',
                image_size=256,
                remove_names=remove_names,
                shapenet_g=args.shapenet_g)
        else:
            test_set = RIN_pipeline.ShapeNet_Dateset_new_new(
                'F:\\ShapeNet',
                size_per_dataset=9000,
                mode='test',
                image_size=256,
                remove_names=remove_names)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=1,
            num_workers=args.loader_threads,
            shuffle=False)

    if args.shapenet_g:
        check_folder(os.path.join(args.save_path, "refl_target_G"))
        check_folder(os.path.join(args.save_path, "shad_target_G"))
        check_folder(os.path.join(args.save_path, "refl_output_G"))
        check_folder(os.path.join(args.save_path, "shad_output_G"))
        check_folder(os.path.join(args.save_path, "mask_G"))
    else:
        if args.fullsize:
            check_folder(os.path.join(args.save_path, "refl_target_fullsize"))
            check_folder(os.path.join(args.save_path, "refl_output_fullsize"))
            check_folder(os.path.join(args.save_path, "shad_target_fullsize"))
            check_folder(os.path.join(args.save_path, "shad_output_fullsize"))
            check_folder(os.path.join(args.save_path, "mask"))
        else:
            check_folder(os.path.join(args.save_path, "refl_target"))
            check_folder(os.path.join(args.save_path, "shad_target"))
            check_folder(os.path.join(args.save_path, "refl_output"))
            check_folder(os.path.join(args.save_path, "shad_output"))
            check_folder(os.path.join(args.save_path, "mask"))

    ToPIL = transforms.ToPILImage()

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp

            if args.fullsize:
                h, w = input_g.size()[2], input_g.size()[3]
                pad_h, pad_w = clc_pad(h, w, 16)
                print(pad_h, pad_w)
                tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h))
                tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h))
                input_g = tmp_pad(input_g)
            if args.refl_multi_size and args.shad_multi_size:
                albedo_fake, shading_fake, _, _ = composer.forward(input_g)
            elif args.refl_multi_size or args.shad_multi_size:
                albedo_fake, shading_fake, _ = composer.forward(input_g)
            else:
                albedo_fake, shading_fake = composer.forward(input_g)
            if args.fullsize:
                albedo_fake, shading_fake = tmp_inversepad(
                    albedo_fake), tmp_inversepad(shading_fake)

            if args.use_tanh:
                albedo_fake = (albedo_fake + 1) / 2
                shading_fake = (shading_fake + 1) / 2
                albedo_g = (albedo_g + 1) / 2
                shading_g = (shading_g + 1) / 2

            albedo_fake = albedo_fake * mask_g
            shading_fake = shading_fake * mask_g

            albedo_fake = albedo_fake.cpu().clamp(0, 1)
            shading_fake = shading_fake.cpu().clamp(0, 1)
            albedo_g = albedo_g.cpu().clamp(0, 1)
            shading_g = shading_g.cpu().clamp(0, 1)

            lab_refl_targ = ToPIL(albedo_g.squeeze())
            lab_sha_targ = ToPIL(shading_g.squeeze())
            refl_pred = ToPIL(albedo_fake.squeeze())
            sha_pred = ToPIL(shading_fake.squeeze())
            mask_g = ToPIL(mask_g.cpu().squeeze())

            if args.shapenet_g:
                lab_refl_targ.save(
                    os.path.join(args.save_path, "refl_target_G",
                                 "{}.png".format(ind)))
                lab_sha_targ.save(
                    os.path.join(args.save_path, "shad_target_G",
                                 "{}.png".format(ind)))
                refl_pred.save(
                    os.path.join(args.save_path, "refl_output_G",
                                 "{}.png".format(ind)))
                sha_pred.save(
                    os.path.join(args.save_path, "shad_output_G",
                                 "{}.png".format(ind)))
                mask_g.save(
                    os.path.join(args.save_path, "mask_G",
                                 "{}.png".format(ind)))
            else:
                lab_refl_targ.save(
                    os.path.join(
                        args.save_path, "refl_target_fullsize" if args.fullsize
                        else "refl_target", "{}.png".format(ind)))
                lab_sha_targ.save(
                    os.path.join(
                        args.save_path, "shad_target_fullsize" if args.fullsize
                        else "shad_target", "{}.png".format(ind)))
                refl_pred.save(
                    os.path.join(
                        args.save_path, "refl_output_fullsize" if args.fullsize
                        else "refl_output", "{}.png".format(ind)))
                sha_pred.save(
                    os.path.join(
                        args.save_path, "shad_output_fullsize" if args.fullsize
                        else "shad_output", "{}.png".format(ind)))
                mask_g.save(
                    os.path.join(args.save_path, "mask", "{}.png".format(ind)))
예제 #15
0
def main():
    random.seed(9999)
    torch.manual_seed(9999)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='test')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_SceneSplit\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--state_dict',
                        type=str,
                        default='composer_state_179.t7')
    args = parser.parse_args()

    # pylint: disable=E1101
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101
    shader = RIN.Shader(output_ch=3)
    reflection = RIN.Decomposer()
    composer = RIN.Composer(reflection, shader).to(device)

    MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
    MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'

    if args.split == 'ImageSplit':
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    composer.load_state_dict(
        torch.load(os.path.join(args.save_path, args.state_dict)))
    print('load checkpoint success!')

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    check_folder(os.path.join(args.save_path, "refl_target_fullsize"))
    check_folder(os.path.join(args.save_path, "shad_target_fullsize"))
    check_folder(os.path.join(args.save_path, "refl_output_fullsize"))
    check_folder(os.path.join(args.save_path, "shad_output_fullsize"))
    check_folder(os.path.join(args.save_path, "shape_output_fullsize"))
    check_folder(os.path.join(args.save_path, "mask_fullsize"))

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp
            lab_refl_pred = np.zeros(
                (input_g.size()[2], input_g.size()[3], input_g.size()[1]))
            lab_sha_pred = np.zeros_like(lab_refl_pred)
            lab_shape_pred = np.zeros_like(lab_refl_pred)
            for ind2 in range(8):
                if ind2 % 2 == 0:
                    input_g_tmp = input_g[:, :, :256, (ind2 // 2) *
                                          256:(ind2 // 2 + 1) * 256]
                    mask_g_tmp = mask_g[:, :, :256, (ind2 // 2) *
                                        256:(ind2 // 2 + 1) * 256]
                    _, albedo_fake, shading_fake, shape_fake = composer.forward(
                        input_g_tmp)
                    albedo_fake = albedo_fake * mask_g_tmp
                    lab_refl_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                  256, :] += albedo_fake.squeeze().cpu().numpy(
                                  ).transpose(1, 2, 0)
                    lab_sha_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                 256, :] += shading_fake.squeeze().cpu().numpy(
                                 ).transpose(1, 2, 0)
                    lab_shape_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                   256, :] += shape_fake.squeeze().cpu().numpy(
                                   ).transpose(1, 2, 0)
                else:
                    input_g_tmp = input_g[:, :, 180:, (ind2 // 2) *
                                          256:(ind2 // 2 + 1) * 256]
                    mask_g_tmp = mask_g[:, :, 180:, (ind2 // 2) *
                                        256:(ind2 // 2 + 1) * 256]
                    _, albedo_fake, shading_fake, shape_fake = composer.forward(
                        input_g_tmp)
                    albedo_fake = albedo_fake * mask_g_tmp
                    lab_refl_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                  256, :] += albedo_fake.squeeze().cpu().numpy(
                                  ).transpose(1, 2, 0)
                    lab_sha_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                 256, :] += shading_fake.squeeze().cpu().numpy(
                                 ).transpose(1, 2, 0)
                    lab_shape_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) *
                                   256, :] += shape_fake.squeeze().cpu().numpy(
                                   ).transpose(1, 2, 0)

            lab_refl_pred[180:256, :, :] /= 2
            lab_sha_pred[180:256, :, :] /= 2
            lab_shape_pred[180:256, :, :] /= 2

            lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            mask = mask_g.squeeze().cpu().numpy().transpose(1, 2, 0)
            # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0)
            # shape_pred = shape_fake.squeeze().cpu().numpy().transpose(1,2,0)
            refl_pred = lab_refl_pred
            sha_pred = lab_sha_pred
            shape_pred = lab_shape_pred

            lab_refl_targ = np.clip(lab_refl_targ, 0, 1)
            lab_sha_targ = np.clip(lab_sha_targ, 0, 1)
            refl_pred = np.clip(refl_pred, 0, 1)
            sha_pred = np.clip(sha_pred, 0, 1)
            shape_pred = np.clip(shape_pred, 0, 1)
            mask = np.clip(mask, 0, 1)

            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_target_fullsize",
                             "{}.png".format(ind)), lab_refl_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_target_fullsize",
                             "{}.png".format(ind)), lab_sha_targ)
            scipy.misc.imsave(
                os.path.join(args.save_path, "mask_fullsize",
                             "{}.png".format(ind)), mask)
            scipy.misc.imsave(
                os.path.join(args.save_path, "refl_output_fullsize",
                             "{}.png".format(ind)), refl_pred)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shad_output_fullsize",
                             "{}.png".format(ind)), sha_pred)
            scipy.misc.imsave(
                os.path.join(args.save_path, "shape_output_fullsize",
                             "{}.png".format(ind)), shape_pred)
예제 #16
0
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode',               type=str,   default='train')
    parser.add_argument('--data_path',          type=str,   default='F:\\ShapeNet',
    help='base folder of datasets')
    parser.add_argument('--save_path',          type=str,   default='logs_shapenet\\RIID_new_RIN_updateLR1_epoch160_CosBF_VGG0.1_shading_SceneSplit_GAN_selayer1_ReflMultiSize_320x320\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--refl_checkpoint',    type=str,   default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',    type=str,   default='shad_checkpoint')
    parser.add_argument('--lr',                 type=float, default=0.001,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=40)
    parser.add_argument('--batch_size',         type=int,   default=20)
    parser.add_argument('--checkpoint',         type=StrToBool,  default=False)
    parser.add_argument('--state_dict_refl',    type=str,   default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',    type=str,    default='composer_shading_state.t7')
    parser.add_argument('--remove_names',       type=str,   default='F:\\ShapeNet\\remove')
    parser.add_argument('--cur_epoch',          type=StrToInt,   default=0)
    parser.add_argument('--skip_se',            type=StrToBool,  default=False)
    parser.add_argument('--cuda',               type=str,        default='cuda:1')
    parser.add_argument('--dilation',           type=StrToBool,  default=False)
    parser.add_argument('--se_improved',        type=StrToBool,  default=False)
    parser.add_argument('--weight_decay',       type=float,      default=0.0001)
    parser.add_argument('--refl_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--shad_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--refl_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--shad_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--refl_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--shad_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--refl_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--refl_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--shad_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--refl_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--refl_grad_flag',     type=StrToBool,  default=False)
    parser.add_argument('--shad_grad_flag',     type=StrToBool,  default=False)
    parser.add_argument('--refl_detach_flag',   type=StrToBool,  default=False)
    parser.add_argument('--shad_detach_flag',   type=StrToBool,  default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool,  default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool,  default=False)
    parser.add_argument('--shad_squeeze_flag',  type=StrToBool,  default=False)
    parser.add_argument('--refl_reduction',     type=StrToInt,   default=8)
    parser.add_argument('--shad_reduction',     type=StrToInt,   default=8)
    parser.add_argument('--refl_bn',            type=StrToBool,  default=True)
    parser.add_argument('--shad_bn',            type=StrToBool,  default=True)
    parser.add_argument('--refl_act',           type=str,        default='relu')
    parser.add_argument('--shad_act',           type=str,        default='relu')
    # parser.add_argument('--refl_gan',           type=StrToBool,  default=False)
    # parser.add_argument('--shad_gan',           type=StrToBool,  default=False)
    parser.add_argument('--data_augmentation',  type=StrToBool,  default=False)
    parser.add_argument('--fullsize',           type=StrToBool,  default=False)
    parser.add_argument('--fullsize_test',      type=StrToBool,  default=False)
    parser.add_argument('--image_size',         type=StrToInt,   default=256)
    parser.add_argument('--ttur',               type=StrToBool,  default=False)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se, detach=args.refl_detach_flag, reduction=args.refl_reduction, bn=args.refl_bn, act=args.refl_act).to(device)
    shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size, low_se=args.shad_low_se, skip_se=args.shad_skip_se, se_squeeze=args.shad_squeeze_flag, reduction=args.shad_reduction, detach=args.shad_detach_flag, bn=args.shad_bn, act=args.shad_act).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(torch.load(os.path.join(args.save_path, args.refl_checkpoint, args.state_dict_refl)))
        shading.load_state_dict(torch.load(os.path.join(args.save_path, args.shad_checkpoint, args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device)
    
    if not args.ttur:
        Discriminator_R = RIN.SEUG_Discriminator().to(device)
        Discriminator_S = RIN.SEUG_Discriminator().to(device)
    else:
        Discriminator_R = RIN.SEUG_Discriminator_new().to(device)
        Discriminator_S = RIN.SEUG_Discriminator_new().to(device)
    
    remove_names = os.listdir(args.remove_names)
    train_set = RIN_pipeline.ShapeNet_Dateset_new_new(args.data_path, size_per_dataset=90000, mode='train', image_size=args.image_size, remove_names=remove_names,refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)
    test_set = RIN_pipeline.ShapeNet_Dateset_new_new(args.data_path, size_per_dataset=1000, mode='test', image_size=args.image_size, remove_names=remove_names)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)

    if args.mode == 'test':
        print('test mode .....')
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, args)
        print('albedo_test_loss: ', albedo_test_loss)
        print('shading_test_loss: ', shading_test_loss)
        return

    writer = SummaryWriter(log_dir=args.save_path)

    if not args.ttur:
        trainer = RIN_pipeline.SEUGTrainer(composer, Discriminator_R, Discriminator_S, train_loader, device, writer, args)
    else:
        trainer = RIN_pipeline.SEUGTrainerNew(composer, Discriminator_R, Discriminator_S, train_loader, device, writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        trainer.train()

        if (epoch + 1) % 10 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
        
        # if (epoch + 1) % 10 == 0:
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, args)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f:
            f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss))
        if args.save_model:
            state = composer.reflectance.state_dict()
            torch.save(state, os.path.join(args.save_path, args.refl_checkpoint, 'composer_reflectance_state_{}.t7'.format(epoch)))
            state = composer.shading.state_dict()
            torch.save(state, os.path.join(args.save_path, args.shad_checkpoint, 'composer_shading_state_{}.t7'.format(epoch)))
        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            # if args.save_model:
            #     state = composer.reflectance.state_dict()
            #     torch.save(state, os.path.join(args.save_path, args.refl_checkpoint, 'composer_reflectance_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            # if args.save_model:
            #     state = composer.shading.state_dict()
            #     torch.save(state, os.path.join(args.save_path, args.shad_checkpoint, 'composer_shading_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))