Ejemplo n.º 1
0
def main():
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='F:\\BOLD_dataset',
                        help='base folder of datasets')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default='logs\\lihao\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=4,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--supervision_set_size', type=int, default=33900)
    # parser.add_argument('--unsupervision_set_size',     type=int,   default=10170)
    parser.add_argument('--num_epochs', type=int, default=80)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--checkpoint', type=bool, default=False)
    parser.add_argument('--img_resize_shape', type=str, default=(256, 256))
    parser.add_argument('--state_dict', type=str, default='composer_state.t7')
    parser.add_argument('--dataset', type=str, default='BOLD')
    parser.add_argument('--remove_names',
                        type=str,
                        default='F:\\ShapeNet\\remove')
    args = parser.parse_args()

    # pylint: disable=E1101
    device = torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
    # pylint: disable=E1101

    shader = RIN.Shader()
    # shader.load_state_dict(torch.load('logs/shader/shader_state_59.t7'))
    decomposer = RIN.Decomposer()
    # reflection.load_state_dict(torch.load('reflection_state.t7'))
    composer = RIN.Composer(decomposer, shader).to(device)
    # RIN.init_weights(composer, init_type='kaiming')

    cur_epoch = 0
    if args.checkpoint:
        # cur_epoch = int(args.state_dict.split('.')[0].split('_')[-1])
        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        print('load checkpoint success!')

    if args.mode == 'train':
        if args.dataset == "ShapeNet":
            remove_names = os.listdir(args.remove_names)
            supervision_train_set = RIN_pipeline.ShapeNet_Dateset_new(
                args.data_path,
                size_per_dataset=args.supervision_set_size,
                mode='train',
                img_size=args.img_resize_shape,
                remove_names=remove_names)
            sv_train_loader = torch.utils.data.DataLoader(
                supervision_train_set,
                batch_size=args.batch_size,
                num_workers=args.loader_threads,
                shuffle=True)
            test_set = RIN_pipeline.ShapeNet_Dateset_new(
                args.data_path,
                size_per_dataset=20,
                mode='test',
                img_size=args.img_resize_shape,
                remove_names=remove_names)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)
        else:
            train_txt = 'lihao/train_list.txt'
            test_txt = 'lihao/test_list.txt'
            supervision_train_set = RIN_pipeline.BOLD_Dataset(
                args.data_path,
                size_per_dataset=args.supervision_set_size,
                mode='train',
                img_size=args.img_resize_shape,
                file_name=train_txt)
            sv_train_loader = torch.utils.data.DataLoader(
                supervision_train_set,
                batch_size=args.batch_size,
                num_workers=args.loader_threads,
                shuffle=True)
            test_set = RIN_pipeline.BOLD_Dataset(
                args.data_path,
                size_per_dataset=20,
                mode='test',
                img_size=args.img_resize_shape,
                file_name=test_txt)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)

        writer = SummaryWriter(log_dir=args.save_path)

        step = 0
        trainer = RIN_pipeline.ShapeNetSupervisionTrainer(
            composer,
            sv_train_loader,
            args.lr,
            device,
            writer,
            step,
            optim_choose=args.optimizer)
        for epoch in range(cur_epoch + 1, args.num_epochs):
            print('<Main> Epoch {}'.format(epoch))
            if epoch % 10 == 0:
                trainer.update_lr(args.lr * 0.75)
            if epoch < 100:
                step = trainer.train()
            # else:
            #     trainer = RIN_pipeline.UnsupervisionTrainer(composer, unsv_train_loader, args.lr, device, writer, new_step)
            #     new_step = trainer.train()

            if args.save_model:
                state = composer.state_dict()
                torch.save(state,
                           os.path.join(args.save_path, 'composer_state.t7'))

            # step += new_step
            # loss = RIN_pipeline.visualize_composer(composer, test_loader, device, os.path.join(args.save_path, '{}.png'.format(epoch)))
            # writer.add_scalar('test_recon_loss', loss[0], epoch)
            # writer.add_scalar('test_refl_loss', loss[1], epoch)
            # writer.add_scalar('test_sha_loss', loss[2], epoch)
    else:
        check_folder(os.path.join(args.save_path, "refl_target"))
        check_folder(os.path.join(args.save_path, "shad_target"))
        check_folder(os.path.join(args.save_path, "refl_output"))
        check_folder(os.path.join(args.save_path, "shad_output"))
        check_folder(os.path.join(args.save_path, "shape_output"))
        if args.dataset == "ShapeNet":
            # check_folder(os.path.join(args.save_path, "mask"))
            remove_names = os.listdir(args.remove_names)
            test_set = RIN_pipeline.ShapeNet_Dateset_new(
                args.data_path,
                size_per_dataset=9488,
                mode='test',
                img_size=args.img_resize_shape,
                remove_names=remove_names)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)
        else:
            test_txt = 'lihao/test_list.txt'
            test_set = RIN_pipeline.BOLD_Dataset(
                args.data_path,
                size_per_dataset=18984,
                mode='test',
                img_size=args.img_resize_shape,
                file_name=test_txt)
            test_loader = torch.utils.data.DataLoader(
                test_set,
                batch_size=1,
                num_workers=args.loader_threads,
                shuffle=False)

        composer.load_state_dict(
            torch.load(os.path.join(args.save_path, args.state_dict)))
        composer.eval()
        with torch.no_grad():
            for ind, tensors in enumerate(test_loader):

                inp = [t.float().to(device) for t in tensors]
                try:
                    lab_inp, lab_refl_targ, lab_sha_targ, mask = inp
                except ValueError:
                    lab_inp, lab_refl_targ, lab_sha_targ = inp
                else:
                    print('input dim should be 3 or 4')
                lab_inp, lab_refl_targ, lab_sha_targ = inp
                recon_pred, refl_pred, sha_pred, shape_pred = composer.forward(
                    lab_inp)

                lab_refl_targ = lab_refl_targ.squeeze().cpu().numpy(
                ).transpose(1, 2, 0)
                lab_sha_targ = lab_sha_targ.squeeze().cpu().numpy().transpose(
                    1, 2, 0)
                # mask = mask.squeeze().cpu().numpy().transpose(1,2,0)
                refl_pred = refl_pred.squeeze().cpu().numpy().transpose(
                    1, 2, 0)
                sha_pred = sha_pred.squeeze().cpu().numpy().transpose(1, 2, 0)
                shape_pred = shape_pred.squeeze().cpu().numpy().transpose(
                    1, 2, 0)
                lab_refl_targ = np.clip(lab_refl_targ, 0, 1)
                lab_sha_targ = np.clip(lab_sha_targ, 0, 1)
                refl_pred = np.clip(refl_pred, 0, 1)
                sha_pred = np.clip(sha_pred, 0, 1)
                shape_pred = np.clip(shape_pred, 0, 1)
                # mask = np.clip(mask, 0, 1)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "refl_target",
                                 "{}.png".format(ind)), lab_refl_targ)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "shad_target",
                                 "{}.png".format(ind)), lab_sha_targ)
                # scipy.misc.imsave(os.path.join(args.save_path, "mask", "{}.png".format(ind)), mask)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "refl_output",
                                 "{}.png".format(ind)), refl_pred)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "shad_output",
                                 "{}.png".format(ind)), sha_pred)
                scipy.misc.imsave(
                    os.path.join(args.save_path, "shape_output",
                                 "{}.png".format(ind)), shape_pred)
Ejemplo n.º 2
0
def main():
    random.seed(520)
    torch.manual_seed(520)
    torch.cuda.manual_seed(520)
    np.random.seed(520)
    cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        type=str,
                        default='E:\\BOLD',
                        help='base folder of datasets')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument(
        '--save_path',
        type=str,
        default='logs_vqvae\\BOLD_base_256x256\\',
        help='save path of model, visualizations, and tensorboard')
    parser.add_argument('--refl_checkpoint',
                        type=str,
                        default='refl_checkpoint')
    parser.add_argument('--shad_checkpoint',
                        type=str,
                        default='shad_checkpoint')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument('--loader_threads',
                        type=float,
                        default=8,
                        help='number of parallel data-loading threads')
    parser.add_argument('--save_model',
                        type=bool,
                        default=True,
                        help='whether to save model or not')
    parser.add_argument('--num_epochs', type=int, default=60)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--checkpoint', type=StrToBool, default=False)
    parser.add_argument('--cur_epoch', type=StrToInt, default=0)
    parser.add_argument('--cuda', type=str, default='cuda')
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--refl_multi_size', type=StrToBool, default=False)
    parser.add_argument('--shad_multi_size', type=StrToBool, default=False)
    parser.add_argument('--refl_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_vgg_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--shad_bf_flag', type=StrToBool, default=True)
    parser.add_argument('--refl_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_cos_flag', type=StrToBool, default=False)
    parser.add_argument('--refl_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--shad_grad_flag', type=StrToBool, default=False)
    parser.add_argument('--vae', type=StrToBool, default=False)
    parser.add_argument('--fullsize_test', type=StrToBool, default=False)
    parser.add_argument('--vq_flag', type=StrToBool, default=False)
    parser.add_argument('--img_resize_shape', type=str, default=(256, 256))
    parser.add_argument('--use_tanh', type=StrToBool, default=False)
    parser.add_argument('--use_inception', type=StrToBool, default=False)
    parser.add_argument('--init_weights', type=StrToBool, default=False)
    parser.add_argument('--adam_flag', type=StrToBool, default=False)
    args = parser.parse_args()

    check_folder(args.save_path)
    check_folder(os.path.join(args.save_path, args.refl_checkpoint))
    check_folder(os.path.join(args.save_path, args.shad_checkpoint))
    # pylint: disable=E1101
    device = torch.device(args.cuda)
    # pylint: disable=E1101
    reflectance = RIN.VQVAE(vq_flag=args.vq_flag,
                            init_weights=args.init_weights,
                            use_tanh=args.use_tanh,
                            use_inception=args.use_inception).to(device)
    shading = RIN.VQVAE(vq_flag=args.vq_flag,
                        init_weights=args.init_weights,
                        use_tanh=args.use_tanh,
                        use_inception=args.use_inception).to(device)
    cur_epoch = 0
    if args.checkpoint:
        reflectance.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.refl_checkpoint,
                             args.state_dict_refl)))
        shading.load_state_dict(
            torch.load(
                os.path.join(args.save_path, args.shad_checkpoint,
                             args.state_dict_shad)))
        cur_epoch = args.cur_epoch
        print('load checkpoint success!')
    composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size,
                              args.shad_multi_size).to(device)

    # train_txt = "BOLD_TXT\\train_list.txt"
    # test_txt = "BOLD_TXT\\test_list.txt"

    supervision_train_set = RIN_pipeline.BOLD_Dataset(
        args.data_path,
        size_per_dataset=40000,
        mode='train',
        img_size=args.img_resize_shape)
    train_loader = torch.utils.data.DataLoader(supervision_train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.loader_threads,
                                               shuffle=True)
    test_set = RIN_pipeline.BOLD_Dataset(args.data_path,
                                         size_per_dataset=None,
                                         mode='val',
                                         img_size=args.img_resize_shape)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              num_workers=args.loader_threads,
                                              shuffle=False)

    if args.mode == 'test':
        print('test mode .....')
        albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
            composer, test_loader, device, args)
        print('albedo_test_loss: ', albedo_test_loss)
        print('shading_test_loss: ', shading_test_loss)
        return

    writer = SummaryWriter(log_dir=args.save_path)

    trainer = RIN_pipeline.BOLDVQVAETrainer(composer, train_loader, device,
                                            writer, args)

    best_albedo_loss = 9999
    best_shading_loss = 9999

    for epoch in range(cur_epoch, args.num_epochs):
        print('<Main> Epoch {}'.format(epoch))

        trainer.train()

        if (epoch + 1) % 20 == 0:
            args.lr = args.lr * 0.75
            trainer.update_lr(args.lr)

        if (epoch + 1) % 5 == 0:
            albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(
                composer, test_loader, device, args)
            average_loss = (albedo_test_loss + shading_test_loss) / 2
            writer.add_scalar('A_mse', albedo_test_loss, epoch)
            writer.add_scalar('S_mse', shading_test_loss, epoch)
            writer.add_scalar('aver_mse', average_loss, epoch)
            with open(os.path.join(args.save_path, 'loss_every_epoch.txt'),
                      'a+') as f:
                f.write(
                    'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'
                    .format(epoch, average_loss, albedo_test_loss,
                            shading_test_loss))
            if albedo_test_loss < best_albedo_loss:
                state = composer.reflectance.state_dict()
                torch.save(
                    state,
                    os.path.join(
                        args.save_path, args.refl_checkpoint,
                        'composer_reflectance_state_{}.t7'.format(epoch)))
                best_albedo_loss = albedo_test_loss
                with open(os.path.join(args.save_path, 'reflectance_loss.txt'),
                          'a+') as f:
                    f.write('epoch{} --- albedo_loss:{}\n'.format(
                        epoch, albedo_test_loss))
            if shading_test_loss < best_shading_loss:
                best_shading_loss = shading_test_loss
                state = composer.shading.state_dict()
                torch.save(
                    state,
                    os.path.join(args.save_path, args.shad_checkpoint,
                                 'composer_shading_state_{}.t7'.format(epoch)))
                with open(os.path.join(args.save_path, 'shading_loss.txt'),
                          'a+') as f:
                    f.write('epoch{} --- shading_loss:{}\n'.format(
                        epoch, shading_test_loss))