Exemplo n.º 1
0
    def viz_new_views(self, ref_imgs, tgt_imgs, latent_z, K_ref, K_tgt, R_tr,
                      t_tr, depths):

        proj_ref_imgs = self.mpi_render_view(ref_imgs, tgt_imgs, R_tr, t_tr,
                                             K_ref, K_tgt, depths).to(1)

        # proj_ref_features = proj_ref_imgs[:, :self.opt.num_mpi_f, :, :]
        proj_ref_rgb = proj_ref_imgs[:, self.opt.num_mpi_f:-1, :, :]
        proj_ref_alphas = proj_ref_imgs[:, -1:, :, :]

        # composite_feature_img = projector.over_composite(proj_ref_features, proj_ref_alphas).unsqueeze(0).repeat(self.opt.num_mpi_planes, 1, 1, 1)

        latent_z_rep = latent_z.view(latent_z.size(0), latent_z.size(1), 1,
                                     1).repeat(proj_ref_imgs.size(0), 1,
                                               proj_ref_imgs.size(2),
                                               proj_ref_imgs.size(3))

        depth_planes = depths.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        proj_depth_planes = depth_planes.repeat(1, 1, proj_ref_imgs.size(2),
                                                proj_ref_imgs.size(3))
        warp_mpi_depth = self.create_mpi_depth(proj_depth_planes.to(1),
                                               proj_ref_alphas)

        # target_pts_3d = self.create_3d_pts(warp_mpi_depth, K_tgt).unsqueeze(0).repeat(self.opt.num_mpi_planes, 1, 1, 1)

        proj_ref_inputs_concat = torch.cat([proj_ref_imgs, latent_z_rep],
                                           dim=1)

        render_rgb_mpi = self.netG.forward(proj_ref_inputs_concat)

        composite_rgb_img = projector.over_composite(render_rgb_mpi,
                                                     proj_ref_alphas)
        composite_rgb_img = composite_rgb_img.unsqueeze(0)

        return composite_rgb_img
Exemplo n.º 2
0
    def infer_rgb_from_mpi(self, ref_imgs, img_a, latent_z_a, K_ref, K_a, R_ar,
                           t_ar, depths):

        proj_ref_imgs_a = self.mpi_render_view(ref_imgs, img_a, R_ar, t_ar,
                                               K_ref, K_a, depths).to(1)
        # proj_ref_imgs_b = self.mpi_render_view(ref_imgs, img_b, R_br, t_br, K_ref, K_b, depths).to(1)

        proj_ref_features_a = proj_ref_imgs_a[:, 0:self.opt.num_mpi_f, :, :]
        proj_ref_albedo_a = proj_ref_imgs_a[:, self.opt.num_mpi_f:-1, :, :]
        proj_ref_alphas_a = proj_ref_imgs_a[:, -1:, :, :]

        composite_albedo_a = projector.over_composite(
            proj_ref_albedo_a, proj_ref_alphas_a,
            premultiply_alpha=0).unsqueeze(0)

        depth_planes = depths.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        proj_depth_planes_a = depth_planes.repeat(1, 1,
                                                  proj_ref_imgs_a.size(2),
                                                  proj_ref_imgs_a.size(3))
        warp_mpi_depth_a = self.create_mpi_depth(proj_depth_planes_a.to(1),
                                                 proj_ref_alphas_a)

        proj_ref_mask_a = torch.sum(proj_ref_alphas_a > 1e-8,
                                    dim=0,
                                    keepdim=True)
        proj_ref_mask_a = (proj_ref_mask_a > self.opt.num_mpi_planes * 3 //
                           4).type(torch.cuda.FloatTensor)

        # render image at viewpoint a, condition on image a
        latent_z_a_rep = latent_z_a.view(
            latent_z_a.size(0), latent_z_a.size(1), 1, 1).repeat(
                proj_ref_imgs_a.size(0), 1, proj_ref_imgs_a.size(2),
                proj_ref_imgs_a.size(
                    3))  #.expand(z.size(0), z.size(1), x.size(2), x.size(3))

        buffer_a = torch.cat(
            [proj_ref_features_a, proj_ref_albedo_a, proj_ref_alphas_a], dim=1)
        proj_ref_inputs_concat_a = torch.cat([buffer_a, latent_z_a_rep], dim=1)

        render_rgb_mpi_a = self.netG.forward(proj_ref_inputs_concat_a)

        composite_rgb_img_a = projector.over_composite(
            render_rgb_mpi_a, proj_ref_alphas_a,
            premultiply_alpha=0).unsqueeze(0)

        return composite_rgb_img_a, composite_albedo_a.detach(
        ), warp_mpi_depth_a, proj_ref_mask_a
Exemplo n.º 3
0
        def get_latent_feature(img_a_full, warp_img_a):
            composite_feature_ref = projector.over_composite(
                self.ref_feature_img_small.detach(),
                self.ref_albedo_rgba_mpi_small[:, -1:, :, :],
                premultiply_alpha=0).unsqueeze(0)
            warp_concat = torch.cat(
                [warp_img_a, self.composite_albedo_ref, composite_feature_ref],
                dim=1)

            latent_z_a = self.netE.forward(warp_concat, img_a_full)

            return latent_z_a
Exemplo n.º 4
0
    def render_mpi_imgs_func(self, img_a_full, warp_img_a, img_a, K_a, R_ar,
                             t_ar):

        composite_feature_ref = projector.over_composite(
            self.ref_feature_img_small.detach(),
            self.ref_albedo_rgba_mpi_small[:, -1:, :, :],
            premultiply_alpha=0).unsqueeze(0)

        warp_concat = torch.cat(
            [warp_img_a, self.composite_albedo_ref, composite_feature_ref],
            dim=1)

        latent_z_a = self.netE.forward(warp_concat, img_a_full)

        ref_mpi_concat = torch.cat([self.ref_feature_img, self.ref_albedo_mpi],
                                   dim=1)

        return self.infer_rgb_from_mpi(ref_mpi_concat, img_a, latent_z_a,
                                       self.K_ref, K_a, R_ar, t_ar,
                                       self.mpi_planes)
Exemplo n.º 5
0
    def create_mpi_depth(self, depth_planes, alphas):
        final_depth = projector.over_composite(depth_planes, alphas)

        return final_depth
Exemplo n.º 6
0
    def render_wander(self, save_root_dir):

        num_frames = 90
        img_path = self.targets['img_path_c'][0]
        img_name = img_path.split('/')[-1][:-4]

        save_dir = save_root_dir + self.targets['img_path_a'][
            0] + '_' + self.targets['img_path_b'][0] + '/'

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        max_disp = 32.

        max_trans = max_disp / self.targets['K_a'][
            0, 0,
            0]  # Maximum camera translation to satisfy max_disp parameter
        output_poses = []

        for i in range(num_frames):
            x_trans = max_trans * np.sin(
                2.0 * np.pi * float(i) / float(num_frames))
            y_trans = max_trans * np.cos(
                2.0 * np.pi * float(i) / float(num_frames)) / 2.0  #* 3.0 / 4.0
            z_trans = -max_trans * np.sin(
                2.0 * np.pi * float(i) / float(num_frames)) / 2.0

            i_pose = np.concatenate(
                [
                    np.concatenate([
                        np.eye(3),
                        np.array([x_trans, y_trans, z_trans])[:, np.newaxis]
                    ],
                                   axis=1),
                    # [np.eye(3), np.array([x_trans, 0., 0.])[:, np.newaxis]], axis=1),
                    np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
                ],
                axis=0)[np.newaxis, :, :]

            output_poses.append(i_pose)

        with torch.no_grad():

            K_a = Variable(self.targets['K_a'].cuda(), requires_grad=False)

            R_ar = Variable(self.targets['R_ar'].cuda(), requires_grad=False)
            t_ar = Variable(self.targets['t_ar'].cuda(), requires_grad=False)

            K_b = Variable(self.targets['K_b'].cuda(), requires_grad=False)

            R_br = Variable(self.targets['R_br'].cuda(), requires_grad=False)
            t_br = Variable(self.targets['t_br'].cuda(), requires_grad=False)

            img_b_full = Variable(self.targets['img_b_full'].cuda(),
                                  requires_grad=False)
            warp_img_b = Variable(self.targets['warp_img_b'].cuda(),
                                  requires_grad=False)

            img_a = Variable(self.targets['img_a'].cuda(), requires_grad=False)
            img_b = Variable(self.targets['img_b'].cuda(), requires_grad=False)

            img_a_np = img_a[0].data.cpu().numpy().transpose(1, 2, 0)
            img_b_np = img_b[0].data.cpu().numpy().transpose(1, 2, 0)

            cv2.imwrite(save_dir + '/%s' % self.targets['img_path_a'][0],
                        np.uint8(img_a_np[:, :, ::-1] * 255))
            cv2.imwrite(save_dir + '/%s' % self.targets['img_path_b'][0],
                        np.uint8(img_b_np[:, :, ::-1] * 255))

            latent_z_b = self.get_latent_feature(img_b_full, warp_img_b)
            ref_mpi_concat = torch.cat(
                [self.ref_feature_img, self.ref_albedo_mpi], dim=1)

            w, h = 512, 384

            K_a[:, 0, 2] = K_a[:, 0, 2] * float(w) / img_a.size(3)
            K_a[:, 1, 2] = K_a[:, 1, 2] * float(h) / img_a.size(2)

            scale_factor = 1.25
            K_a_large = K_a.clone()
            K_a_large[:, 0, 2] = K_a[:, 0, 2] * scale_factor
            K_a_large[:, 1, 2] = K_a[:, 1, 2] * scale_factor

            img_a = nn.functional.interpolate(img_a, size=[h, w])

            _, render_rgba_mpi_ba = self.infer_app_mpi_from_mpi(
                ref_mpi_concat,
                nn.functional.interpolate(img_a, scale_factor=scale_factor),
                latent_z_b.to(1), self.K_ref, K_a_large, R_ar, t_ar,
                self.mpi_planes)

            render_rgba_mpi_ba_2 = render_rgba_mpi_ba.to(2)
            K_a_large_2 = K_a_large.to(2)
            K_a_2 = K_a.to(2)
            mpi_planes_2 = self.mpi_planes.to(2)

            for i in range(num_frames):
                print('render ', i)
                pose = output_poses[i]
                R_ta = Variable(torch.from_numpy(
                    np.ascontiguousarray(
                        pose[:, :3, :3])).contiguous().float().cuda(2),
                                requires_grad=False)
                t_ta = Variable(torch.from_numpy(
                    np.ascontiguousarray(
                        pose[:, :3, 3:])).contiguous().float().cuda(2),
                                requires_grad=False)

                proj_ref_imgs_a = self.mpi_render_view(
                    render_rgba_mpi_ba.to(2),
                    img_a,
                    R_ta,
                    t_ta,
                    K_a_large.to(2),
                    K_a.to(2),
                    self.mpi_planes.to(2),
                    cuda_id=2)  #.to(1)

                proj_ref_rgb_a = proj_ref_imgs_a[:, :-1, :, :]
                proj_ref_alphas_a = proj_ref_imgs_a[:, -1:, :, :]
                proj_ref_mask_a = torch.sum(proj_ref_alphas_a > 1e-8,
                                            dim=0,
                                            keepdim=True)
                proj_ref_mask_a = (proj_ref_mask_a > self.opt.num_mpi_planes *
                                   3 // 4).type(torch.cuda.FloatTensor)

                pred_render_rgb_ba = projector.over_composite(
                    proj_ref_rgb_a, proj_ref_alphas_a)

                pred_render_rgb_ba = pred_render_rgb_ba * proj_ref_mask_a
                pred_render_rgb_np = pred_render_rgb_ba[0].data.cpu().numpy(
                ).transpose(1, 2, 0)
                cv2.imwrite(save_dir + '/wander_rgb_ours_%d.png' % i,
                            np.uint8(pred_render_rgb_np[:, :, ::-1] * 255))
                cv2.imwrite(
                    save_dir + '/wander_rgb_ours_%d.png' % (num_frames + i),
                    np.uint8(pred_render_rgb_np[:, :, ::-1] * 255))
Exemplo n.º 7
0
    def __init__(self, opt, _isTrain):
        BaseModel.initialize(self, opt)
        self.num_scales = 4

        pretrain_dir = opt.local_dir + '/pretrain_alpha_new/'

        if opt.dataset == 'trevi':
            pretrain_mpi_path = pretrain_dir + 'latest_trevi_feature_mpi_exp_clean_model_resnet_encoder_lr_0.001_num_mpi_planes_64_max_depth_4_min_depth_1_fov_70_stage_2_use_log_l1_0ref_feature_mpi.npy'
        elif opt.dataset == 'pantheon':
            pretrain_mpi_path = pretrain_dir + 'latest_pantheon_feature_mpi_exp_clean_model_resnet_encoder_lr_0.001_num_mpi_planes_64_max_depth_25_min_depth_1_fov_65_stage_2_use_log_l1_0ref_feature_mpi.npy'
        elif opt.dataset == 'coeur':
            pretrain_mpi_path = pretrain_dir + 'latest_coeur_feature_mpi_exp_clean_model_resnet_encoder_lr_0.001_num_mpi_planes_64_max_depth_20_min_depth_1_fov_65_stage_2_use_log_l1_0ref_feature_mpi.npy'
        elif opt.dataset == 'rock':
            pretrain_mpi_path = pretrain_dir + 'latest_rock_feature_mpi_exp_clean_model_resnet_encoder_lr_0.001_num_mpi_planes_64_max_depth_75_min_depth_1_fov_70_stage_2_use_log_l1_0ref_feature_mpi.npy'
        elif opt.dataset == 'navona':
            pretrain_mpi_path = pretrain_dir + 'latest_navona_feature_mpi_exp_clean_model_resnet_encoder_lr_0.001_num_mpi_planes_64_max_depth_25_min_depth_1_fov_70_stage_2_use_log_l1_0ref_feature_mpi.npy'

        self.criterion_joint = networks.JointLoss(opt)

        ref_albedo_mpi = np.load(pretrain_mpi_path)
        self.ref_albedo_mpi = Variable(torch.from_numpy(
            np.ascontiguousarray(ref_albedo_mpi)).contiguous().float().cuda(),
                                       requires_grad=False)
        self.ref_albedo_mpi = torch.sigmoid(self.ref_albedo_mpi)
        self.ref_albedo_mpi[0, -1:, :, :] = 1.0

        mpi_planes = self.generate_mpi_depth_planes(opt.min_depth,
                                                    opt.max_depth,
                                                    opt.num_mpi_planes)
        self.mpi_planes = Variable(torch.from_numpy(
            np.ascontiguousarray(mpi_planes)).contiguous().float().cuda(),
                                   requires_grad=False)

        self.K_ref = self.create_ref_intrinsic(opt.mpi_w, opt.mpi_h,
                                               opt.ref_fov)
        self.ref_feature_img = self.create_feature_mpi(opt.num_mpi_planes,
                                                       opt.num_mpi_f,
                                                       opt.mpi_w, opt.mpi_h)
        self.opt = opt

        self.ref_albedo_rgba_mpi_small = torch.nn.functional.interpolate(
            self.ref_albedo_mpi, [512, 512], mode='bilinear').to(1)
        self.composite_albedo_ref = projector.over_composite(
            self.ref_albedo_rgba_mpi_small[:, :-1, :, :],
            self.ref_albedo_rgba_mpi_small[:, -1:, :, :],
            premultiply_alpha=0).unsqueeze(0)

        self.ref_feature_img_small = torch.nn.functional.interpolate(
            self.ref_feature_img, [512, 512], mode='bilinear').to(1)

        appearace_encoder = models_spade.AppearanceEncoder(opt)

        if opt.where_add == 'adain':
            num_input_features = 3 + opt.num_mpi_f + 1
            neural_render = networks.define_G(input_nc=num_input_features,
                                              output_nc=3,
                                              nz=0,
                                              ngf=opt.ngf,
                                              norm=opt.norm_G,
                                              nl=opt.nl,
                                              use_dropout=False,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              where_add=opt.where_add,
                                              upsample=opt.upsample,
                                              style_dim=opt.num_latent_f)
        else:
            num_input_features = 3 + opt.num_mpi_f + opt.num_latent_f + 1
            neural_render = networks.define_G(input_nc=num_input_features,
                                              output_nc=3,
                                              nz=0,
                                              ngf=opt.ngf,
                                              norm=opt.norm_G,
                                              nl=opt.nl,
                                              use_dropout=False,
                                              init_type=opt.init_type,
                                              init_gain=opt.init_gain,
                                              where_add=opt.where_add,
                                              upsample=opt.upsample)

        if opt.dataset == 'trevi':
            model_name = '_best_trevi_feature_mpi_exp_independent_warp_model_munit_encoder_lr_0.0004_use_gan_1_use_vgg_loss_1_warp_src_img_1_where_add_adain'
        elif opt.dataset == 'pantheon':
            model_name = '_best_pantheon_feature_mpi_exp_independent_warp_model_munit_encoder_lr_0.0004_use_gan_1_use_vgg_loss_1_warp_src_img_1_where_add_adain'
        elif opt.dataset == 'coeur':
            model_name = '_best_coeur_feature_mpi_exp_independent_warp_model_munit_encoder_lr_0.0004_use_gan_1_use_vgg_loss_1_warp_src_img_1_where_add_adain'
        elif opt.dataset == 'rock':
            model_name = '_best_rock_feature_mpi_exp_independent_warp_model_munit_encoder_lr_0.0004_use_gan_1_use_vgg_loss_1_warp_src_img_1_where_add_adain'
        elif opt.dataset == 'navona':
            model_name = '_best_navona_feature_mpi_exp_independent_warp_model_munit_encoder_lr_0.0004_use_gan_1_use_vgg_loss_1_warp_src_img_1_where_add_adain'

        feature_mpi_name = opt.local_dir + '/deep_mpi/' + model_name[
            1:] + 'ref_feature_img.npy'

        appearace_encoder.load_state_dict(
            self.load_network(appearace_encoder, 'E', model_name))
        neural_render.load_state_dict(
            self.load_network(neural_render, 'G', model_name))

        self.ref_feature_img = Variable(torch.from_numpy(
            np.ascontiguousarray(
                np.load(feature_mpi_name))).contiguous().float().cuda(),
                                        requires_grad=False)

        self.netE = torch.nn.parallel.DataParallel(appearace_encoder.cuda(1),
                                                   device_ids=[1])
        self.netG = torch.nn.parallel.DataParallel(neural_render.cuda(1),
                                                   device_ids=[1, 2, 3])

        self.netE.eval()
        self.netG.eval()

        print('---------- Encoder Networks initialized -------------')
        networks.print_network(self.netE)

        print('---------- neural_render Networks initialized -------------')
        networks.print_network(self.netG)
        print('-----------------------------------------------')