コード例 #1
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'logs_vqvae\\MIT_base_256x256_noRetinex_withBf_leakyrelu_BNUP_Sigmiod_inception_bs4_finetune_woMultiPredict\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    parser.add_argument('--shad_out_conv', type=StrToInt, default=3)
    parser.add_argument('--dataset', type=str, default='mit')
    parser.add_argument('--shapenet_g', type=StrToBool, default=False)
    parser.add_argument('--vq_flag', type=StrToBool, default=False)
    parser.add_argument('--use_tanh', type=StrToBool, default=False)
    parser.add_argument('--use_inception', type=StrToBool, default=True)
    parser.add_argument('--use_skip', type=StrToBool, default=True)
    parser.add_argument('--use_multiPredict', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=True)
    args = parser.parse_args()

    device = torch.device(args.cuda)
    if args.vae:
        reflectance = RIN.VQVAE(
            vq_flag=args.vq_flag,
            use_tanh=args.use_tanh,
            use_inception=args.use_inception,
            use_skip=args.use_skip,
            use_multiPredict=args.use_multiPredict).to(device)
        shading = RIN.VQVAE(vq_flag=args.vq_flag,
                            use_tanh=args.use_tanh,
                            use_inception=args.use_inception,
                            use_skip=args.use_skip,
                            use_multiPredict=args.use_multiPredict).to(device)
    else:
        reflectance = RIN.SEDecomposerSingle(
            multi_size=args.refl_multi_size,
            low_se=args.refl_low_se,
            skip_se=args.refl_skip_se,
            detach=args.refl_detach_flag,
            reduction=args.refl_reduction).to(device)
        shading = RIN.SEDecomposerSingle(
            multi_size=args.shad_multi_size,
            low_se=args.shad_low_se,
            skip_se=args.shad_skip_se,
            se_squeeze=args.shad_squeeze_flag,
            reduction=args.shad_reduction,
            detach=args.shad_detach_flag,
            last_conv_ch=args.shad_out_conv).to(device)
    reflectance.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.refl_checkpoint,
                         args.state_dict_refl)))
    shading.load_state_dict(
        torch.load(
            os.path.join(args.save_path, args.shad_checkpoint,
                         args.state_dict_shad)))
    print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    if args.dataset == 'mit':
        if args.fullsize:
            print('test fullsize....')
            test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MIT_TXT\\MIT_BarronSplit_fullsize_test.txt'
        else:
            print('test size256....')
            test_txt = 'MIT_TXT\\MIT_BarronSplit_test.txt'
        test_set = RIN_pipeline.MIT_Dataset_Revisit(test_txt, mode='test')
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=1,
            num_workers=args.loader_threads,
            shuffle=False)
    else:
        remove_names = os.listdir('F:\\ShapeNet\\remove')
        if args.shapenet_g:
            test_set = RIN_pipeline.ShapeNet_Dateset_new_new(
                'F:\\ShapeNet',
                size_per_dataset=9000,
                mode='test',
                image_size=256,
                remove_names=remove_names,
                shapenet_g=args.shapenet_g)
        else:
            test_set = RIN_pipeline.ShapeNet_Dateset_new_new(
                'F:\\ShapeNet',
                size_per_dataset=9000,
                mode='test',
                image_size=256,
                remove_names=remove_names)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=1,
            num_workers=args.loader_threads,
            shuffle=False)

    if args.shapenet_g:
        check_folder(os.path.join(args.save_path, "refl_target_G"))
        check_folder(os.path.join(args.save_path, "shad_target_G"))
        check_folder(os.path.join(args.save_path, "refl_output_G"))
        check_folder(os.path.join(args.save_path, "shad_output_G"))
        check_folder(os.path.join(args.save_path, "mask_G"))
    else:
        if args.fullsize:
            check_folder(os.path.join(args.save_path, "refl_target_fullsize"))
            check_folder(os.path.join(args.save_path, "refl_output_fullsize"))
            check_folder(os.path.join(args.save_path, "shad_target_fullsize"))
            check_folder(os.path.join(args.save_path, "shad_output_fullsize"))
            check_folder(os.path.join(args.save_path, "mask"))
        else:
            check_folder(os.path.join(args.save_path, "refl_target"))
            check_folder(os.path.join(args.save_path, "shad_target"))
            check_folder(os.path.join(args.save_path, "refl_output"))
            check_folder(os.path.join(args.save_path, "shad_output"))
            check_folder(os.path.join(args.save_path, "mask"))

    ToPIL = transforms.ToPILImage()

    composer.eval()
    with torch.no_grad():
        for ind, tensors in enumerate(test_loader):
            print(ind)
            inp = [t.to(device) for t in tensors]
            input_g, albedo_g, shading_g, mask_g = inp

            if args.fullsize:
                h, w = input_g.size()[2], input_g.size()[3]
                pad_h, pad_w = clc_pad(h, w, 16)
                print(pad_h, pad_w)
                tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h))
                tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h))
                input_g = tmp_pad(input_g)
            if args.refl_multi_size and args.shad_multi_size:
                albedo_fake, shading_fake, _, _ = composer.forward(input_g)
            elif args.refl_multi_size or args.shad_multi_size:
                albedo_fake, shading_fake, _ = composer.forward(input_g)
            else:
                albedo_fake, shading_fake = composer.forward(input_g)
            if args.fullsize:
                albedo_fake, shading_fake = tmp_inversepad(
                    albedo_fake), tmp_inversepad(shading_fake)

            if args.use_tanh:
                albedo_fake = (albedo_fake + 1) / 2
                shading_fake = (shading_fake + 1) / 2
                albedo_g = (albedo_g + 1) / 2
                shading_g = (shading_g + 1) / 2

            albedo_fake = albedo_fake * mask_g
            shading_fake = shading_fake * mask_g

            albedo_fake = albedo_fake.cpu().clamp(0, 1)
            shading_fake = shading_fake.cpu().clamp(0, 1)
            albedo_g = albedo_g.cpu().clamp(0, 1)
            shading_g = shading_g.cpu().clamp(0, 1)

            lab_refl_targ = ToPIL(albedo_g.squeeze())
            lab_sha_targ = ToPIL(shading_g.squeeze())
            refl_pred = ToPIL(albedo_fake.squeeze())
            sha_pred = ToPIL(shading_fake.squeeze())
            mask_g = ToPIL(mask_g.cpu().squeeze())

            if args.shapenet_g:
                lab_refl_targ.save(
                    os.path.join(args.save_path, "refl_target_G",
                                 "{}.png".format(ind)))
                lab_sha_targ.save(
                    os.path.join(args.save_path, "shad_target_G",
                                 "{}.png".format(ind)))
                refl_pred.save(
                    os.path.join(args.save_path, "refl_output_G",
                                 "{}.png".format(ind)))
                sha_pred.save(
                    os.path.join(args.save_path, "shad_output_G",
                                 "{}.png".format(ind)))
                mask_g.save(
                    os.path.join(args.save_path, "mask_G",
                                 "{}.png".format(ind)))
            else:
                lab_refl_targ.save(
                    os.path.join(
                        args.save_path, "refl_target_fullsize" if args.fullsize
                        else "refl_target", "{}.png".format(ind)))
                lab_sha_targ.save(
                    os.path.join(
                        args.save_path, "shad_target_fullsize" if args.fullsize
                        else "shad_target", "{}.png".format(ind)))
                refl_pred.save(
                    os.path.join(
                        args.save_path, "refl_output_fullsize" if args.fullsize
                        else "refl_output", "{}.png".format(ind)))
                sha_pred.save(
                    os.path.join(
                        args.save_path, "shad_output_fullsize" if args.fullsize
                        else "shad_output", "{}.png".format(ind)))
                mask_g.save(
                    os.path.join(args.save_path, "mask", "{}.png".format(ind)))
コード例 #2
0
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode',               type=str,   default='train')
    parser.add_argument('--data_path',          type=str,   default='F:\\ShapeNet',
    help='base folder of datasets')
    parser.add_argument('--save_path',          type=str,   default='logs_shapenet\\RIID_new_RIN_updateLR1_epoch160_CosBF_VGG0.1_shading_SceneSplit_GAN_selayer1_ReflMultiSize_320x320\\',
    help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--refl_checkpoint',    type=str,   default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',    type=str,   default='shad_checkpoint')
    parser.add_argument('--lr',                 type=float, default=0.001,
    help='learning rate')
    parser.add_argument('--loader_threads',     type=float, default=8,
    help='number of parallel data-loading threads')
    parser.add_argument('--save_model',         type=bool,  default=True,
    help='whether to save model or not')
    parser.add_argument('--num_epochs',         type=int,   default=40)
    parser.add_argument('--batch_size',         type=int,   default=20)
    parser.add_argument('--checkpoint',         type=StrToBool,  default=False)
    parser.add_argument('--state_dict_refl',    type=str,   default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',    type=str,    default='composer_shading_state.t7')
    parser.add_argument('--remove_names',       type=str,   default='F:\\ShapeNet\\remove')
    parser.add_argument('--cur_epoch',          type=StrToInt,   default=0)
    parser.add_argument('--skip_se',            type=StrToBool,  default=False)
    parser.add_argument('--cuda',               type=str,        default='cuda:1')
    parser.add_argument('--dilation',           type=StrToBool,  default=False)
    parser.add_argument('--se_improved',        type=StrToBool,  default=False)
    parser.add_argument('--weight_decay',       type=float,      default=0.0001)
    parser.add_argument('--refl_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--shad_skip_se',       type=StrToBool,  default=False)
    parser.add_argument('--refl_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--shad_low_se',        type=StrToBool,  default=False)
    parser.add_argument('--refl_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--shad_multi_size',    type=StrToBool,  default=False)
    parser.add_argument('--refl_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_vgg_flag',      type=StrToBool,  default=False)
    parser.add_argument('--refl_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--shad_bf_flag',       type=StrToBool,  default=False)
    parser.add_argument('--refl_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--shad_cos_flag',      type=StrToBool,  default=False)
    parser.add_argument('--refl_grad_flag',     type=StrToBool,  default=False)
    parser.add_argument('--shad_grad_flag',     type=StrToBool,  default=False)
    parser.add_argument('--refl_detach_flag',   type=StrToBool,  default=False)
    parser.add_argument('--shad_detach_flag',   type=StrToBool,  default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool,  default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool,  default=False)
    parser.add_argument('--shad_squeeze_flag',  type=StrToBool,  default=False)
    parser.add_argument('--refl_reduction',     type=StrToInt,   default=8)
    parser.add_argument('--shad_reduction',     type=StrToInt,   default=8)
    parser.add_argument('--refl_bn',            type=StrToBool,  default=True)
    parser.add_argument('--shad_bn',            type=StrToBool,  default=True)
    parser.add_argument('--refl_act',           type=str,        default='relu')
    parser.add_argument('--shad_act',           type=str,        default='relu')
    # parser.add_argument('--refl_gan',           type=StrToBool,  default=False)
    # parser.add_argument('--shad_gan',           type=StrToBool,  default=False)
    parser.add_argument('--data_augmentation',  type=StrToBool,  default=False)
    parser.add_argument('--fullsize',           type=StrToBool,  default=False)
    parser.add_argument('--fullsize_test',      type=StrToBool,  default=False)
    parser.add_argument('--image_size',         type=StrToInt,   default=256)
    parser.add_argument('--ttur',               type=StrToBool,  default=False)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se, detach=args.refl_detach_flag, reduction=args.refl_reduction, bn=args.refl_bn, act=args.refl_act).to(device)
    shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size, low_se=args.shad_low_se, skip_se=args.shad_skip_se, se_squeeze=args.shad_squeeze_flag, reduction=args.shad_reduction, detach=args.shad_detach_flag, bn=args.shad_bn, act=args.shad_act).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(torch.load(os.path.join(args.save_path, args.refl_checkpoint, args.state_dict_refl)))
        shading.load_state_dict(torch.load(os.path.join(args.save_path, args.shad_checkpoint, args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device)
    
    if not args.ttur:
        Discriminator_R = RIN.SEUG_Discriminator().to(device)
        Discriminator_S = RIN.SEUG_Discriminator().to(device)
    else:
        Discriminator_R = RIN.SEUG_Discriminator_new().to(device)
        Discriminator_S = RIN.SEUG_Discriminator_new().to(device)
    
    remove_names = os.listdir(args.remove_names)
    train_set = RIN_pipeline.ShapeNet_Dateset_new_new(args.data_path, size_per_dataset=90000, mode='train', image_size=args.image_size, remove_names=remove_names,refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)
    test_set = RIN_pipeline.ShapeNet_Dateset_new_new(args.data_path, size_per_dataset=1000, mode='test', image_size=args.image_size, remove_names=remove_names)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False)

    if args.mode == 'test':
        print('test mode .....')
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, args)
        print('albedo_test_loss: ', albedo_test_loss)
        print('shading_test_loss: ', shading_test_loss)
        return

    writer = SummaryWriter(log_dir=args.save_path)

    if not args.ttur:
        trainer = RIN_pipeline.SEUGTrainer(composer, Discriminator_R, Discriminator_S, train_loader, device, writer, args)
    else:
        trainer = RIN_pipeline.SEUGTrainerNew(composer, Discriminator_R, Discriminator_S, train_loader, device, writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))
        
        trainer.train()

        if (epoch + 1) % 10 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)
        
        # if (epoch + 1) % 10 == 0:
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, args)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f:
            f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss))
        if args.save_model:
            state = composer.reflectance.state_dict()
            torch.save(state, os.path.join(args.save_path, args.refl_checkpoint, 'composer_reflectance_state_{}.t7'.format(epoch)))
            state = composer.shading.state_dict()
            torch.save(state, os.path.join(args.save_path, args.shad_checkpoint, 'composer_shading_state_{}.t7'.format(epoch)))
        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            # if args.save_model:
            #     state = composer.reflectance.state_dict()
            #     torch.save(state, os.path.join(args.save_path, args.refl_checkpoint, 'composer_reflectance_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            # if args.save_model:
            #     state = composer.shading.state_dict()
            #     torch.save(state, os.path.join(args.save_path, args.shad_checkpoint, 'composer_shading_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))