Пример #1
0
class ShapeTrainer(train_utils.Trainer):
    def log_param(self):
        log_param("max_data_num", opts.max_data_num)
        log_param("num_epochs", opts.num_epochs)
        log_param("sphere_initial", opts.sphere_initial)
        log_param("use_gtpose", opts.use_gtpose)
        log_param("kp_loss_wt", opts.kp_loss_wt)
        log_param("vert2kp_loss_wt", opts.vert2kp_loss_wt)
        log_param("renderer", opts.renderer)
        log_param("add_smr_loss", opts.add_smr_loss)

    def define_model(self):
        opts = self.opts
        # ----------
        # Options
        # ----------
        self.symmetric = opts.symmetric
        anno_sfm_path = osp.join(opts.cub_cache_dir, 'sfm', 'anno_train.mat')
        anno_sfm = sio.loadmat(anno_sfm_path,
                               struct_as_record=False,
                               squeeze_me=True)
        if opts.sphere_initial:
            sfm_mean_shape = mesh.create_sphere(3)
        else:
            sfm_mean_shape = (np.transpose(anno_sfm['S']),
                              anno_sfm['conv_tri'] - 1)

        img_size = (opts.img_size, opts.img_size)
        self.model = mesh_net.MeshNet(img_size,
                                      opts,
                                      nz_feat=opts.nz_feat,
                                      num_kps=opts.num_kps,
                                      sfm_mean_shape=sfm_mean_shape)

        if opts.num_pretrain_epochs > 0:
            self.load_network(self.model, 'pred', opts.num_pretrain_epochs)

        self.model = self.model.cuda(device=opts.gpu_id)

        # Data structures to use for triangle priors.
        edges2verts = self.model.edges2verts
        # B x E x 4
        edges2verts = np.tile(np.expand_dims(edges2verts, 0),
                              (opts.batch_size, 1, 1))
        self.edges2verts = Variable(
            torch.LongTensor(edges2verts).cuda(device=opts.gpu_id),
            requires_grad=False)
        # For renderering.
        faces = self.model.faces.view(1, -1, 3)
        self.faces = faces.repeat(opts.batch_size, 1, 1)
        # opts.renderer = "smr"
        self.renderer = NeuralRenderer(
            opts.img_size) if opts.renderer == "nmr" else SoftRenderer(
                opts.img_size)
        self.renderer_predcam = NeuralRenderer(
            opts.img_size) if opts.renderer == "nmr" else SoftRenderer(
                opts.img_size)  #for camera loss via projection

        # Need separate NMR for each fwd/bwd call.
        if opts.texture:
            self.tex_renderer = NeuralRenderer(
                opts.img_size) if opts.renderer == "nmr" else SoftRenderer(
                    opts.img_size)
            # Only use ambient light for tex renderer
            self.tex_renderer.ambient_light_only()

        # For visualization
        self.vis_rend = bird_vis.VisRenderer(opts.img_size,
                                             faces.data.cpu().numpy())

        return

    def init_dataset(self):
        opts = self.opts
        if opts.dataset == 'cub':
            self.data_module = cub_data
        else:
            print('Unknown dataset %d!' % opts.dataset)

        self.dataloader = self.data_module.data_loader(opts)
        self.resnet_transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def define_criterion(self):
        self.projection_loss = loss_utils.kp_l2_loss
        self.mask_loss_fn = torch.nn.MSELoss()
        self.entropy_loss = loss_utils.entropy_loss
        self.deform_reg_fn = loss_utils.deform_l2reg
        self.camera_loss = loss_utils.camera_loss
        self.triangle_loss_fn = loss_utils.LaplacianLoss(self.faces)

        if self.opts.texture:
            self.texture_loss = loss_utils.PerceptualTextureLoss()
            self.texture_dt_loss_fn = loss_utils.texture_dt_loss

        if self.opts.add_smr_loss:
            self.flatten_loss_fn = sr.FlattenLoss(self.model.faces)
            self.ori_reg_fn = loss_utils.sym_reg

    def set_input(self, batch):
        opts = self.opts

        # Image with annotations.
        input_img_tensor = batch['img'].type(torch.FloatTensor)
        for b in range(input_img_tensor.size(0)):
            input_img_tensor[b] = self.resnet_transform(input_img_tensor[b])
        img_tensor = batch['img'].type(torch.FloatTensor)
        mask_tensor = batch['mask'].type(torch.FloatTensor)
        kp_tensor = batch['kp'].type(torch.FloatTensor)
        cam_tensor = batch['sfm_pose'].type(torch.FloatTensor)

        self.input_imgs = Variable(input_img_tensor.cuda(device=opts.gpu_id),
                                   requires_grad=False)
        self.imgs = Variable(img_tensor.cuda(device=opts.gpu_id),
                             requires_grad=False)
        self.masks = Variable(mask_tensor.cuda(device=opts.gpu_id),
                              requires_grad=False)
        self.kps = Variable(kp_tensor.cuda(device=opts.gpu_id),
                            requires_grad=False)
        self.cams = Variable(cam_tensor.cuda(device=opts.gpu_id),
                             requires_grad=False)

        # Compute barrier distance transform.
        mask_dts = np.stack(
            [image_utils.compute_dt_barrier(m) for m in batch['mask']])
        dt_tensor = torch.FloatTensor(mask_dts).cuda(device=opts.gpu_id)
        # B x 1 x N x N
        self.dts_barrier = Variable(dt_tensor,
                                    requires_grad=False).unsqueeze(1)

    def forward(self):
        opts = self.opts
        if opts.texture:
            pred_codes, self.textures = self.model(self.input_imgs)
        else:
            pred_codes = self.model(self.input_imgs)
        self.delta_v, scale, trans, quat = pred_codes

        self.cam_pred = torch.cat([scale, trans, quat], 1)

        if opts.only_mean_sym:
            del_v = self.delta_v
        else:
            del_v = self.model.symmetrize(self.delta_v)

        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()
        self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp, dim=1)
        self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Decide which camera to use for projection.
        if opts.use_gtpose:
            proj_cam = self.cams
        else:
            proj_cam = self.cam_pred

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts, proj_cam)

        # Render mask.
        self.mask_pred = self.renderer(self.pred_v, self.faces, proj_cam)
        if opts.renderer == "smr":
            self.mask_pred = self.mask_pred[0][:, 0, :, :]

        if opts.texture:
            self.texture_flow = self.textures
            self.textures = geom_utils.sample_textures(self.texture_flow,
                                                       self.imgs)
            if opts.renderer == "smr":
                self.textures = self.textures.contiguous()
                bs, fs, ts, _, _ = self.textures.size()
                self.textures = self.textures.view(bs, fs, -1, 3)
                texture_rgba, p2f_info, _ = self.tex_renderer.forward(
                    self.pred_v.detach(), self.faces, proj_cam.detach(),
                    self.textures)
                self.texture_pred = texture_rgba[:, 0:3, :, :]
            else:
                tex_size = self.textures.size(2)
                self.textures = self.textures.unsqueeze(4).repeat(
                    1, 1, 1, 1, tex_size, 1)
                self.texture_pred = self.tex_renderer(self.pred_v.detach(),
                                                      self.faces,
                                                      proj_cam.detach(),
                                                      textures=self.textures)
        else:
            self.textures = None

        # Compute losses for this instance.
        self.kp_loss = self.projection_loss(self.kp_pred, self.kps)
        self.mask_loss = self.mask_loss_fn(self.mask_pred, self.masks)
        self.cam_loss = self.camera_loss(self.cam_pred, self.cams, 0)

        if opts.texture:
            self.tex_loss = self.texture_loss(self.texture_pred, self.imgs,
                                              self.mask_pred, self.masks)
            self.tex_dt_loss = self.texture_dt_loss_fn(self.texture_flow,
                                                       self.dts_barrier)

        # Priors:
        self.vert2kp_loss = self.entropy_loss(self.vert2kp)
        self.deform_reg = self.deform_reg_fn(self.delta_v)
        self.triangle_loss = self.triangle_loss_fn(self.pred_v)
        if opts.add_smr_loss:
            self.flatten_loss = self.flatten_loss_fn(self.pred_v).mean()
            self.ori_loss = self.ori_reg_fn(self.pred_v)

        # Finally sum up the loss.
        # Instance loss:
        self.total_loss = opts.kp_loss_wt * self.kp_loss
        self.total_loss += opts.mask_loss_wt * self.mask_loss
        self.total_loss += opts.cam_loss_wt * self.cam_loss
        if opts.texture:
            self.total_loss += opts.tex_loss_wt * self.tex_loss

        # Priors:
        self.total_loss += opts.vert2kp_loss_wt * self.vert2kp_loss
        self.total_loss += opts.deform_reg_wt * self.deform_reg
        self.total_loss += opts.triangle_reg_wt * self.triangle_loss

        self.total_loss += opts.tex_dt_loss_wt * self.tex_dt_loss

        if opts.add_smr_loss:
            self.total_loss += self.flatten_loss * opts.flatten_reg_wt
            if (self.curr_epoch < opts.stop_ori_epoch):
                # constrain prediction to be symmetric on the given axis
                self.total_loss += self.ori_loss * opts.ori_reg_wt

        log_metric("kp_loss", float(self.kp_loss.cpu().detach().numpy()))
        log_metric("mask_loss", float(self.mask_loss.cpu().detach().numpy()))
        log_metric("vert2kp_loss",
                   float(self.vert2kp_loss.cpu().detach().numpy()))
        log_metric("deform_reg", float(self.deform_reg.cpu().detach().numpy()))
        log_metric("triangle_loss",
                   float(self.triangle_loss.cpu().detach().numpy()))
        log_metric("cam_loss", float(self.cam_loss.cpu().detach().numpy()))
        log_metric("tex_loss", float(self.tex_loss.cpu().detach().numpy()))
        log_metric("tex_dt_loss",
                   float(self.tex_dt_loss.cpu().detach().numpy()))
        log_metric("total_loss", float(self.total_loss.cpu().detach().numpy()))

    def get_current_visuals(self):
        vis_dict = {}
        mask_concat = torch.cat([self.masks, self.mask_pred], 2)

        if self.opts.texture:
            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            uv_flows = uv_flows.permute(0, 2, 3, 1)
            uv_images = torch.nn.functional.grid_sample(self.imgs,
                                                        uv_flows,
                                                        align_corners=True)

        num_show = min(2, self.opts.batch_size)
        show_uv_imgs = []
        show_uv_flows = []

        for i in range(num_show):
            input_img = bird_vis.kp2im(self.kps[i].data, self.imgs[i].data)
            pred_kp_img = bird_vis.kp2im(self.kp_pred[i].data,
                                         self.imgs[i].data)
            masks = bird_vis.tensor2mask(mask_concat[i].data)
            if self.opts.texture:
                texture_here = self.textures[i]
            else:
                texture_here = None

            rend_predcam = self.vis_rend(self.pred_v[i],
                                         self.cam_pred[i],
                                         texture=texture_here)
            # Render from front & back:
            rend_frontal = self.vis_rend.diff_vp(self.pred_v[i],
                                                 self.cam_pred[i],
                                                 texture=texture_here,
                                                 kp_verts=self.kp_verts[i])
            rend_top = self.vis_rend.diff_vp(self.pred_v[i],
                                             self.cam_pred[i],
                                             axis=[0, 1, 0],
                                             texture=texture_here,
                                             kp_verts=self.kp_verts[i])
            diff_rends = np.hstack((rend_frontal, rend_top))

            if self.opts.texture:
                uv_img = bird_vis.tensor2im(uv_images[i].data)
                show_uv_imgs.append(uv_img)
                uv_flow = bird_vis.visflow(uv_flows[i].data)
                show_uv_flows.append(uv_flow)

                tex_img = bird_vis.tensor2im(self.texture_pred[i].data)
                imgs = np.hstack((input_img, pred_kp_img, tex_img))
            else:
                imgs = np.hstack((input_img, pred_kp_img))

            rend_gtcam = self.vis_rend(self.pred_v[i],
                                       self.cams[i],
                                       texture=texture_here)
            rends = np.hstack((diff_rends, rend_predcam, rend_gtcam))
            vis_dict['%d' % i] = np.hstack((imgs, rends, masks))
            vis_dict['masked_img %d' % i] = bird_vis.tensor2im(
                (self.imgs[i] * self.masks[i]).data)

        if self.opts.texture:
            vis_dict['uv_images'] = np.hstack(show_uv_imgs)
            vis_dict['uv_flow_vis'] = np.hstack(show_uv_flows)

        return vis_dict

    def get_current_points(self):
        return {
            'mean_shape': visutil.tensor2verts(self.mean_shape.data),
            'verts': visutil.tensor2verts(self.pred_v.data),
        }

    def get_current_scalars(self):
        sc_dict = OrderedDict([
            ('smoothed_total_loss', self.smoothed_total_loss),
            ('total_loss', self.total_loss.item()),
            ('kp_loss', self.kp_loss.item()),
            ('mask_loss', self.mask_loss.item()),
            ('vert2kp_loss', self.vert2kp_loss.item()),
            ('deform_reg', self.deform_reg.item()),
            ('tri_loss', self.triangle_loss.item()),
            ('cam_loss', self.cam_loss.item()),
        ])
        if self.opts.texture:
            sc_dict['tex_loss'] = self.tex_loss.item()
            sc_dict['tex_dt_loss'] = self.tex_dt_loss.item()
        if self.opts.add_smr_loss:
            sc_dict['flatten_loss'] = self.flatten_loss
            sc_dict['ori_loss'] = self.ori_loss

        return sc_dict
Пример #2
0
class MeshPredictor(object):
    def __init__(self, opts):
        self.opts = opts

        self.symmetric = opts.symmetric

        img_size = (opts.img_size, opts.img_size)
        print('Setting up model..')
        self.model = mesh_net.MeshNet(img_size, opts, nz_feat=opts.nz_feat)

        self.load_network(self.model, 'pred', self.opts.num_train_epoch)
        self.model.eval()
        self.model = self.model.cuda(device=self.opts.gpu_id)

        # TODO junzhe option of renderer
        print('self.opts.renderer_opt:', self.opts.renderer_opt)
        if self.opts.renderer_opt == 'nmr':
            from nnutils.nmr import NeuralRenderer
        elif self.opts.renderer_opt == 'nmr_kaolin':
            from nnutils.nmr_kaolin import NeuralRenderer
        elif self.opts.renderer_opt == 'dibr_kaolin':
            from nnutils.dibr_kaolin import NeuralRenderer
        else:
            raise NotImplementedError

        self.renderer = NeuralRenderer(opts.img_size,
                                       uv_sampler=self.model.uv_sampler)

        if opts.texture:
            self.tex_renderer = NeuralRenderer(
                opts.img_size, uv_sampler=self.model.uv_sampler)
            # Only use ambient light for tex renderer
            self.tex_renderer.ambient_light_only()

        if opts.use_sfm_ms:
            anno_sfm_path = osp.join(opts.cub_cache_dir, 'sfm',
                                     'anno_testval.mat')
            anno_sfm = sio.loadmat(anno_sfm_path,
                                   struct_as_record=False,
                                   squeeze_me=True)
            sfm_mean_shape = torch.Tensor(np.transpose(
                anno_sfm['S'])).cuda(device=opts.gpu_id)
            self.sfm_mean_shape = Variable(sfm_mean_shape, requires_grad=False)
            self.sfm_mean_shape = self.sfm_mean_shape.unsqueeze(0).repeat(
                opts.batch_size, 1, 1)
            sfm_face = torch.LongTensor(anno_sfm['conv_tri'] -
                                        1).cuda(device=opts.gpu_id)
            self.sfm_face = Variable(sfm_face, requires_grad=False)
            faces = self.sfm_face.view(1, -1, 3)
        else:
            # For visualization
            faces = self.model.faces.view(1, -1, 3)
        self.faces = faces.repeat(opts.batch_size, 1, 1)
        self.vis_rend = bird_vis.VisRenderer(opts.img_size,
                                             faces.data.cpu().numpy(),
                                             self.opts,
                                             uv_sampler=self.model.uv_sampler)
        self.vis_rend.set_bgcolor([1., 1., 1.])

        self.resnet_transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def load_network(self, network, network_label, epoch_label):
        save_filename = '{}_net_{}.pth'.format(network_label, epoch_label)
        network_dir = os.path.join(self.opts.checkpoint_dir, self.opts.name)
        save_path = os.path.join(network_dir, save_filename)
        print('loading {}..'.format(save_path))
        network.load_state_dict(torch.load(save_path))

        return

    def set_input(self, batch):
        opts = self.opts

        # original image where texture is sampled from.
        img_tensor = batch['img'].clone().type(torch.FloatTensor)

        # input_img is the input to resnet
        input_img_tensor = batch['img'].type(torch.FloatTensor)
        for b in range(input_img_tensor.size(0)):
            input_img_tensor[b] = self.resnet_transform(input_img_tensor[b])

        self.input_imgs = Variable(input_img_tensor.cuda(device=opts.gpu_id),
                                   requires_grad=False)
        self.imgs = Variable(img_tensor.cuda(device=opts.gpu_id),
                             requires_grad=False)
        if opts.use_sfm_camera:
            cam_tensor = batch['sfm_pose'].type(torch.FloatTensor)
            self.sfm_cams = Variable(cam_tensor.cuda(device=opts.gpu_id),
                                     requires_grad=False)

    def predict(self, batch):
        """
        batch has B x C x H x W numpy
        """
        self.set_input(batch)
        self.forward()
        return self.collect_outputs()

    def forward(self):
        # print('predict forward')
        # print('self.opts.texture:',self.opts.texture)
        # print('self.opts.use_sfm_camera:',self.opts.use_sfm_camera)
        if self.opts.texture:
            pred_codes, self.textures = self.model.forward(self.input_imgs)
            # print('--tex from model:',self.textures.shape) # [B, 1280, 6, 6, 2]
        else:
            pred_codes = self.model.forward(self.input_imgs)

        self.delta_v, scale, trans, quat = pred_codes

        if self.opts.use_sfm_camera:
            self.cam_pred = self.sfm_cams
        else:
            self.cam_pred = torch.cat([scale, trans, quat], 1)

        del_v = self.model.symmetrize(self.delta_v)
        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()

        if self.opts.use_sfm_ms:
            self.pred_v = self.sfm_mean_shape
        elif self.opts.ignore_pred_delta_v:
            self.pred_v = self.mean_shape + del_v * 0
        else:
            self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        if self.opts.use_sfm_ms:
            self.kp_verts = self.pred_v
        else:
            self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp,
                                                       dim=1)
            self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts,
                                                    self.cam_pred)
        self.mask_pred = self.renderer.forward(self.pred_v, self.faces,
                                               self.cam_pred)
        # import pdb; pdb.set_trace()
        # Render texture.
        if self.opts.texture and not self.opts.use_sfm_ms:
            if self.textures.size(-1) == 2:
                # Flow texture!
                self.texture_flow = self.textures
                # print('--tex b4 sample:',self.textures.shape)
                self.textures = geom_utils.sample_textures(
                    self.textures, self.imgs)
                # print('--tex after sample:',self.textures.shape)
            if self.textures.dim() == 5:  # B x F x T x T x 3
                tex_size = self.textures.size(2)
                # print('--tex b4 unsqueeze:',self.textures.shape)
                self.textures = self.textures.unsqueeze(4).repeat(
                    1, 1, 1, 1, tex_size, 1)
                # print('--tex after unsqueeze:',self.textures.shape)
            # Render texture:
            # texture_pred is rendered img
            self.texture_pred = self.tex_renderer.forward(
                self.pred_v, self.faces, self.cam_pred, textures=self.textures)

            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            self.uv_flows = uv_flows.permute(0, 2, 3, 1)
            self.uv_images = torch.nn.functional.grid_sample(
                self.imgs, self.uv_flows, align_corners=True)
        else:
            self.textures = None

    def collect_outputs(self):
        outputs = {
            'kp_pred': self.kp_pred.data,
            'verts': self.pred_v.data,
            'kp_verts': self.kp_verts.data,
            'cam_pred': self.cam_pred.data,
            'mask_pred': self.mask_pred.data,
        }
        if self.opts.texture and not self.opts.use_sfm_ms:
            outputs['texture'] = self.textures
            outputs['texture_pred'] = self.texture_pred.data
            outputs['uv_image'] = self.uv_images.data
            outputs['uv_flow'] = self.uv_flows.data

        return outputs
Пример #3
0
class VisRenderer(object):
    """
    Utility to render meshes using pytorch NMR
    faces are F x 3 or 1 x F x 3 numpy
    """
    def __init__(self, img_size, faces, t_size=3):
        self.renderer = NeuralRenderer(img_size)
        self.faces = Variable(torch.IntTensor(faces).cuda(),
                              requires_grad=False)
        if self.faces.dim() == 2:
            self.faces = torch.unsqueeze(self.faces, 0)
        default_tex = np.ones(
            (1, self.faces.shape[1], t_size, t_size, t_size, 3))
        blue = np.array([156, 199, 234.]) / 255.
        default_tex = default_tex * blue
        # Could make each triangle different color
        self.default_tex = Variable(torch.FloatTensor(default_tex).cuda(),
                                    requires_grad=False)
        # rot = transformations.quaternion_about_axis(np.pi/8, [1, 0, 0])
        # This is median quaternion from sfm_pose
        # rot = np.array([ 0.66553962,  0.31033762, -0.02249813,  0.01267084])
        # This is the side view:
        import cv2
        R0 = cv2.Rodrigues(np.array([np.pi / 3, 0, 0]))[0]
        R1 = cv2.Rodrigues(np.array([0, np.pi / 2, 0]))[0]
        R = R1.dot(R0)
        R = np.vstack((np.hstack((R, np.zeros((3, 1)))), np.array([0, 0, 0,
                                                                   1])))
        rot = transformations.quaternion_from_matrix(R, isprecise=True)
        cam = np.hstack([0.75, 0, 0, rot])
        self.default_cam = Variable(torch.FloatTensor(cam).cuda(),
                                    requires_grad=False)
        self.default_cam = torch.unsqueeze(self.default_cam, 0)

    def __call__(self, verts, cams=None, texture=None, rend_mask=False):
        """
        verts is |V| x 3 cuda torch Variable
        cams is 7, cuda torch Variable
        Returns N x N x 3 numpy
        """
        #print("visrender call")
        if texture is None:
            texture = self.default_tex
        elif texture.dim() == 5:
            # Here input it F x T x T x T x 3 (instead of F x T x T x 3)
            # So add batch dim.
            texture = torch.unsqueeze(texture, 0)
        if cams is None:
            cams = self.default_cam
        elif cams.dim() == 1:
            cams = torch.unsqueeze(cams, 0)

        if verts.dim() == 2:
            verts = torch.unsqueeze(verts, 0)

#------------------------------ edited by parker
#------------------------------this is edited bird mesh----
        f = open("edited_bird_mesh.off", "w")
        f.write("OFF\n")
        line = str(len(verts[0])) + " " + str(len(self.faces[0])) + " 0\n"
        f.write(line)

        mesh_x = np.empty(len(verts[0]))
        mesh_y = np.empty(len(verts[0]))
        mesh_z = np.empty(len(verts[0]))
        #        print("bird_vis verts:",verts[0])
        for i in range(len(verts[0])):
            line = str(float(verts[0][i][0])) + " " + str(float(
                verts[0][i][1])) + " " + str(float(verts[0][i][2])) + "\n"
            f.write(line)
            for j in range(3):
                if (j == 0):
                    mesh_x[i] = verts[0][i][j]
                elif (j == 1):
                    mesh_y[i] = verts[0][i][j]
                else:
                    mesh_z[i] = verts[0][i][j]
        tri_i = np.empty(len(self.faces[0]))
        tri_j = np.empty(len(self.faces[0]))
        tri_k = np.empty(len(self.faces[0]))

        for i in range(len(self.faces[0])):
            line = str(3) + " " + str(int(self.faces[0][i][0])) + " " + str(
                int(self.faces[0][i][1])) + " " + str(int(
                    self.faces[0][i][2])) + "\n"
            f.write(line)
            for j in range(3):
                if (j == 0):
                    tri_i[i] = self.faces[0][i][j]
                elif (j == 1):
                    tri_j[i] = self.faces[0][i][j]
                else:
                    tri_k[i] = self.faces[0][i][j]


#        fig = go.Figure(
#            data=[go.Mesh3d(x=mesh_x, y=mesh_y, z=mesh_z, color='lightblue', opacity=0.5, i=tri_i, j=tri_j, k=tri_k)])
#        fig.show()
        f.close()

        # ----------------------edited by parker-----------------

        verts = asVariable(verts)
        cams = asVariable(cams)
        texture = asVariable(texture)

        if rend_mask:
            rend = self.renderer.forward(verts, self.faces, cams)
            rend = rend.repeat(3, 1, 1)
            rend = rend.unsqueeze(0)
        else:
            rend = self.renderer.forward(verts, self.faces, cams, texture)

        rend = rend.data.cpu().numpy()[0].transpose((1, 2, 0))
        rend = np.clip(rend, 0, 1) * 255.0

        return rend.astype(np.uint8)

    def rotated(self, vert, deg, axis=[0, 1, 0], cam=None, texture=None):
        """
        vert is N x 3, torch FloatTensor (or Variable)
        """
        import cv2
        new_rot = cv2.Rodrigues(np.deg2rad(deg) * np.array(axis))[0]
        new_rot = convert_as(torch.FloatTensor(new_rot), vert)

        center = vert.mean(0)
        new_vert = torch.t(torch.matmul(new_rot,
                                        torch.t(vert - center))) + center
        # new_vert = torch.matmul(vert - center, new_rot) + center

        return self.__call__(new_vert, cams=cam, texture=texture)

    def diff_vp(self,
                verts,
                cam=None,
                angle=90,
                axis=[1, 0, 0],
                texture=None,
                kp_verts=None,
                new_ext=None,
                extra_elev=False):
        if cam is None:
            cam = self.default_cam[0]
        if new_ext is None:
            new_ext = [0.6, 0, 0]
        # Cam is 7D: [s, tx, ty, rot]
        import cv2
        cam = asVariable(cam)
        quat = cam[-4:].view(1, 1, -1)
        R = transformations.quaternion_matrix(
            quat.squeeze().data.cpu().numpy())[:3, :3]
        rad_angle = np.deg2rad(angle)
        rotate_by = cv2.Rodrigues(rad_angle * np.array(axis))[0]
        # new_R = R.dot(rotate_by)

        new_R = rotate_by.dot(R)
        if extra_elev:
            # Left multiply the camera by 30deg on X.
            R_elev = cv2.Rodrigues(np.array([np.pi / 9, 0, 0]))[0]
            new_R = R_elev.dot(new_R)
        # Make homogeneous
        new_R = np.vstack(
            [np.hstack((new_R, np.zeros((3, 1)))),
             np.array([0, 0, 0, 1])])
        new_quat = transformations.quaternion_from_matrix(new_R,
                                                          isprecise=True)
        new_quat = Variable(torch.Tensor(new_quat).cuda(), requires_grad=False)
        # new_cam = torch.cat([cam[:-4], new_quat], 0)
        new_ext = Variable(torch.Tensor(new_ext).cuda(), requires_grad=False)
        new_cam = torch.cat([new_ext, new_quat], 0)

        rend_img = self.__call__(verts, cams=new_cam, texture=texture)
        if kp_verts is None:
            return rend_img
        else:
            kps = self.renderer.project_points(kp_verts.unsqueeze(0),
                                               new_cam.unsqueeze(0))
            kps = kps[0].data.cpu().numpy()
            return kp2im(kps, rend_img, radius=1)

    def set_bgcolor(self, color):
        self.renderer.set_bgcolor(color)

    def set_light_dir(self, direction, int_dir=0.8, int_amb=0.8):
        renderer = self.renderer.renderer
        renderer.light_direction = direction
        renderer.light_intensity_directional = int_dir
        renderer.light_intensity_ambient = int_amb
Пример #4
0
class MeshPredictor(object):
    def __init__(self, opts):
        self.opts = opts

        self.symmetric = opts.symmetric
        #img_size是(256,256)
        img_size = (opts.img_size, opts.img_size)
        print('Setting up model..')
        #-----------------目前猜測是在這一行的什後從mean mesh變成learned mesh的
        #        print(opts.nz_feat)
        #        exit()
        #nz_feat目前不確定是哪冒出來的,還要找源頭
        #nz_feat 為200
        self.model = mesh_net.MeshNet(img_size, opts, nz_feat=opts.nz_feat)
        #-----------------------------------經這一個之後就被改變了得到一個337的verts,但原本的verts至少有600個所以它可能是將某些點更動了,
        # 也可能是它會透過對稱的手法來變成完整的mean shape
        self.load_network(self.model, 'pred', self.opts.num_train_epoch)
        #model 從training()模式轉換成評估模式
        self.model.eval()

        self.model = self.model.cuda(device=self.opts.gpu_id)

        self.renderer = NeuralRenderer(opts.img_size)

        if opts.texture:  #--------------------這個只是true而已
            self.tex_renderer = NeuralRenderer(opts.img_size)
            # Only use ambient light for tex renderer
            self.tex_renderer.ambient_light_only()
#--------------------------------這邊將initial mean shape拿進去訓練得到 訓練過後的learned mean shape
#----------------是否使用use_sfm_ms(它門預設都沒有,這個mesh非常的簡陋,它必須經過學習才會得到一個mean shape
        if opts.use_sfm_ms:
            anno_sfm_path = osp.join(opts.cub_cache_dir, 'sfm',
                                     'anno_testval.mat')
            anno_sfm = sio.loadmat(anno_sfm_path,
                                   struct_as_record=False,
                                   squeeze_me=True)
            sfm_mean_shape = torch.Tensor(np.transpose(
                anno_sfm['S'])).cuda(device=opts.gpu_id)
            self.sfm_mean_shape = Variable(sfm_mean_shape, requires_grad=False)
            self.sfm_mean_shape = self.sfm_mean_shape.unsqueeze(0).repeat(
                opts.batch_size, 1, 1)
            sfm_face = torch.LongTensor(anno_sfm['conv_tri'] -
                                        1).cuda(device=opts.gpu_id)
            self.sfm_face = Variable(sfm_face, requires_grad=False)
            faces = self.sfm_face.view(1, -1, 3)
#-------------------------------------------
        else:
            # For visualization
            faces = self.model.faces.view(1, -1, 3)

        self.faces = faces.repeat(opts.batch_size, 1, 1)
        #--------------------------------------這邊會到vis render init()
        self.vis_rend = bird_vis.VisRenderer(opts.img_size,
                                             faces.data.cpu().numpy())
        self.vis_rend.set_bgcolor([1., 1., 1.])
        self.resnet_transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def load_network(self, network, network_label, epoch_label):
        save_filename = '{}_net_{}.pth'.format(network_label, epoch_label)
        network_dir = os.path.join(self.opts.checkpoint_dir, self.opts.name)
        save_path = os.path.join(network_dir, save_filename)
        print('loading {}..'.format(save_path))
        network.load_state_dict(torch.load(save_path))

        return

    def set_input(self, batch):
        opts = self.opts

        # original image where texture is sampled from.
        img_tensor = batch['img'].clone().type(torch.FloatTensor)

        # input_img is the input to resnet
        input_img_tensor = batch['img'].type(torch.FloatTensor)
        for b in range(input_img_tensor.size(0)):
            input_img_tensor[b] = self.resnet_transform(input_img_tensor[b])

        self.input_imgs = Variable(input_img_tensor.cuda(device=opts.gpu_id),
                                   requires_grad=False)
        self.imgs = Variable(img_tensor.cuda(device=opts.gpu_id),
                             requires_grad=False)
        if opts.use_sfm_camera:
            cam_tensor = batch['sfm_pose'].type(torch.FloatTensor)
            self.sfm_cams = Variable(cam_tensor.cuda(device=opts.gpu_id),
                                     requires_grad=False)

    def predict(self, batch):
        """
        batch has B x C x H x W numpy
        """
        self.set_input(batch)
        self.forward()
        return self.collect_outputs()

    def forward(self):
        if self.opts.texture:
            pred_codes, self.textures = self.model.forward(
                self.input_imgs)  #這邊得到的textures就是1 1280 6 6 2
        else:
            pred_codes = self.model.forward(self.input_imgs)

        self.delta_v, scale, trans, quat = pred_codes

        if self.opts.use_sfm_camera:
            self.cam_pred = self.sfm_cams
        else:
            self.cam_pred = torch.cat([scale, trans, quat], 1)

        del_v = self.model.symmetrize(self.delta_v)
        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()
        #-------------------------edited by parker
        #----------------------------這確實是mean shape----------------
        f = open("bird_mean_mesh.off", "w")
        f.write("OFF\n")
        line = str(len(self.mean_shape)) + " " + str(len(
            self.faces[0])) + " 0\n"
        f.write(line)
        mesh_x = np.empty(len(self.mean_shape))
        mesh_y = np.empty(len(self.mean_shape))
        mesh_z = np.empty(len(self.mean_shape))
        # print("bird_vis verts:", self.mean_shape)
        for i in range(len(self.mean_shape)):
            mesh_x_point = float(self.mean_shape[i][0])
            mesh_y_point = float(self.mean_shape[i][1])
            mesh_z_point = float(self.mean_shape[i][2])

            line = str(mesh_x_point) + " " + str(mesh_y_point) + " " + str(
                mesh_z_point) + "\n"
            f.write(line)
            for j in range(3):
                if (j == 0):
                    mesh_x[i] = self.mean_shape[i][j]
                elif (j == 1):
                    mesh_y[i] = self.mean_shape[i][j]
                else:
                    mesh_z[i] = self.mean_shape[i][j]

        tri_i = np.empty(len(self.faces[0]))
        tri_j = np.empty(len(self.faces[0]))
        tri_k = np.empty(len(self.faces[0]))

        for i in range(len(self.faces[0])):

            #-------------------------
            face_point1 = int(self.faces[0][i][0])
            face_point2 = int(self.faces[0][i][1])
            face_point3 = int(self.faces[0][i][2])
            #--------------------------------------

            line = str(3) + " " + str(face_point1) + " " + str(
                face_point2) + " " + str(face_point3) + "\n"
            f.write(line)
            for j in range(3):
                if (j == 0):
                    tri_i[i] = self.faces[0][i][j]
                elif (j == 1):
                    tri_j[i] = self.faces[0][i][j]
                else:
                    tri_k[i] = self.faces[0][i][j]
#--------------我暫時不需要顯示這些東西
#        fig = go.Figure(
#            data=[go.Mesh3d(x=mesh_x, y=mesh_y, z=mesh_z, color='lightgreen', opacity=0.5,i=tri_i, j=tri_j, k=tri_k)])
#        fig.show()
        f.close()
        #---------------------------------------------------------
        #        exit()
        if self.opts.use_sfm_ms:
            self.pred_v = self.sfm_mean_shape
        elif self.opts.ignore_pred_delta_v:
            self.pred_v = self.mean_shape + del_v * 0
        else:
            self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        if self.opts.use_sfm_ms:
            self.kp_verts = self.pred_v
        else:
            self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp,
                                                       dim=1)
            self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts,
                                                    self.cam_pred)
        self.mask_pred = self.renderer.forward(self.pred_v, self.faces,
                                               self.cam_pred)

        # Render texture.
        if self.opts.texture and not self.opts.use_sfm_ms:
            if self.textures.size(-1) == 2:
                # Flow texture!
                self.texture_flow = self.textures
                #-----------------------
                # txt_file = open("texture_flow.txt", "w")
                # txt_file.write(repr(self.textures.shape))
                # txt_file.write(repr(self.textures))
                # txt_file.close()
                #-----------------------
                self.textures = geom_utils.sample_textures(
                    self.textures, self.imgs)
#-----------------------edited by parker
# txt_file=open("texture_sample_textures.txt","w")
# txt_file.write(repr(self.textures.shape))
# txt_file.write(repr(self.textures))
# txt_file.close()

            if self.textures.dim() == 5:  # B x F x T x T x 3
                tex_size = self.textures.size(2)
                self.textures = self.textures.unsqueeze(4).repeat(
                    1, 1, 1, 1, tex_size, 1)  #這一行部知道在幹麻

            # Render texture:
            self.texture_pred = self.tex_renderer.forward(
                self.pred_v, self.faces, self.cam_pred, textures=self.textures)

            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            self.uv_flows = uv_flows.permute(0, 2, 3, 1)
            self.uv_images = torch.nn.functional.grid_sample(
                self.imgs, self.uv_flows, align_corners=True)
            #edited_by parker
            # uv_flows=open("uv_flows.txt","w")
            # uv_flows.write(repr(self.uv_flows.shape))
            # uv_flows.write(repr(self.uv_flows))
            # uv_flows.close()
            # uv_images=open("uv_images.txt","w")
            # uv_images.write(repr(self.uv_images[0].shape))
            # uv_images_png=np.reshape(self.uv_images[0],(128,256,3))
            # uv_images.write(repr(uv_images_png))
            # uv_images.close()
            #---------------------
            #----------------------------------show uv image------ parker
            uv_image_array = np.zeros([128, 256, 3])

            for i in range(len(self.uv_images[0])):
                for j in range(len(self.uv_images[0][i])):
                    for k in range(len(self.uv_images[0][i][j])):
                        uv_image_array[j][k][i] = self.uv_images[0][i][j][k]
            import matplotlib.pyplot as plt
            plt.imshow(uv_image_array)
            plt.draw()
            plt.show()
            plt.savefig('uv_image_test.png')
            #----------------------------------
        else:
            self.textures = None

    def collect_outputs(self):
        outputs = {
            'kp_pred': self.kp_pred.data,
            'verts': self.pred_v.data,
            'kp_verts': self.kp_verts.data,
            'cam_pred': self.cam_pred.data,
            'mask_pred': self.mask_pred.data,
        }
        if self.opts.texture and not self.opts.use_sfm_ms:
            outputs['texture'] = self.textures
            outputs['texture_pred'] = self.texture_pred.data
            outputs['uv_image'] = self.uv_images.data
            outputs['uv_flow'] = self.uv_flows.data

        return outputs
Пример #5
0
class VisRenderer(object):
    """
    Utility to render meshes using pytorch NMR
    faces are F x 3 or 1 x F x 3 numpy
    """
    def __init__(self, img_size, faces, opts, t_size=3, uv_sampler=None):
        self.opts = opts
        # TODO junzhe option of renderer
        if self.opts.renderer_opt == 'nmr':
            from nnutils.nmr import NeuralRenderer
        elif self.opts.renderer_opt == 'nmr_kaolin':
            from nnutils.nmr_kaolin import NeuralRenderer
        elif self.opts.renderer_opt == 'dibr_kaolin':
            from nnutils.dibr_kaolin import NeuralRenderer
        else:
            raise NotImplementedError
        self.renderer = NeuralRenderer(img_size, uv_sampler=uv_sampler)
        self.faces = Variable(torch.IntTensor(faces).cuda(),
                              requires_grad=False)
        if self.faces.dim() == 2:
            self.faces = torch.unsqueeze(self.faces, 0)
        default_tex = np.ones(
            (1, self.faces.shape[1], t_size, t_size, t_size, 3))
        blue = np.array([156, 199, 234.]) / 255.
        default_tex = default_tex * blue
        # Could make each triangle different color
        self.default_tex = Variable(torch.FloatTensor(default_tex).cuda(),
                                    requires_grad=False)
        # rot = transformations.quaternion_about_axis(np.pi/8, [1, 0, 0])
        # This is median quaternion from sfm_pose
        # rot = np.array([ 0.66553962,  0.31033762, -0.02249813,  0.01267084])
        # This is the side view:
        import cv2
        R0 = cv2.Rodrigues(np.array([np.pi / 3, 0, 0]))[0]
        R1 = cv2.Rodrigues(np.array([0, np.pi / 2, 0]))[0]
        R = R1.dot(R0)
        R = np.vstack((np.hstack((R, np.zeros((3, 1)))), np.array([0, 0, 0,
                                                                   1])))
        rot = transformations.quaternion_from_matrix(R, isprecise=True)
        cam = np.hstack([0.75, 0, 0, rot])
        self.default_cam = Variable(torch.FloatTensor(cam).cuda(),
                                    requires_grad=False)
        self.default_cam = torch.unsqueeze(self.default_cam, 0)

    def __call__(self, verts, cams=None, texture=None, rend_mask=False):
        """
        verts is |V| x 3 cuda torch Variable
        cams is 7, cuda torch Variable
        Returns N x N x 3 numpy
        """
        if texture is None:
            texture = self.default_tex
        elif texture.dim() == 5:
            # Here input it F x T x T x T x 3 (instead of F x T x T x 3)
            # So add batch dim.
            texture = torch.unsqueeze(texture, 0)
        if cams is None:
            cams = self.default_cam
        elif cams.dim() == 1:
            cams = torch.unsqueeze(cams, 0)

        if verts.dim() == 2:
            verts = torch.unsqueeze(verts, 0)

        verts = asVariable(verts)
        cams = asVariable(cams)
        texture = asVariable(texture)

        if rend_mask:
            rend = self.renderer.forward(verts, self.faces, cams)
            rend = rend.repeat(3, 1, 1)
            rend = rend.unsqueeze(0)
        else:
            rend = self.renderer.forward(verts, self.faces, cams, texture)

        rend = rend.data.cpu().numpy()[0].transpose((1, 2, 0))
        rend = np.clip(rend, 0, 1) * 255.0

        return rend.astype(np.uint8)

    def rotated(self, vert, deg, axis=[0, 1, 0], cam=None, texture=None):
        """
        vert is N x 3, torch FloatTensor (or Variable)
        """
        import cv2
        new_rot = cv2.Rodrigues(np.deg2rad(deg) * np.array(axis))[0]
        new_rot = convert_as(torch.FloatTensor(new_rot), vert)

        center = vert.mean(0)
        new_vert = torch.t(torch.matmul(new_rot,
                                        torch.t(vert - center))) + center
        # new_vert = torch.matmul(vert - center, new_rot) + center

        return self.__call__(new_vert, cams=cam, texture=texture)

    def diff_vp(self,
                verts,
                cam=None,
                angle=90,
                axis=[1, 0, 0],
                texture=None,
                kp_verts=None,
                new_ext=None,
                extra_elev=False):
        if cam is None:
            cam = self.default_cam[0]
        if new_ext is None:
            new_ext = [0.6, 0, 0]
        # Cam is 7D: [s, tx, ty, rot]
        import cv2
        cam = asVariable(cam)
        quat = cam[-4:].view(1, 1, -1)
        R = transformations.quaternion_matrix(
            quat.squeeze().data.cpu().numpy())[:3, :3]
        rad_angle = np.deg2rad(angle)
        rotate_by = cv2.Rodrigues(rad_angle * np.array(axis))[0]
        # new_R = R.dot(rotate_by)

        new_R = rotate_by.dot(R)
        if extra_elev:
            # Left multiply the camera by 30deg on X.
            R_elev = cv2.Rodrigues(np.array([np.pi / 9, 0, 0]))[0]
            new_R = R_elev.dot(new_R)
        # Make homogeneous
        new_R = np.vstack(
            [np.hstack((new_R, np.zeros((3, 1)))),
             np.array([0, 0, 0, 1])])
        new_quat = transformations.quaternion_from_matrix(new_R,
                                                          isprecise=True)
        new_quat = Variable(torch.Tensor(new_quat).cuda(), requires_grad=False)
        # new_cam = torch.cat([cam[:-4], new_quat], 0)
        new_ext = Variable(torch.Tensor(new_ext).cuda(), requires_grad=False)
        new_cam = torch.cat([new_ext, new_quat], 0)

        rend_img = self.__call__(verts, cams=new_cam, texture=texture)
        if kp_verts is None:
            return rend_img
        else:
            kps = self.renderer.project_points(kp_verts.unsqueeze(0),
                                               new_cam.unsqueeze(0))
            kps = kps[0].data.cpu().numpy()
            return kp2im(kps, rend_img, radius=1)

    def set_bgcolor(self, color):
        self.renderer.set_bgcolor(color)

    def set_light_dir(self, direction, int_dir=0.8, int_amb=0.8):
        renderer = self.renderer.renderer
        renderer.light_direction = direction
        renderer.light_intensity_directional = int_dir
        renderer.light_intensity_ambient = int_amb
Пример #6
0
class ShapeTrainer(train_utils.Trainer):
    def define_model(self):
        opts = self.opts

        # ----------
        # Options
        # ----------
        #是否對稱
        self.symmetric = opts.symmetric
        #--------------parker 不確定這是幹麻的
        #--------------anno_sfm_path會得到一個路徑
        anno_sfm_path = osp.join(opts.cub_cache_dir, 'sfm', 'anno_train.mat')
        #----anno_train.mat是只有15個點的最基礎的鳥模型
        anno_sfm = sio.loadmat(anno_sfm_path,
                               struct_as_record=False,
                               squeeze_me=True)
        #----------將anno_sfm["S"]是15個點的vertex,anno_sfm["conv_tri"]是15個點組成的面
        sfm_mean_shape = (np.transpose(anno_sfm['S']),
                          anno_sfm['conv_tri'] - 1)
        #--------------
        img_size = (opts.img_size, opts.img_size)
        #這邊會進去mesh_net.py的MeshNet
        self.model = mesh_net.MeshNet(img_size,
                                      opts,
                                      nz_feat=opts.nz_feat,
                                      num_kps=opts.num_kps,
                                      sfm_mean_shape=sfm_mean_shape)
        #-----如果有已經訓練過得epochs則執行這一行
        if opts.num_pretrain_epochs > 0:
            self.load_network(self.model, 'pred', opts.num_pretrain_epochs)

        self.model = self.model.cuda(device=opts.gpu_id)

        # Data structures to use for triangle priors.
        #---------這邊是拿取modle已經計算好的eges2verts
        edges2verts = self.model.edges2verts
        # B x E x 4
        edges2verts = np.tile(np.expand_dims(edges2verts, 0),
                              (opts.batch_size, 1, 1))
        self.edges2verts = Variable(
            torch.LongTensor(edges2verts).cuda(device=opts.gpu_id),
            requires_grad=False)
        # For renderering.
        faces = self.model.faces.view(1, -1, 3)
        self.faces = faces.repeat(opts.batch_size, 1, 1)
        #include nmr並且取名叫做Neural Renderer
        self.renderer = NeuralRenderer(opts.img_size)
        self.renderer_predcam = NeuralRenderer(
            opts.img_size)  #for camera loss via projection

        #如果要計算texture的話,會執行下面這一行,但我反而不懂,為什麼要執行這一行
        # Need separate NMR for each fwd/bwd call.
        if opts.texture:
            self.tex_renderer = NeuralRenderer(opts.img_size)
            # Only use ambient light for tex renderer
            self.tex_renderer.ambient_light_only()

        # For visualization
        self.vis_rend = bird_vis.VisRenderer(opts.img_size,
                                             faces.data.cpu().numpy())

        # import ipdb
        # ipdb.set_trace()
        # for k,v in self.model.named_modules():
        #         v.register_backward_hook(hook)

        return

    def init_dataset(self):
        opts = self.opts
        if opts.dataset == 'cub':
            self.data_module = cub_data
        else:
            print('Unknown dataset %d!' % opts.dataset)

        self.dataloader = self.data_module.data_loader(opts)
        self.resnet_transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def define_criterion(self):
        self.projection_loss = loss_utils.kp_l2_loss
        self.mask_loss_fn = torch.nn.MSELoss()
        self.entropy_loss = loss_utils.entropy_loss
        self.deform_reg_fn = loss_utils.deform_l2reg
        self.camera_loss = loss_utils.camera_loss
        self.triangle_loss_fn = loss_utils.LaplacianLoss(self.faces)

        if self.opts.texture:
            self.texture_loss = loss_utils.PerceptualTextureLoss()
            self.texture_dt_loss_fn = loss_utils.texture_dt_loss

    def set_input(self, batch):
        opts = self.opts

        # Image with annotations.
        input_img_tensor = batch['img'].type(torch.FloatTensor)
        for b in range(input_img_tensor.size(0)):
            input_img_tensor[b] = self.resnet_transform(input_img_tensor[b])
        img_tensor = batch['img'].type(torch.FloatTensor)
        mask_tensor = batch['mask'].type(torch.FloatTensor)
        kp_tensor = batch['kp'].type(torch.FloatTensor)
        cam_tensor = batch['sfm_pose'].type(torch.FloatTensor)

        self.input_imgs = Variable(input_img_tensor.cuda(device=opts.gpu_id),
                                   requires_grad=False)
        self.imgs = Variable(img_tensor.cuda(device=opts.gpu_id),
                             requires_grad=False)
        self.masks = Variable(mask_tensor.cuda(device=opts.gpu_id),
                              requires_grad=False)
        self.kps = Variable(kp_tensor.cuda(device=opts.gpu_id),
                            requires_grad=False)
        self.cams = Variable(cam_tensor.cuda(device=opts.gpu_id),
                             requires_grad=False)

        # Compute barrier distance transform.
        mask_dts = np.stack(
            [image_utils.compute_dt_barrier(m) for m in batch['mask']])
        dt_tensor = torch.FloatTensor(mask_dts).cuda(device=opts.gpu_id)
        # B x 1 x N x N
        self.dts_barrier = Variable(dt_tensor,
                                    requires_grad=False).unsqueeze(1)

    def forward(self):
        opts = self.opts
        if opts.texture:
            pred_codes, self.textures = self.model(self.input_imgs)
        else:
            pred_codes = self.model(self.input_imgs)
        self.delta_v, scale, trans, quat = pred_codes

        self.cam_pred = torch.cat([scale, trans, quat], 1)

        if opts.only_mean_sym:
            del_v = self.delta_v
        else:
            del_v = self.model.symmetrize(self.delta_v)

        # Deform mean shape:
        self.mean_shape = self.model.get_mean_shape()
        self.pred_v = self.mean_shape + del_v

        # Compute keypoints.
        self.vert2kp = torch.nn.functional.softmax(self.model.vert2kp, dim=1)
        self.kp_verts = torch.matmul(self.vert2kp, self.pred_v)

        # Decide which camera to use for projection.
        if opts.use_gtpose:
            proj_cam = self.cams
        else:
            proj_cam = self.cam_pred

        # Project keypoints
        self.kp_pred = self.renderer.project_points(self.kp_verts, proj_cam)

        # Render mask.
        self.mask_pred = self.renderer(self.pred_v, self.faces, proj_cam)

        if opts.texture:
            self.texture_flow = self.textures
            self.textures = geom_utils.sample_textures(self.texture_flow,
                                                       self.imgs)
            tex_size = self.textures.size(2)
            self.textures = self.textures.unsqueeze(4).repeat(
                1, 1, 1, 1, tex_size, 1)

            self.texture_pred = self.tex_renderer(self.pred_v.detach(),
                                                  self.faces,
                                                  proj_cam.detach(),
                                                  textures=self.textures)
        else:
            self.textures = None

        # Compute losses for this instance.
        self.kp_loss = self.projection_loss(self.kp_pred, self.kps)
        self.mask_loss = self.mask_loss_fn(self.mask_pred, self.masks)
        self.cam_loss = self.camera_loss(self.cam_pred, self.cams, 0)

        if opts.texture:
            self.tex_loss = self.texture_loss(self.texture_pred, self.imgs,
                                              self.mask_pred, self.masks)
            self.tex_dt_loss = self.texture_dt_loss_fn(self.texture_flow,
                                                       self.dts_barrier)

        # Priors:
        self.vert2kp_loss = self.entropy_loss(self.vert2kp)
        self.deform_reg = self.deform_reg_fn(self.delta_v)
        self.triangle_loss = self.triangle_loss_fn(self.pred_v)

        # Finally sum up the loss.
        # Instance loss:
        self.total_loss = opts.kp_loss_wt * self.kp_loss
        self.total_loss += opts.mask_loss_wt * self.mask_loss
        self.total_loss += opts.cam_loss_wt * self.cam_loss
        if opts.texture:
            self.total_loss += opts.tex_loss_wt * self.tex_loss

        # Priors:
        self.total_loss += opts.vert2kp_loss_wt * self.vert2kp_loss
        self.total_loss += opts.deform_reg_wt * self.deform_reg
        self.total_loss += opts.triangle_reg_wt * self.triangle_loss

        self.total_loss += opts.tex_dt_loss_wt * self.tex_dt_loss

    def get_current_visuals(self):
        vis_dict = {}
        mask_concat = torch.cat([self.masks, self.mask_pred], 2)

        if self.opts.texture:
            # B x 2 x H x W
            uv_flows = self.model.texture_predictor.uvimage_pred
            # B x H x W x 2
            uv_flows = uv_flows.permute(0, 2, 3, 1)
            uv_images = torch.nn.functional.grid_sample(self.imgs,
                                                        uv_flows,
                                                        align_corners=True)

        num_show = min(2, self.opts.batch_size)
        show_uv_imgs = []
        show_uv_flows = []

        for i in range(num_show):
            input_img = bird_vis.kp2im(self.kps[i].data, self.imgs[i].data)
            pred_kp_img = bird_vis.kp2im(self.kp_pred[i].data,
                                         self.imgs[i].data)
            masks = bird_vis.tensor2mask(mask_concat[i].data)
            if self.opts.texture:
                texture_here = self.textures[i]
            else:
                texture_here = None

            rend_predcam = self.vis_rend(self.pred_v[i],
                                         self.cam_pred[i],
                                         texture=texture_here)
            # Render from front & back:
            rend_frontal = self.vis_rend.diff_vp(self.pred_v[i],
                                                 self.cam_pred[i],
                                                 texture=texture_here,
                                                 kp_verts=self.kp_verts[i])
            rend_top = self.vis_rend.diff_vp(self.pred_v[i],
                                             self.cam_pred[i],
                                             axis=[0, 1, 0],
                                             texture=texture_here,
                                             kp_verts=self.kp_verts[i])
            diff_rends = np.hstack((rend_frontal, rend_top))

            if self.opts.texture:
                uv_img = bird_vis.tensor2im(uv_images[i].data)
                show_uv_imgs.append(uv_img)
                uv_flow = bird_vis.visflow(uv_flows[i].data)
                show_uv_flows.append(uv_flow)

                tex_img = bird_vis.tensor2im(self.texture_pred[i].data)
                imgs = np.hstack((input_img, pred_kp_img, tex_img))
            else:
                imgs = np.hstack((input_img, pred_kp_img))

            rend_gtcam = self.vis_rend(self.pred_v[i],
                                       self.cams[i],
                                       texture=texture_here)
            rends = np.hstack((diff_rends, rend_predcam, rend_gtcam))
            vis_dict['%d' % i] = np.hstack((imgs, rends, masks))
            vis_dict['masked_img %d' % i] = bird_vis.tensor2im(
                (self.imgs[i] * self.masks[i]).data)

        if self.opts.texture:
            vis_dict['uv_images'] = np.hstack(show_uv_imgs)
            vis_dict['uv_flow_vis'] = np.hstack(show_uv_flows)

        return vis_dict

    def get_current_points(self):
        return {
            'mean_shape': visutil.tensor2verts(self.mean_shape.data),
            'verts': visutil.tensor2verts(self.pred_v.data),
        }

    def get_current_scalars(self):
        sc_dict = OrderedDict([
            ('smoothed_total_loss', self.smoothed_total_loss),
            ('total_loss', self.total_loss.item()),
            ('kp_loss', self.kp_loss.item()),
            ('mask_loss', self.mask_loss.item()),
            ('vert2kp_loss', self.vert2kp_loss.item()),
            ('deform_reg', self.deform_reg.item()),
            ('tri_loss', self.triangle_loss.item()),
            ('cam_loss', self.cam_loss.item()),
        ])
        if self.opts.texture:
            sc_dict['tex_loss'] = self.tex_loss.item()
            sc_dict['tex_dt_loss'] = self.tex_dt_loss.item()

        return sc_dict