コード例 #1
0
def main():
    # random.seed(6666)
    # torch.manual_seed(6666)
    # torch.cuda.manual_seed(6666)
    # np.random.seed(6666)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default=
        'logs_vqvae\\MIT_base_256x256_noRetinex_withBf_ALLleakyrelu_BNUP_Sigmiod_inception_bs4_finetune\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=StrToInt, default=4)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--skip_se', type=StrToBool, default=False)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--dilation', type=StrToBool, default=False)
    parser.add_argument('--se_improved', type=StrToBool, default=False)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--refl_act', type=str, default='relu')
    parser.add_argument('--shad_act', type=str, default='relu')
    parser.add_argument('--data_augmentation', type=StrToBool, default=False)
    parser.add_argument('--fullsize', type=StrToBool, default=True)
    parser.add_argument('--vae', type=StrToBool, default=False)
    parser.add_argument('--fullsize_test', type=StrToBool, default=False)
    parser.add_argument('--vq_flag', type=StrToBool, default=False)
    parser.add_argument('--image_size', type=StrToInt, default=256)
    parser.add_argument('--shad_out_conv', type=StrToInt, default=3)
    parser.add_argument('--finetune', type=StrToBool, default=True)
    parser.add_argument('--use_tanh', type=StrToBool, default=False)
    parser.add_argument('--use_inception', type=StrToBool, default=True)
    parser.add_argument('--use_skip', type=StrToBool, default=True)
    parser.add_argument('--use_multiPredict', type=StrToBool, default=True)
    parser.add_argument('--init_weights', type=StrToBool, default=False)
    parser.add_argument('--adam_flag', type=StrToBool, default=False)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    device = torch.device(args.cuda)

    reflectance = RIN.VQVAE(vq_flag=args.vq_flag,
                            init_weights=args.init_weights,
                            use_tanh=args.use_tanh,
                            use_inception=args.use_inception,
                            use_skip=args.use_skip,
                            use_multiPredict=args.use_multiPredict).to(device)
    shading = RIN.VQVAE(vq_flag=args.vq_flag,
                        init_weights=args.init_weights,
                        use_tanh=args.use_tanh,
                        use_inception=args.use_inception,
                        use_skip=args.use_skip,
                        use_multiPredict=args.use_multiPredict).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    MIT_train_fullsize_txt = 'MIT_TXT\\MIT_BarronSplit_fullsize_train.txt'
    MIT_test_fullsize_txt = 'MIT_TXT\\MIT_BarronSplit_fullsize_test.txt'
    MIT_train_txt = 'MIT_TXT\\MIT_BarronSplit_train.txt'
    MIT_test_txt = 'MIT_TXT\\MIT_BarronSplit_test.txt'
    if args.fullsize and not args.finetune:
        # train_set = RIN_pipeline.MIT_Dataset_Revisit(MIT_train_fullsize_txt, mode='train', refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size, fullsize=args.fullsize)
        train_set = RIN_pipeline.MIT_Dataset_Revisit(
            MIT_train_fullsize_txt,
            mode='train',
            refl_multi_size=args.refl_multi_size,
            shad_multi_size=args.shad_multi_size,
            image_size=args.image_size)
    else:
        train_set = RIN_pipeline.MIT_Dataset_Revisit(
            MIT_train_txt,
            mode='train',
            refl_multi_size=args.refl_multi_size,
            shad_multi_size=args.shad_multi_size,
            image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MIT_Dataset_Revisit(MIT_test_fullsize_txt,
                                                mode='test')
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    writer = SummaryWriter(log_dir=args.save_path)
    best_albedo_loss = 9999
    best_shading_loss = 9999
    best_avg_lmse = 9999
    flag = True
    trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device, writer,
                                        args)
    logging.info('start training....')
    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if epoch >= 80 and args.finetune and flag:
            flag = False
            train_set = RIN_pipeline.MIT_Dataset_Revisit(
                MIT_train_fullsize_txt,
                mode='train',
                refl_multi_size=args.refl_multi_size,
                shad_multi_size=args.shad_multi_size,
                image_size=args.image_size,
                fullsize=args.fullsize)
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=True)
            trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                                writer, args)
            # else:
            #     flag = True
            #     train_set = RIN_pipeline.MIT_Dataset_Revisit(MIT_train_txt, mode='train', refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size)
            #     train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True)
            #     trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device, writer, args)

        albedo_test_loss, shading_test_loss = RIN_pipeline.MIT_test_unet(
            composer, test_loader, device, args)

        if (epoch + 1) % 40 == 0:
            args.lr *= 0.75
            trainer.update_lr(args.lr)

        average_loss = (albedo_test_loss + shading_test_loss) / 2
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))

        if albedo_test_loss < best_albedo_loss:
            best_albedo_loss = albedo_test_loss
            state = composer.reflectance.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.refl_checkpoint,
                             'composer_reflectance_state.t7'))
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(
                    epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            state = composer.shading.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.shad_checkpoint,
                             'composer_shading_state.t7'))
            with open(os.path.join(args.save_path, 'shading_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(
                    epoch, shading_test_loss))
コード例 #2
0
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='SceneSplit')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default='MPI_logs_vqvae\\vqvae_base_256x256\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--state_dict_refl',
                        type=str,
                        default='composer_reflectance_state.t7')
    parser.add_argument('--state_dict_shad',
                        type=str,
                        default='composer_shading_state.t7')
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--skip_se', type=StrToBool, default=False)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--dilation', type=StrToBool, default=False)
    parser.add_argument('--se_improved', type=StrToBool, default=False)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_skip_se', type=StrToBool, default=False)
    parser.add_argument('--shad_skip_se', type=StrToBool, default=False)
    parser.add_argument('--refl_low_se', type=StrToBool, default=False)
    parser.add_argument('--shad_low_se', type=StrToBool, default=False)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_detach_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_D_weight_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_reduction', type=StrToInt, default=8)
    parser.add_argument('--shad_reduction', type=StrToInt, default=8)
    parser.add_argument('--refl_bn', type=StrToBool, default=True)
    parser.add_argument('--shad_bn', type=StrToBool, default=True)
    parser.add_argument('--refl_act', type=str, default='relu')
    parser.add_argument('--shad_act', type=str, default='relu')
    # parser.add_argument('--refl_gan',           type=StrToBool,  default=False)
    # parser.add_argument('--shad_gan',           type=StrToBool,  default=False)
    parser.add_argument('--data_augmentation', type=StrToBool, default=False)
    parser.add_argument('--fullsize', type=StrToBool, default=False)
    parser.add_argument('--fullsize_test', type=StrToBool, default=False)
    parser.add_argument('--image_size', type=StrToInt, default=256)
    parser.add_argument('--ttur', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=True)
    parser.add_argument('--vq_flag', type=StrToBool, default=True)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.VQVAE(vq_flag=args.vq_flag).to(device)
    shading = RIN.VQVAE(vq_flag=args.vq_flag).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    # if not args.ttur:
    #     Discriminator_R = RIN.SEUG_Discriminator().to(device)
    #     Discriminator_S = RIN.SEUG_Discriminator().to(device)
    # else:
    #     Discriminator_R = RIN.SEUG_Discriminator_new().to(device)
    #     Discriminator_S = RIN.SEUG_Discriminator_new().to(device)

    if args.data_augmentation:
        print('data_augmentation.....')
        if args.fullsize:
            MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-train.txt'
            MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-train.txt'
        else:
            MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-320-train.txt'
            MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-320-train.txt'
    else:
        MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt'
        MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt'

    if args.fullsize_test:
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt'
    else:
        MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt'

    if args.fullsize_test:
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt'
    else:
        MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt'

    if args.split == 'ImageSplit':
        train_txt = MPI_Image_Split_train_txt
        test_txt = MPI_Image_Split_test_txt
        print('Image split mode')
    else:
        train_txt = MPI_Scene_Split_train_txt
        test_txt = MPI_Scene_Split_test_txt
        print('Scene split mode')

    if args.data_augmentation:
        print('augmentation...')
        train_transform = RIN_pipeline.MPI_Train_Agumentation_fy2()

    train_set = RIN_pipeline.MPI_Dataset_Revisit(
        train_txt,
        transform=train_transform if args.data_augmentation else None,
        refl_multi_size=args.refl_multi_size,
        shad_multi_size=args.shad_multi_size,
        image_size=args.image_size)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)

    test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.mode == 'test':
        print('test mode .....')
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        print('albedo_test_loss: ', albedo_test_loss)
        print('shading_test_loss: ', shading_test_loss)
        return

    writer = SummaryWriter(log_dir=args.save_path)

    if not args.ttur:
        trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                            writer, args)
    else:
        trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device,
                                            writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if (epoch + 1) % 40 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)

        # if (epoch + 1) % 10 == 0:
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        average_loss = (albedo_test_loss + shading_test_loss) / 2
        writer.add_scalar('A_mse', albedo_test_loss, epoch)
        writer.add_scalar('S_mse', shading_test_loss, epoch)
        writer.add_scalar('aver_mse', average_loss, epoch)
        with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                  'a+') as f:
            f.write(
                'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                .format(epoch, average_loss, albedo_test_loss,
                        shading_test_loss))
        if albedo_test_loss < best_albedo_loss:
            state = composer.reflectance.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.refl_checkpoint,
                             'composer_reflectance_state_{}.t7'.format(epoch)))
            best_albedo_loss = albedo_test_loss
            with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- albedo_loss:{}\n'.format(
                    epoch, albedo_test_loss))
        if shading_test_loss < best_shading_loss:
            best_shading_loss = shading_test_loss
            state = composer.shading.state_dict()
            torch.save(
                state,
                os.path.join(args.save_path, args.shad_checkpoint,
                             'composer_shading_state_{}.t7'.format(epoch)))
            with open(os.path.join(args.save_path, 'shading_loss.txt'),
                      'a+') as f:
                f.write('epoch{} --- shading_loss:{}\n'.format(
                    epoch, shading_test_loss))