Ejemplo n.º 1
0
def mesh_random_translation(mesh, bound: float, device: str = ""):
    """
        Generates a random translation for the mesh.
        Input:
            -mesh: Pytorch3d meshes
            -bound: translation in pixel units
        Returns the altered mesh
    """

    upper = bound
    lower = -1 * bound

    t_params_x = round(random.uniform(lower, upper), 2)
    # Y bounds logic
    y_upper = min(upper, t_params_x) if t_params_x >= 0 else upper
    y_lower = max(lower, t_params_x) if t_params_x < 0 else lower
    t_params_y = round(random.uniform(y_lower, y_upper), 2)
    # Z bounds logic
    z_upper = min(upper, max(t_params_x, t_params_y)) if t_params_y >= 0 else upper
    z_lower = min(upper, max(t_params_x, t_params_y)) if t_params_y < 0 else lower
    t_params_z = round(random.uniform(z_lower, z_upper), 2)

    transform = Transform3d(device=device).translate(t_params_x, t_params_y, t_params_z)
    verts, faces = mesh.get_mesh_verts_faces(0)
    verts = transform.transform_points(verts)
    mesh = mesh.update_padded(verts.unsqueeze(0))

    return mesh, transform
Ejemplo n.º 2
0
def rotate_mesh_around_axis(
    mesh, rot: list, renderer, dist: float = 3.5, save: str = "", device: str = ""
):

    if not device:
        device = torch.cuda.current_device()

    rot_x = RotateAxisAngle(rot[0], "X", device=device)
    rot_y = RotateAxisAngle(rot[1], "Y", device=device)
    rot_z = RotateAxisAngle(rot[2], "Z", device=device)

    rot = Transform3d(device=device).stack(rot_x, rot_y, rot_z)

    verts, faces = mesh.get_mesh_verts_faces(0)
    verts = rot_x.transform_points(verts)
    verts = rot_y.transform_points(verts)
    verts = rot_z.transform_points(verts)
    mesh = mesh.update_padded(verts.unsqueeze(0))

    dist = dist
    elev = 0
    azim = 0

    R, T = look_at_view_transform(dist=dist, elev=elev, azim=azim, device=device)

    image_ref = renderer(meshes_world=mesh, R=R, T=T, device=device)
    image_ref = image_ref.cpu().numpy()[..., :3]

    plt.imshow(image_ref.squeeze())
    plt.show()

    if save:
        verts, faces = mesh.get_mesh_verts_faces(0)
        save_obj(save, verts, faces)
    return mesh
Ejemplo n.º 3
0
def translate_mesh_on_axis(
    mesh, t: list, renderer, dist: float = 3.5, save: str = "", device: str = ""
):

    translation = Transform3d(device=device).translate(t[0], t[1], t[2])

    verts, faces = mesh.get_mesh_verts_faces(0)
    verts = translation.transform_points(verts)
    mesh = mesh.update_padded(verts.unsqueeze(0))

    dist = dist
    elev = 0
    azim = 0

    R, T = look_at_view_transform(dist=dist, elev=elev, azim=azim, device=device)

    image_ref = renderer(meshes_world=mesh, R=R, T=T, device=device)
    image_ref = image_ref.cpu().numpy()

    plt.imshow(image_ref.squeeze())
    plt.show()

    if save:
        verts, faces = mesh.get_mesh_verts_faces(0)
        save_obj(save, verts, faces)
    return mesh
Ejemplo n.º 4
0
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the orthographic projection matrix.
        Use column major order.

        Args:
            **kwargs: parameters for the projection can be passed in to
                      override the default values set in __init__.
        Return:
            P: a Transform3d object which represents a batch of projection
               matrices of shape (N, 4, 4)

        .. code-block:: python

            scale_x = 2 / (max_x - min_x)
            scale_y = 2 / (max_y - min_y)
            scale_z = 2 / (far-near)
            mid_x = (max_x + min_x) / (max_x - min_x)
            mix_y = (max_y + min_y) / (max_y - min_y)
            mid_z = (far + near) / (far−near)

            P = [
                    [scale_x,        0,         0,  -mid_x],
                    [0,        scale_y,         0,  -mix_y],
                    [0,              0,  -scale_z,  -mid_z],
                    [0,              0,         0,       1],
            ]
        """
        znear = kwargs.get("znear", self.znear)  # pyre-ignore[16]
        zfar = kwargs.get("zfar", self.zfar)  # pyre-ignore[16]
        max_x = kwargs.get("max_x", self.max_x)  # pyre-ignore[16]
        min_x = kwargs.get("min_x", self.min_x)  # pyre-ignore[16]
        max_y = kwargs.get("max_y", self.max_y)  # pyre-ignore[16]
        min_y = kwargs.get("min_y", self.min_y)  # pyre-ignore[16]
        scale_xyz = kwargs.get("scale_xyz", self.scale_xyz)  # pyre-ignore[16]

        P = torch.zeros((self._N, 4, 4),
                        dtype=torch.float32,
                        device=self.device)
        ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
        # NOTE: OpenGL flips handedness of coordinate system between camera
        # space and NDC space so z sign is -ve. In PyTorch3D we maintain a
        # right handed coordinate system throughout.
        z_sign = +1.0

        P[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
        P[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
        P[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
        P[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
        P[:, 3, 3] = ones

        # NOTE: This maps the z coordinate to the range [0, 1] and replaces the
        # the OpenGL z normalization to [-1, 1]
        P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
        P[:, 2, 3] = -znear / (zfar - znear)

        transform = Transform3d(device=self.device)
        transform._matrix = P.transpose(1, 2).contiguous()
        return transform
Ejemplo n.º 5
0
    def prepare_pose(self, p: dict) -> Transform3d:
        # transform evimo coordinate system to pytorch3d coordinate system
        pos_t = self.evimo_to_pytorch3d_xyz(p)
        pos_R = self.evimo_to_pytorch3d_Rotation(p)
        R_tmp = Rotate(pos_R)
        w2v_transform = R_tmp.translate(pos_t)

        return Transform3d(matrix=w2v_transform.get_matrix())
Ejemplo n.º 6
0
def camera_to_eye_at_up(
    world_to_view_transform: Transform3d,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Given a world to view transform, return the eye, at and up vectors which
    represent its position.

    For example, if cam is a camera object, then after running

    .. code-block::

        from cameras import look_at_view_transform
        eye, at, up = camera_to_eye_at_up(cam.get_world_to_view_transform())
        R, T = look_at_view_transform(eye=eye, at=at, up=up)

    any other camera created from R and T will have the same world to view
    transform as cam.

    Also, given a camera position R and T, then after running:

    .. code-block::

        from cameras import get_world_to_view_transform, look_at_view_transform
        eye, at, up = camera_to_eye_at_up(get_world_to_view_transform(R=R, T=T))
        R2, T2 = look_at_view_transform(eye=eye, at=at, up=up)

    R2 will equal R and T2 will equal T.

    Args:
        world_to_view_transform: Transform3d representing the extrinsic
            transformation of N cameras.

    Returns:
        eye: FloatTensor of shape [N, 3] representing the camera centers in world space.
        at: FloatTensor of shape [N, 3] representing points in world space directly in
            front of the cameras e.g. the positions of objects to be viewed by the
            cameras.
        up: FloatTensor of shape [N, 3] representing vectors in world space which
            when projected on to the camera plane point upwards.
    """
    cam_trans = world_to_view_transform.inverse()
    # In the PyTorch3D right handed coordinate system, the camera in view space
    # is always at the origin looking along the +z axis.

    # The up vector is not a position so cannot be transformed with
    # transform_points. However the position eye+up above the camera
    # (whose position vector in the camera coordinate frame is an up vector)
    # can be transformed with transform_points.
    eye_at_up_view = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0]],
                                  dtype=torch.float32,
                                  device=cam_trans.device)
    eye_at_up_world = cam_trans.transform_points(eye_at_up_view).reshape(
        -1, 3, 3)

    eye, at, up_plus_eye = eye_at_up_world.unbind(1)
    up = up_plus_eye - eye
    return eye, at, up
Ejemplo n.º 7
0
    def forward(self, mesh):
        # R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
        # T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]  # (1, 3)

        t = Transform3d(device=self.device).scale(
            self.camera_position[3] * self.distance_range).rotate_axis_angle(
                self.camera_position[0] * self.angle_range,
                axis="X",
                degrees=False).rotate_axis_angle(
                    self.camera_position[1] * self.angle_range,
                    axis="Y",
                    degrees=False).rotate_axis_angle(self.camera_position[2] *
                                                     self.angle_range,
                                                     axis="Z",
                                                     degrees=False)
        # translation = Translate(T[0][0], T[0][1], T[0][2], device=self.device)

        # t = Transform3d(matrix=self.camera_position)
        vertices = t.transform_points(self.vertices)

        R = look_at_rotation(vertices[:self.nviews], device=self.device)
        T = -torch.bmm(R.transpose(1, 2), vertices[:self.nviews, :,
                                                   None])[:, :, 0]

        if self.light:
            images = torch.empty(self.nviews * len(mesh),
                                 224,
                                 224,
                                 4,
                                 device=self.device)
            # the loop is needed because for now pytorch3d do not allow a batch of lights
            for i in range(self.nviews):
                self.lights.location = vertices[i]
                imgs = self.renderer(meshes_world=mesh.clone(),
                                     R=R[None, i],
                                     T=T[None, i],
                                     lights=self.lights)
                for k, j in zip(range(len(imgs)),
                                range(0,
                                      len(imgs) * self.nviews, self.nviews)):
                    images[i + j] = imgs[k]
        else:
            # self.lights.location = self.light_position
            meshes = mesh.extend(self.nviews)
            # because now we have n elements in R and T we need to expand them to be the same size of meshes
            R = R.repeat(len(mesh), 1, 1)
            T = T.repeat(len(mesh), 1)

            images = self.renderer(meshes_world=meshes.clone(), R=R,
                                   T=T)  # , lights=self.lights)

        images = images.permute(0, 3, 1, 2)
        y = self.net_1(images[:, :3, :, :])
        y = y.view((int(images.shape[0] / self.nviews), self.nviews,
                    y.shape[-3], y.shape[-2], y.shape[-1]))
        return self.net_2(torch.max(y, 1)[0].view(y.shape[0], -1))
Ejemplo n.º 8
0
    def rotate_vert(self, vertices, angles, trans):
        transformer = Transform3d(device=self.device)
        transformer = transformer.rotate_axis_angle(angles[:, 0],
                                                    self.rot_order[0], False)
        transformer = transformer.rotate_axis_angle(angles[:, 1],
                                                    self.rot_order[1], False)
        transformer = transformer.rotate_axis_angle(angles[:, 2],
                                                    self.rot_order[2], False)
        transformer = transformer.translate(trans)

        rotate_vert = transformer.transform_points(vertices)
        return rotate_vert
Ejemplo n.º 9
0
def _make_node_transform(node: Dict[str, Any]) -> Transform3d:
    """
    Convert a transform from the json data in to a PyTorch3D
    Transform3d format.
    """
    array = node.get("matrix")
    if array is not None:  # Stored in column-major order
        M = np.array(array, dtype=np.float32).reshape(4, 4, order="F")
        return Transform3d(matrix=torch.from_numpy(M))

    out = Transform3d()

    # Given some of (scale/rotation/translation), we do them in that order to
    # get points in to the world space.
    # See https://github.com/KhronosGroup/glTF/issues/743 .

    array = node.get("scale", None)
    if array is not None:
        scale_vector = torch.FloatTensor(array)
        out = out.scale(scale_vector[None])

    # Rotation quaternion (x, y, z, w) where w is the scalar
    array = node.get("rotation", None)
    if array is not None:
        x, y, z, w = array
        # We negate w. This is equivalent to inverting the rotation.
        # This is needed as quaternion_to_matrix makes a matrix which
        # operates on column vectors, whereas Transform3d wants a
        # matrix which operates on row vectors.
        rotation_quaternion = torch.FloatTensor([-w, x, y, z])
        rotation_matrix = quaternion_to_matrix(rotation_quaternion)
        out = out.rotate(R=rotation_matrix)

    array = node.get("translation", None)
    if array is not None:
        translation_vector = torch.FloatTensor(array)
        out = out.translate(x=translation_vector[None])

    return out
Ejemplo n.º 10
0
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the projection matrix using
        the multi-view geometry convention.

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in __init__.

        Returns:
            P: A `Transform3d` object with a batch of `N` projection transforms.

        .. code-block:: python

            fx = focal_length[:,0]
            fy = focal_length[:,1]
            px = principal_point[:,0]
            py = principal_point[:,1]

            P = [
                    [fx,   0,    0,  px],
                    [0,   fy,    0,  py],
                    [0,    0,    1,   0],
                    [0,    0,    0,   1],
            ]
        """
        # pyre-ignore[16]
        principal_point = kwargs.get("principal_point", self.principal_point)
        # pyre-ignore[16]
        focal_length = kwargs.get("focal_length", self.focal_length)
        # pyre-ignore[16]
        image_size = kwargs.get("image_size", self.image_size)

        # if imwidth > 0, parameters are in screen space
        in_screen = image_size[0][0] > 0
        image_size = image_size if in_screen else None

        P = _get_sfm_calibration_matrix(
            self._N,
            self.device,
            focal_length,
            principal_point,
            orthographic=True,
            image_size=image_size,
        )

        transform = Transform3d(device=self.device)
        transform._matrix = P.transpose(1, 2).contiguous()
        return transform
Ejemplo n.º 11
0
    def __init__(self,
                 points,
                 normals=None,
                 features=None,
                 to_unit_sphere: bool = False,
                 to_unit_box: bool = False,
                 to_axis_aligned: bool = False,
                 up=((0, 1, 0), ),
                 front=((0, 0, 1), )):
        """
        Args:
            points, normals: points in world coordinates
            (unnormalized and unaligned) Pointclouds in pytorch3d
            features: can be a dict {name: value} where value can be any acceptable
                form as the pytorch3d.Pointclouds
            to_unit_box (bool): transform to unit box (sidelength = 1)
            to_axis_aligned (bool): rotate the object using the up and front vectors
            up: the up direction in world coordinate (will be justified to object)
            front: front direction in the world coordinate (will be justified to z-axis)
        """
        super().__init__(points, normals=normals, features=features)
        self.obj2world_trans = Transform3d()

        # rotate object to have up direction (0, 1, 0)
        # and front direction (0, 0, -1)
        # (B,3,3) rotation to transform to axis-aligned point clouds
        if to_axis_aligned:
            self.obj2world_trans = Rotate(look_at_rotation(((0, 0, 0), ),
                                                           at=front,
                                                           up=up),
                                          device=self.device)
            world_to_obj_rotate_trans = self.obj2world_trans.inverse()

            # update points, normals
            self.update_points_(
                world_to_obj_rotate_trans.transform_points(
                    self.points_packed()))
            normals_packed = self.normals_packed()
            if normals_packed is not None:
                self.update_normals_(
                    world_to_obj_rotate_trans.transform_normals(
                        normals_packed))

        # normalize to unit box and update obj2world_trans
        if to_unit_box:
            normalizing_trans = self.normalize_to_box_()

        elif to_unit_sphere:
            normalizing_trans = self.normalize_to_sphere_()
 def __getitem__(self, index):
     index %= len(self.meshes)
     scale, rot, trans = self.get_random_transform()
     transform = Transform3d() \
         .scale(scale) \
         .compose(Rotate(rot)) \
         .translate(*trans) \
         .get_matrix() \
         .squeeze()
     mesh = self.meshes[index].scale_verts(scale)
     pixels = self.renderer(mesh,
                            R=rot.unsqueeze(0).to(self.device),
                            T=trans.unsqueeze(0).to(self.device))
     pixels = pixels[0, ..., :3].transpose(0, -1)
     return (pixels, [transform.to(self.device)])
Ejemplo n.º 13
0
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the projection matrix using
        the multi-view geometry convention.

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in __init__.

        Return:
            P: a batch of projection matrices of shape (N, 4, 4)

        .. code-block:: python

            fx = focal_length[:,0]
            fy = focal_length[:,1]
            px = principal_point[:,0]
            py = principal_point[:,1]

            P = [
                    [fx,   0,    0,  px],
                    [0,   fy,    0,  py],
                    [0,    0,    1,   0],
                    [0,    0,    0,   1],
            ]
        """
        principal_point = kwargs.get(
            "principal_point", self.principal_point
        )  # pyre-ignore[16]
        focal_length = kwargs.get(
            "focal_length", self.focal_length
        )  # pyre-ignore[16]

        P = _get_sfm_calibration_matrix(
            self._N, self.device, focal_length, principal_point, True
        )

        transform = Transform3d(device=self.device)
        transform._matrix = P.transpose(1, 2).contiguous()
        return transform
Ejemplo n.º 14
0
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the OpenGL perpective projection matrix with a symmetric
        viewing frustrum. Use column major order.

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in `__init__`.

        Return:
            P: a Transform3d object which represents a batch of projection
            matrices of shape (N, 3, 3)

        .. code-block:: python

            f1 = -(far + near)/(far−near)
            f2 = -2*far*near/(far-near)
            h1 = (top + bottom)/(top - bottom)
            w1 = (right + left)/(right - left)
            tanhalffov = tan((fov/2))
            s1 = 1/tanhalffov
            s2 = 1/(tanhalffov * (aspect_ratio))

            P = [
                    [s1,   0,   w1,   0],
                    [0,   s2,   h1,   0],
                    [0,    0,   f1,  f2],
                    [0,    0,   -1,   0],
            ]
        """
        znear = kwargs.get("znear", self.znear)  # pyre-ignore[16]
        zfar = kwargs.get("zfar", self.zfar)  # pyre-ignore[16]
        fov = kwargs.get("fov", self.fov)  # pyre-ignore[16]
        # pyre-ignore[16]
        aspect_ratio = kwargs.get("aspect_ratio", self.aspect_ratio)
        degrees = kwargs.get("degrees", self.degrees)

        P = torch.zeros((self._N, 4, 4),
                        device=self.device,
                        dtype=torch.float32)
        ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
        if degrees:
            fov = (np.pi / 180) * fov

        if not torch.is_tensor(fov):
            fov = torch.tensor(fov, device=self.device)
        tanHalfFov = torch.tan((fov / 2))
        top = tanHalfFov * znear
        bottom = -top
        right = top * aspect_ratio
        left = -right

        # NOTE: In OpenGL the projection matrix changes the handedness of the
        # coordinate frame. i.e the NDC space postive z direction is the
        # camera space negative z direction. This is because the sign of the z
        # in the projection matrix is set to -1.0.
        # In pytorch3d we maintain a right handed coordinate system throughout
        # so the so the z sign is 1.0.
        z_sign = 1.0

        P[:, 0, 0] = 2.0 * znear / (right - left)
        P[:, 1, 1] = 2.0 * znear / (top - bottom)
        P[:, 0, 2] = (right + left) / (right - left)
        P[:, 1, 2] = (top + bottom) / (top - bottom)
        P[:, 3, 2] = z_sign * ones

        # NOTE: This part of the matrix is for z renormalization in OpenGL
        # which maps the z to [-1, 1]. This won't work yet as the torch3d
        # rasterizer ignores faces which have z < 0.
        # P[:, 2, 2] = z_sign * (far + near) / (far - near)
        # P[:, 2, 3] = -2.0 * far * near / (far - near)
        # P[:, 3, 2] = z_sign * torch.ones((N))

        # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
        # is at the near clipping plane and z = 1 when the point is at the far
        # clipping plane. This replaces the OpenGL z normalization to [-1, 1]
        # until rasterization is changed to clip at z = -1.
        P[:, 2, 2] = z_sign * zfar / (zfar - znear)
        P[:, 2, 3] = -(zfar * znear) / (zfar - znear)

        # OpenGL uses column vectors so need to transpose the projection matrix
        # as torch3d uses row vectors.
        transform = Transform3d(device=self.device)
        transform._matrix = P.transpose(1, 2).contiguous()
        return transform
Ejemplo n.º 15
0
    def preprocess(self, image, face_model):
        #* input image should be uint8, in RGB order

        image = utils.center_crop_resize(image, self.config.im_size)
        image = self.cropper.crop_image(image, self.config.im_size)

        image = image[:, ::-1].copy()

        images_224 = cv2.resize(image, (224, 224),
                                interpolation=cv2.INTER_AREA).astype(
                                    np.float32)[None]

        images = self.to_tensor(image[None])
        segments = self.segmenter.segment_torch(images)
        segments = center_crop(segments, images.shape[1])
        image_segment = torch.cat([images, segments[..., None]], dim=-1)
        image_segment = image_segment.permute(0, 3, 1, 2)

        coeff, bfm_vert, bfm_neu_vert = self.reconstructor.predict(
            images_224, neutral=True)
        bfm_neu_vert = self.to_tensor(bfm_neu_vert)

        #! using torch from now on -----------------------------
        bfm_vert = self.to_tensor(bfm_vert)
        nsh_vert = self.transfers[face_model].transfer_shape_torch(bfm_vert)
        nsh_neu_vert = None
        nsh_neu_vert = self.transfers[face_model].transfer_shape_torch(
            bfm_neu_vert)
        nsh_face_vert = nsh_vert[self.uv_creators[face_model].
                                 nsh_face_start_idx:]

        coeff = self.to_tensor(coeff[None])
        _, _, _, angles, _, translation = utils.split_bfm09_coeff(coeff)
        # angle = (angle / 180.0 * math.pi) if degrees else angle
        transformer = Transform3d(device=self.device)
        transformer = transformer.rotate_axis_angle(angles[:, 0],
                                                    self.rot_order[0], False)
        transformer = transformer.rotate_axis_angle(angles[:, 1],
                                                    self.rot_order[1], False)
        transformer = transformer.rotate_axis_angle(angles[:, 2],
                                                    self.rot_order[2], False)
        transformer = transformer.translate(translation)

        nsh_trans_vert = transformer.transform_points(nsh_face_vert[None])

        nsh_shift_vert = nsh_trans_vert[0] - self.to_tensor([[0, 0, 10]])
        image_segment = torch.flip(image_segment, (3, )).type(torch.float32)

        nsh_trans_mesh = Meshes(nsh_trans_vert,
                                self.nsh_face_tris[face_model][None])

        fragment = self.rasterizer(nsh_trans_mesh)
        visible_face = torch.unique(
            fragment.pix_to_face)[1:]  # exclude face id -1
        visible_vert = self.nsh_face_tris[face_model][visible_face]
        visible_vert = torch.unique(visible_vert)
        vert_alpha = torch.zeros([nsh_shift_vert.shape[0], 1],
                                 device=self.device)
        vert_alpha[visible_vert] = 1
        nsh_shift_vert_alpha = torch.cat([nsh_shift_vert, vert_alpha], axis=-1)

        uvmap = self.uv_creators[face_model].create_nsh_uv_torch(
            nsh_shift_vert_alpha, image_segment, self.config.uv_size)

        uvmap[..., 3] = uvmap[..., 3] + uvmap[..., 4] * 128
        uvmap = uvmap[..., :4].cpu().numpy()
        uvmap = self.test_dataset.process_uvmap(uvmap.astype(np.uint8),
                                                dark_brow=True)

        images = images.permute(0, 3, 1, 2) / 127.5 - 1.0
        images = F.interpolate(images,
                               size=self.config.im_size,
                               mode='bilinear',
                               align_corners=False)
        segments = F.interpolate(segments[:, None],
                                 size=self.config.im_size,
                                 mode='nearest')
        images = torch.cat([images, segments], dim=1)
        uvmaps = uvmap[None].permute(0, 3, 1, 2)

        return images, uvmaps, coeff, nsh_face_vert, nsh_neu_vert
Ejemplo n.º 16
0
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the OpenGL orthographic projection matrix.
        Use column major order.

        Args:
            **kwargs: parameters for the projection can be passed in to
                      override the default values set in __init__.
        Return:
            P: a Transform3d object which represents a batch of projection
               matrices of shape (N, 3, 3)

        .. code-block:: python

            scale_x = 2/(right - left)
            scale_y = 2/(top - bottom)
            scale_z = 2/(far-near)
            mid_x = (right + left)/(right - left)
            mix_y = (top + bottom)/(top - bottom)
            mid_z = (far + near)/(far−near)

            P = [
                    [scale_x,        0,         0,  -mid_x],
                    [0,        scale_y,         0,  -mix_y],
                    [0,              0,  -scale_z,  -mid_z],
                    [0,              0,         0,       1],
            ]
        """
        znear = kwargs.get("znear", self.znear)  # pyre-ignore[16]
        zfar = kwargs.get("zfar", self.zfar)  # pyre-ignore[16]
        left = kwargs.get("left", self.left)  # pyre-ignore[16]
        right = kwargs.get("right", self.right)  # pyre-ignore[16]
        top = kwargs.get("top", self.top)  # pyre-ignore[16]
        bottom = kwargs.get("bottom", self.bottom)  # pyre-ignore[16]
        scale_xyz = kwargs.get("scale_xyz", self.scale_xyz)  # pyre-ignore[16]

        P = torch.zeros((self._N, 4, 4),
                        dtype=torch.float32,
                        device=self.device)
        ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
        # NOTE: OpenGL flips handedness of coordinate system between camera
        # space and NDC space so z sign is -ve. In PyTorch3D we maintain a
        # right handed coordinate system throughout.
        z_sign = +1.0

        P[:, 0, 0] = (2.0 / (right - left)) * scale_xyz[:, 0]
        P[:, 1, 1] = (2.0 / (top - bottom)) * scale_xyz[:, 1]
        P[:, 0, 3] = -(right + left) / (right - left)
        P[:, 1, 3] = -(top + bottom) / (top - bottom)
        P[:, 3, 3] = ones

        # NOTE: This maps the z coordinate to the range [0, 1] and replaces the
        # the OpenGL z normalization to [-1, 1]
        P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
        P[:, 2, 3] = -znear / (zfar - znear)

        # NOTE: This part of the matrix is for z renormalization in OpenGL.
        # The z is mapped to the range [-1, 1] but this won't work yet in
        # pytorch3d as the rasterizer ignores faces which have z < 0.
        # P[:, 2, 2] = z_sign * (2.0 / (far - near)) * scale[:, 2]
        # P[:, 2, 3] = -(far + near) / (far - near)

        transform = Transform3d(device=self.device)
        transform._matrix = P.transpose(1, 2).contiguous()
        return transform
Ejemplo n.º 17
0
    def load(self, include_textures: bool) -> List[Tuple[Optional[str], Meshes]]:
        """
        Attempt to load all the meshes making up the default scene from
        the file as a list of possibly-named Meshes objects.

        Args:
            include_textures: Whether to try loading textures.

        Returns:
            Meshes object containing one mesh.
        """
        if self._json_data is None:
            raise ValueError("Initialization problem")

        # This loads the default scene from the file.
        # This is usually the only one.
        # It is possible to have multiple scenes, in which case
        # you could choose another here instead of taking the default.
        scene_index = self._json_data.get("scene")

        if scene_index is None:
            raise ValueError("Default scene is not specified.")

        scene = self._json_data["scenes"][scene_index]
        nodes = self._json_data.get("nodes", [])
        meshes = self._json_data.get("meshes", [])
        root_node_indices = scene["nodes"]

        mesh_transform = Transform3d()
        names_meshes_list: List[Tuple[Optional[str], Meshes]] = []

        # Keep track and apply the transform of the scene node to mesh vertices
        Q = deque([(Transform3d(), node_index) for node_index in root_node_indices])

        while Q:
            parent_transform, current_node_index = Q.popleft()

            current_node = nodes[current_node_index]

            transform = _make_node_transform(current_node)
            current_transform = transform.compose(parent_transform)

            if "mesh" in current_node:
                mesh_index = current_node["mesh"]
                mesh = meshes[mesh_index]
                mesh_name = mesh.get("name", None)
                mesh_transform = current_transform

                for primitive in mesh["primitives"]:
                    attributes = primitive["attributes"]
                    accessor_index = attributes["POSITION"]
                    positions = torch.from_numpy(
                        self._access_data(accessor_index).copy()
                    )
                    positions = mesh_transform.transform_points(positions)

                    mode = primitive.get("mode", _PrimitiveMode.TRIANGLES)
                    if mode != _PrimitiveMode.TRIANGLES:
                        raise NotImplementedError("Non triangular meshes")

                    if "indices" in primitive:
                        accessor_index = primitive["indices"]
                        indices = self._access_data(accessor_index).astype(np.int64)
                    else:
                        indices = np.arange(0, len(positions), dtype=np.int64)
                    indices = torch.from_numpy(indices.reshape(-1, 3))

                    texture = None
                    if include_textures:
                        texture = self.get_texture_for_mesh(primitive, indices)

                    mesh_obj = Meshes(
                        verts=[positions], faces=[indices], textures=texture
                    )
                    names_meshes_list.append((mesh_name, mesh_obj))

            if "children" in current_node:
                children_node_indices = current_node["children"]
                Q.extend(
                    [
                        (current_transform, node_index)
                        for node_index in children_node_indices
                    ]
                )

        return names_meshes_list
Ejemplo n.º 18
0
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the OpenGL perpective projection matrix with a symmetric
        viewing frustrum. Use column major order.
        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in `__init__`.
        Return:
            P: a Transform3d object which represents a batch of projection
            matrices of shape (N, 3, 3)
        .. code-block:: python
            q = -(far + near)/(far - near)
            qn = -2*far*near/(far-near)
            P.T = [
                    [2*fx/w,     0,           0,  0],
                    [0,          -2*fy/h,     0,  0],
                    [(2*px-w)/w, (-2*py+h)/h, -q, 1],
                    [0,          0,           qn, 0],
                ]
                sometimes P[2,:] *= -1, P[1, :] *= -1
        """
        znear = kwargs.get("znear", self.znear)  # pyre-ignore[16]
        zfar = kwargs.get("zfar", self.zfar)  # pyre-ignore[16]
        x0 = kwargs.get("x0", self.x0)  # pyre-ignore[16]
        y0 = kwargs.get("y0", self.y0)  # pyre-ignore[16]
        w = kwargs.get("w", self.w)  # pyre-ignore[16]
        h = kwargs.get("h", self.h)  # pyre-ignore[16]
        principal_point = kwargs.get(
            "principal_point", self.principal_point
        )  # pyre-ignore[16]
        focal_length = kwargs.get(
            "focal_length", self.focal_length
        )  # pyre-ignore[16]

        if not torch.is_tensor(focal_length):
            focal_length = torch.tensor(focal_length, device=self.device)

        if len(focal_length.shape) in (0, 1) or focal_length.shape[1] == 1:
            fx = fy = focal_length
        else:
            fx, fy = focal_length.unbind(1)

        if not torch.is_tensor(principal_point):
            principal_point = torch.tensor(principal_point, device=self.device)
        px, py = principal_point.unbind(1)

        P = torch.zeros(
            (self._N, 4, 4), device=self.device, dtype=torch.float32
        )
        ones = torch.ones((self._N), dtype=torch.float32, device=self.device)

        # NOTE: In OpenGL the projection matrix changes the handedness of the
        # coordinate frame. i.e the NDC space postive z direction is the
        # camera space negative z direction. This is because the sign of the z
        # in the projection matrix is set to -1.0.
        # In pytorch3d we maintain a right handed coordinate system throughout
        # so the so the z sign is 1.0.
        z_sign = 1.0
        # define P.T directly
        P[:, 0, 0] = 2.0 * fx / w
        P[:, 1, 1] = -2.0 * fy / h
        P[:, 2, 0] = -(-2 * px + w + 2 * x0) / w
        P[:, 2, 1] = -(+2 * py - h + 2 * y0) / h
        P[:, 2, 3] = z_sign * ones

        # NOTE: This part of the matrix is for z renormalization in OpenGL
        # which maps the z to [-1, 1]. This won't work yet as the torch3d
        # rasterizer ignores faces which have z < 0.
        # P[:, 2, 2] = z_sign * (far + near) / (far - near)
        # P[:, 2, 3] = -2.0 * far * near / (far - near)
        # P[:, 2, 3] = z_sign * torch.ones((N))

        # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
        # is at the near clipping plane and z = 1 when the point is at the far
        # clipping plane. This replaces the OpenGL z normalization to [-1, 1]
        # until rasterization is changed to clip at z = -1.
        P[:, 2, 2] = z_sign * zfar / (zfar - znear)
        P[:, 3, 2] = -(zfar * znear) / (zfar - znear)

        # OpenGL uses column vectors so need to transpose the projection matrix
        # as torch3d uses row vectors.
        transform = Transform3d(device=self.device)
        transform._matrix = P
        return transform
Ejemplo n.º 19
0
def train(model, criterion, optimizer, train_loader, val_loader, args):
    best_prec1 = 0
    epoch_no_improve = 0

    for epoch in range(1000):

        statistics = Statistics()
        model.train()
        start_time = time.time()

        for i, (input, target) in enumerate(train_loader):
            loss, (prec1, prec5), y_pred, y_true = execute_batch(
                model, criterion, input, target, args.device)

            statistics.update(loss.detach().cpu().numpy(), prec1, prec5,
                              y_pred, y_true)
            # compute gradient and do optimizer step
            optimizer.zero_grad()  #
            loss.backward()
            optimizer.step()

            # if args.net_version == 2:
            #    model.camera_position = model.camera_position.clamp(0, 1)
            del loss
            torch.cuda.empty_cache()

        elapsed_time = time.time() - start_time

        # Evaluate on validation set
        val_statistics = validate(val_loader, model, criterion, args.device)

        log_data(statistics, "train", val_loader.dataset.dataset.classes,
                 epoch)
        log_data(val_statistics, "internal_val",
                 val_loader.dataset.dataset.classes, epoch)

        wandb.log({"Epoch elapsed time": elapsed_time}, step=epoch)
        # print(model.camera_position)
        if epoch % 1 == 0:
            vertices = []
            if args.net_version == 1:
                R = look_at_rotation(model.camera_position, device=args.device)
                T = -torch.bmm(R.transpose(1, 2),
                               model.camera_position[:, :, None])[:, :, 0]
            else:
                t = Transform3d(device=model.device).scale(
                    model.camera_position[3] *
                    model.distance_range).rotate_axis_angle(
                        model.camera_position[0] * model.angle_range,
                        axis="X",
                        degrees=False).rotate_axis_angle(
                            model.camera_position[1] * model.angle_range,
                            axis="Y",
                            degrees=False).rotate_axis_angle(
                                model.camera_position[2] * model.angle_range,
                                axis="Z",
                                degrees=False)

                vertices = t.transform_points(model.vertices)

                R = look_at_rotation(vertices[:model.nviews],
                                     device=model.device)
                T = -torch.bmm(R.transpose(1, 2), vertices[:model.nviews, :,
                                                           None])[:, :, 0]

            cameras = OpenGLPerspectiveCameras(R=R, T=T, device=args.device)
            wandb.log(
                {
                    "Cameras":
                    [wandb.Image(plot_camera_scene(cameras, args.device))]
                },
                step=epoch)
            plt.close()
            images = render_shape(model, R, T, args, vertices)
            wandb.log(
                {
                    "Views": [
                        wandb.Image(
                            image_grid(images,
                                       rows=int(np.ceil(args.nviews / 2)),
                                       cols=2))
                    ]
                },
                step=epoch)
            plt.close()
        #  Save best model and best prediction
        if val_statistics.top1.avg > best_prec1:
            best_prec1 = val_statistics.top1.avg
            save_model("views_net", model, optimizer, args.fname_best)
            epoch_no_improve = 0
        else:
            # Early stopping
            epoch_no_improve += 1
            if epoch_no_improve == 20:
                wandb.run.summary[
                    "best_internal_val_top1_accuracy"] = best_prec1
                wandb.run.summary[
                    "best_internal_val_top1_accuracy_epoch"] = epoch - 20

                return
Ejemplo n.º 20
0
 def get_projection_transform(self, **kwargs) -> Transform3d:
     transform = Transform3d(device=self.device)
     transform._matrix = self.K.transpose(1,
                                          2).contiguous()  # pyre-ignore[16]
     return transform
Ejemplo n.º 21
0
    def __init__(self, device, params=Params(), template_mesh=None):
        super().__init__()

        self.device = device
        self.params = params
        self.mesh_scale = params.mesh_sphere_scale
        self.ico_level = params.mesh_sphere_level
        self.is_real_data = params.is_real_data
        self.init_pose_R = None
        self.init_pose_t = None

        # Create a source mesh
        if not template_mesh:
            template_mesh = ico_sphere(params.mesh_sphere_level, device)
            template_mesh.scale_verts_(params.mesh_sphere_scale)

        # for EVIMO data, we need to apply a delta Transform to adjust the pose in the EVIMO coordinate system
        # to PyTorch3D system
        # Since we don't know the initial transform, we optimize the initial pose as a parameter while render the mesh
        # initialize the delta Transform
        if params.is_real_data:
            init_trans = Transform3d(device=device)
            R_init = init_trans.get_matrix()[:, :3, :3]
            qua_init = matrix_to_quaternion(R_init)
            random_noise = (torch.randn(qua_init.shape) /
                            params.mesh_pose_init_noise_var).to(self.device)
            qua_init += random_noise

            t_init = init_trans.get_matrix()[:, 3:, :3]
            random_noise_t = (torch.randn(t_init.shape) /
                              params.mesh_pose_init_noise_var).to(self.device)
            t_init += random_noise_t

            self.register_parameter('init_camera_R',
                                    nn.Parameter(qua_init).to(self.device))
            self.register_parameter('init_camera_t',
                                    nn.Parameter(t_init).to(self.device))

        verts, faces = template_mesh.get_mesh_verts_faces(0)
        # Initialize each vert to have no tetxture
        verts_rgb = torch.ones_like(verts)[None]
        textures = TexturesVertex(verts_rgb.to(self.device))
        self.template_mesh = Meshes(
            verts=[verts.to(self.device)],
            faces=[faces.to(self.device)],
            textures=textures,
        )

        self.register_buffer("vertices", self.template_mesh.verts_padded())
        self.register_buffer("faces", self.template_mesh.faces_padded())
        self.register_buffer("textures", textures.verts_features_padded())

        deform_verts = torch.zeros_like(self.template_mesh.verts_packed(),
                                        device=device,
                                        requires_grad=True)
        # Create an optimizable parameter for the mesh
        self.register_parameter("deform_verts",
                                nn.Parameter(deform_verts).to(self.device))

        # Create optimizer
        self.optimizer = self.params.mesh_optimizer(
            self.parameters(),
            lr=self.params.mesh_learning_rate,
            betas=self.params.mesh_betas)

        self.losses = {"iou": [], "laplacian": [], "flatten": []}

        # Create a silhouette_renderer
        self.renderer = silhouette_renderer(self.params.img_size, device)
Ejemplo n.º 22
0
    smax = 2.0
    srange = smax - smin
    scale = (torch.rand(1).squeeze() * srange + smin).item()

    # Generate a random NDC coordinate https://pytorch3d.org/docs/cameras
    x, y, d = torch.rand(3)
    x = x * 2.0 - 1.0
    y = y * 2.0 - 1.0
    trans = torch.Tensor([x, y, d]).to(device)
    trans = cameras.unproject_points(trans.unsqueeze(0),
                                     world_coordinates=False,
                                     scaled_depth_input=True)[0]
    rot = random_rotations(1)[0].to(device)

    transform = Transform3d() \
        .scale(scale) \
        .compose(Rotate(rot)) \
        .translate(*trans)

    # TODO: transform mesh
    # Create a phong renderer by composing a rasterizer and a shader. The textured phong shader will
    # interpolate the texture uv coordinates for each vertex, sample from a texture image and
    # apply the Phong lighting model
    renderer = MeshRenderer(rasterizer=MeshRasterizer(
        cameras=cameras, raster_settings=raster_settings),
                            shader=SoftPhongShader(
                                device=device,
                                cameras=cameras,
                                lights=lights,
                            ))
    images = renderer(mesh.scale_verts(scale),
                      R=rot.unsqueeze(0),
Ejemplo n.º 23
0
    def get_projection_transform(self, **kwargs) -> Transform3d:
        """
        Calculate the perpective projection matrix with a symmetric
        viewing frustrum. Use column major order.
        The viewing frustrum will be projected into ndc, s.t.
        (max_x, max_y) -> (+1, +1)
        (min_x, min_y) -> (-1, -1)

        Args:
            **kwargs: parameters for the projection can be passed in as keyword
                arguments to override the default values set in `__init__`.

        Return:
            P: a Transform3d object which represents a batch of projection
            matrices of shape (N, 4, 4)

        .. code-block:: python

            f1 = -(far + near)/(far−near)
            f2 = -2*far*near/(far-near)
            h1 = (max_y + min_y)/(max_y - min_y)
            w1 = (max_x + min_x)/(max_x - min_x)
            tanhalffov = tan((fov/2))
            s1 = 1/tanhalffov
            s2 = 1/(tanhalffov * (aspect_ratio))

            P = [
                    [s1,   0,   w1,   0],
                    [0,   s2,   h1,   0],
                    [0,    0,   f1,  f2],
                    [0,    0,    1,   0],
            ]
        """
        znear = kwargs.get("znear", self.znear)  # pyre-ignore[16]
        zfar = kwargs.get("zfar", self.zfar)  # pyre-ignore[16]
        fov = kwargs.get("fov", self.fov)  # pyre-ignore[16]
        # pyre-ignore[16]
        aspect_ratio = kwargs.get("aspect_ratio", self.aspect_ratio)
        degrees = kwargs.get("degrees", self.degrees)

        P = torch.zeros((self._N, 4, 4),
                        device=self.device,
                        dtype=torch.float32)
        ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
        if degrees:
            fov = (np.pi / 180) * fov

        if not torch.is_tensor(fov):
            fov = torch.tensor(fov, device=self.device)
        tanHalfFov = torch.tan((fov / 2))
        max_y = tanHalfFov * znear
        min_y = -max_y
        max_x = max_y * aspect_ratio
        min_x = -max_x

        # NOTE: In OpenGL the projection matrix changes the handedness of the
        # coordinate frame. i.e the NDC space postive z direction is the
        # camera space negative z direction. This is because the sign of the z
        # in the projection matrix is set to -1.0.
        # In pytorch3d we maintain a right handed coordinate system throughout
        # so the so the z sign is 1.0.
        z_sign = 1.0

        P[:, 0, 0] = 2.0 * znear / (max_x - min_x)
        P[:, 1, 1] = 2.0 * znear / (max_y - min_y)
        P[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
        P[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
        P[:, 3, 2] = z_sign * ones

        # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
        # is at the near clipping plane and z = 1 when the point is at the far
        # clipping plane.
        P[:, 2, 2] = z_sign * zfar / (zfar - znear)
        P[:, 2, 3] = -(zfar * znear) / (zfar - znear)

        # Transpose the projection matrix as PyTorch3d transforms use row vectors.
        transform = Transform3d(device=self.device)
        transform._matrix = P.transpose(1, 2).contiguous()
        return transform