예제 #1
0
파일: shading.py 프로젝트: zvict/pytorch3d
def phong_shading(meshes, fragments, lights, cameras, materials,
                  texels) -> torch.Tensor:
    """
    Apply per pixel shading. First interpolate the vertex normals and
    vertex coordinates using the barycentric coordinates to get the position
    and normal at each pixel. Then compute the illumination for each pixel.
    The pixel color is obtained by multiplying the pixel textures by the ambient
    and diffuse illumination and adding the specular component.

    Args:
        meshes: Batch of meshes
        fragments: Fragments named tuple with the outputs of rasterization
        lights: Lights class containing a batch of lights
        cameras: Cameras class containing a batch of cameras
        materials: Materials class containing a batch of material properties
        texels: texture per pixel of shape (N, H, W, K, 3)

    Returns:
        colors: (N, H, W, K, 3)
    """
    verts = meshes.verts_packed()  # (V, 3)
    faces = meshes.faces_packed()  # (F, 3)
    vertex_normals = meshes.verts_normals_packed()  # (V, 3)
    faces_verts = verts[faces]
    faces_normals = vertex_normals[faces]
    pixel_coords = interpolate_face_attributes(fragments.pix_to_face,
                                               fragments.bary_coords,
                                               faces_verts)
    pixel_normals = interpolate_face_attributes(fragments.pix_to_face,
                                                fragments.bary_coords,
                                                faces_normals)
    ambient, diffuse, specular = _apply_lighting(pixel_coords, pixel_normals,
                                                 lights, cameras, materials)
    colors = (ambient + diffuse) * texels + specular
    return colors
예제 #2
0
def neural_shading(
    meshes, fragments, lights, cameras, pixels_uv, texels, NN
) -> torch.Tensor:
    verts = meshes.verts_packed()  # (V, 3)
    faces = meshes.faces_packed()  # (F, 3)
    vertex_normals = meshes.verts_normals_packed()  # (V, 3)
    faces_verts = verts[faces]
    faces_normals = vertex_normals[faces]
    pixel_coords = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_verts)
    pixel_normals = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_normals)
    return _neural_shading(
      pixel_coords, pixel_normals, lights, cameras, pixels_uv, NN,
    )
예제 #3
0
    def sample_textures(self, fragments, faces_packed=None) -> torch.Tensor:
        """
        Detemine the color for each rasterized face. Interpolate the colors for
        vertices which form the face using the barycentric coordinates.
        Args:
            fragments:
                The outputs of rasterization. From this we use

                - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
                of the faces (in the packed representation) which
                overlap each pixel in the image.
                - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
                the barycentric coordianates of each pixel
                relative to the faces (in the packed
                representation) which overlap the pixel.

        Returns:
            texels: An texture per pixel of shape (N, H, W, K, C).
            There will be one C dimensional value for each element in
            fragments.pix_to_face.
        """
        verts_features_packed = self.verts_features_packed()
        faces_verts_features = verts_features_packed[faces_packed]

        texels = interpolate_face_attributes(
            fragments.pix_to_face, fragments.bary_coords, faces_verts_features
        )
        return texels
예제 #4
0
def _interpolate_zbuf(pix_to_face: torch.Tensor,
                      barycentric_coords: torch.Tensor,
                      meshes) -> torch.Tensor:
    """
    A helper function to calculate the z buffer for each pixel in the
    rasterized output.

    Args:
        pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
            of the faces (in the packed representation) which
            overlap each pixel in the image.
        barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
            the barycentric coordinates of each pixel
            relative to the faces (in the packed
            representation) which overlap the pixel.
        meshes: Meshes object representing a batch of meshes.

    Returns:
        zbuffer: (N, H, W, K) FloatTensor
    """
    verts = meshes.verts_packed()
    faces = meshes.faces_packed()
    faces_verts_z = verts[faces][..., 2][..., None]  # (F, 3, 1)
    zbuf = interpolate_face_attributes(pix_to_face, barycentric_coords,
                                       faces_verts_z)[..., 0]  # (1, H, W, K)
    zbuf[pix_to_face == -1] = -1
    return zbuf
예제 #5
0
def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
    """
    Detemine the color for each rasterized face. Interpolate the colors for
    vertices which form the face using the barycentric coordinates.
    Args:
        meshes: A Meshes class representing a batch of meshes.
        fragments:
            The outputs of rasterization. From this we use

            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image.
            - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
              the barycentric coordianates of each pixel
              relative to the faces (in the packed
              representation) which overlap the pixel.

    Returns:
        texels: An texture per pixel of shape (N, H, W, K, C).
        There will be one C dimensional value for each element in
        fragments.pix_to_face.
    """
    vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1,
                                                                 3)  # (V, C)
    vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
    faces_packed = meshes.faces_packed()
    faces_textures = vertex_textures[faces_packed]  # (F, 3, C)
    texels = interpolate_face_attributes(fragments.pix_to_face,
                                         fragments.bary_coords, faces_textures)
    return texels
예제 #6
0
 def forward(self, fragments, meshes, **kwargs):
     colors = interpolate_face_attributes(fragments.pix_to_face,
                                          fragments.bary_coords,
                                          self.face_colors)
     blend_params = BlendParams(sigma=1e-4, gamma=self.blend_gamma)
     imgs = soft_feature_blending(colors,
                                  fragments,
                                  blend_params=blend_params)
     return imgs
예제 #7
0
def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tensor:
    """
    Apply per vertex shading. First compute the vertex illumination by applying
    ambient, diffuse and specular lighting. If vertex color is available,
    combine the ambient and diffuse vertex illumination with the vertex color
    and add the specular component to determine the vertex shaded color.
    Then interpolate the vertex shaded colors using the barycentric coordinates
    to get a color per pixel.

    Gouraud shading is only supported for meshes with texture type `TexturesVertex`.
    This is because the illumination is applied to the vertex colors.

    Args:
        meshes: Batch of meshes
        fragments: Fragments named tuple with the outputs of rasterization
        lights: Lights class containing a batch of lights parameters
        cameras: Cameras class containing a batch of cameras parameters
        materials: Materials class containing a batch of material properties

    Returns:
        colors: (N, H, W, K, 3)
    """
    if not isinstance(meshes.textures, TexturesVertex):
        raise ValueError("Mesh textures must be an instance of TexturesVertex")

    faces = meshes.faces_packed()  # (F, 3)
    verts = meshes.verts_packed()  # (V, 3)
    verts_normals = meshes.verts_normals_packed()  # (V, 3)
    verts_colors = meshes.textures.verts_features_packed()  # (V, D)
    vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx()

    # Format properties of lights and materials so they are compatible
    # with the packed representation of the vertices. This transforms
    # all tensor properties in the class from shape (N, ...) -> (V, ...) where
    # V is the number of packed vertices. If the number of meshes in the
    # batch is one then this is not necessary.
    if len(meshes) > 1:
        lights = lights.clone().gather_props(vert_to_mesh_idx)
        cameras = cameras.clone().gather_props(vert_to_mesh_idx)
        materials = materials.clone().gather_props(vert_to_mesh_idx)

    # Calculate the illumination at each vertex
    ambient, diffuse, specular = _apply_lighting(
        verts, verts_normals, lights, cameras, materials
    )

    verts_colors_shaded = verts_colors * (ambient + diffuse) + specular
    face_colors = verts_colors_shaded[faces]
    colors = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, face_colors
    )
    return colors
예제 #8
0
    def get_view_to_texture_map(self, meshes_world, **kwargs):
        fragments = self.rasterizer(meshes_world, **kwargs)
        tex = meshes_world.textures
        tex_map = tex.maps_padded()

        packing_list = [
            i[j] for i, j in zip(tex.verts_uvs_list(), tex.faces_uvs_list())
        ]
        faces_verts_uvs = torch.cat(packing_list)

        pixel_uvs = interpolate_face_attributes(fragments.pix_to_face,
                                                fragments.bary_coords,
                                                faces_verts_uvs)

        return pixel_uvs
예제 #9
0
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
        """"
        LightfieldShader just returns interpolated vertex coordinates for a plane
        """
        # get the faces, normals, and textures from the mesh
        faces = meshes.faces_packed()  # (F, 3)
        verts = meshes.verts_packed(); # (V, 3)
        faces_verts = verts[faces]        
        Nv, H_out, W_out, K = fragments.pix_to_face.shape; 

        # pixel_verts: (Nv, H, W, K=1, 3) -> (Nv, K=1, H, W, 3) -> (Nv*K=1, H, W, 3)
        K = 1;
        pixel_verts = interpolate_face_attributes(
            fragments.pix_to_face[:,:,:,0:K], fragments.bary_coords[:,:,:,0:K,:], faces_verts)
        pixel_verts = pixel_verts.permute(0, 3, 1, 2, 4).view(Nv * K, H_out, W_out, 3)

        return pixel_verts
예제 #10
0
    def sample_pixel_uvs(self, fragments, **kwargs) -> torch.Tensor:
        """
        Copied from super().sample_textures and adapted to output pixel_uvs instead of the sampled texture.

        Args:
            fragments:
                The outputs of rasterization. From this we use

                - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
                of the faces (in the packed representation) which
                overlap each pixel in the image.
                - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
                the barycentric coordianates of each pixel
                relative to the faces (in the packed
                representation) which overlap the pixel.

        Returns:
            texels: tensor of shape (N, H, W, K, C) giving the interpolated
            texture for each pixel in the rasterized image.
        """
        if self.isempty():
            faces_verts_uvs = torch.zeros((self._N, 3, 2),
                                          dtype=torch.float32,
                                          device=self.device)
        else:
            packing_list = [
                i[j]
                for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
            ]
            faces_verts_uvs = torch.cat(packing_list)
            # Each vertex yields 3 triangles with u,v coordinates (N, 3, 2)
        # pixel_uvs: (N, H, W, K, 2)
        pixel_uvs = interpolate_face_attributes(fragments.pix_to_face,
                                                fragments.bary_coords,
                                                faces_verts_uvs)

        N, H_out, W_out, K = fragments.pix_to_face.shape
        # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
        pixel_uvs = pixel_uvs.permute(0, 3, 1, 2,
                                      4).reshape(N * K, H_out, W_out, 2)
        pixel_uvs = pixel_uvs * 2.0 - 1.0
        return pixel_uvs
예제 #11
0
def debug_shading(meshes, fragments) -> torch.Tensor:
    return fragments.bary_coords
    faces = meshes.faces_packed()  # (F, 3)
    vertex_normals = meshes.verts_normals_packed()  # (V, 3)
    faces_normals = vertex_normals[faces]
    pixel_normals = interpolate_face_attributes(
      fragments.pix_to_face, fragments.bary_coords, faces_normals
    )
    return pixel_normals





    verts = meshes.verts_packed()  # (V, 3)
    faces = meshes.faces_packed()  # (F, 3)
    face_normals = meshes.faces_normals_packed()  # (V, 3)
    faces_verts = verts[faces]
    face_coords = faces_verts.mean(dim=-2)  # (F, 3, XYZ) mean xyz across verts

    # Replace empty pixels in pix_to_face with 0 in order to interpolate.
    mask = fragments.pix_to_face == -1
    pix_to_face = fragments.pix_to_face.clone()
    pix_to_face[mask] = 0

    N, H, W, K = pix_to_face.shape
    idx = pix_to_face.view(N * H * W * K, 1).expand(N * H * W * K, 3)

    # gather pixel coords
    pixel_coords = face_coords.gather(0, idx).view(N, H, W, K, 3)
    pixel_coords[mask] = 0.0
    # gather pixel normals
    pixel_normals = face_normals.gather(0, idx).view(N, H, W, K, 3)
    pixel_normals[mask] = 0.0

    # Calculate the illumination at each face
    ambient, diffuse, specular = _apply_lighting(
        pixel_coords, pixel_normals, lights, cameras, materials
    )
    colors = (ambient + diffuse) * texels + specular
    return colors
예제 #12
0
    def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
        """
        Interpolate a 2D texture map using uv vertex texture coordinates for each
        face in the mesh. First interpolate the vertex uvs using barycentric coordinates
        for each pixel in the rasterized output. Then interpolate the texture map
        using the uv coordinate for each pixel.

        Args:
            fragments:
                The outputs of rasterization. From this we use

                - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
                of the faces (in the packed representation) which
                overlap each pixel in the image.
                - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
                the barycentric coordianates of each pixel
                relative to the faces (in the packed
                representation) which overlap the pixel.

        Returns:
            texels: tensor of shape (N, H, W, K, C) giving the interpolated
            texture for each pixel in the rasterized image.
        """
        if self.isempty():
            faces_verts_uvs = torch.zeros(
                (self._N, 3, 2), dtype=torch.float32, device=self.device
            )
        else:
            packing_list = [
                i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
            ]
            faces_verts_uvs = torch.cat(packing_list)
        texture_maps = self.maps_padded()

        # pixel_uvs: (N, H, W, K, 2)
        pixel_uvs = interpolate_face_attributes(
            fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
        )

        N, H_out, W_out, K = fragments.pix_to_face.shape
        N, H_in, W_in, C = texture_maps.shape  # 3 for RGB

        # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
        pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)

        # textures.map:
        #   (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
        #   -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
        texture_maps = (
            texture_maps.permute(0, 3, 1, 2)[None, ...]
            .expand(K, -1, -1, -1, -1)
            .transpose(0, 1)
            .reshape(N * K, C, H_in, W_in)
        )

        # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
        # Now need to format the pixel uvs and the texture map correctly!
        # From pytorch docs, grid_sample takes `grid` and `input`:
        #   grid specifies the sampling pixel locations normalized by
        #   the input spatial dimensions It should have most
        #   values in the range of [-1, 1]. Values x = -1, y = -1
        #   is the left-top pixel of input, and values x = 1, y = 1 is the
        #   right-bottom pixel of input.

        pixel_uvs = pixel_uvs * 2.0 - 1.0

        texture_maps = torch.flip(texture_maps, [2])  # flip y axis of the texture map
        if texture_maps.device != pixel_uvs.device:
            texture_maps = texture_maps.to(pixel_uvs.device)
        texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
        # texels now has shape (NK, C, H_out, W_out)
        texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
        return texels
예제 #13
0
파일: test.py 프로젝트: sanjeevmk/Woodhouse
def main(cfg: DictConfig):
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    obj_path = cfg.data.obj_path
    texture_path = cfg.data.texture_path
    views_folder = cfg.data.views_folder
    params_file = os.path.join(views_folder, "params.json")
    dataset = CowMultiViews(obj_path,
                            views_folder,
                            texture_path,
                            params_file=params_file)

    train_dataset, validation_dataset, test_dataset = CowMultiViews.random_split_dataset(
        dataset, train_fraction=0.7, validation_fraction=0.2)

    del dataset
    train_dataset.unit_normalize()
    validation_dataset.unit_normalize()
    test_dataset.unit_normalize()

    mesh_verts = test_dataset.get_verts()
    mesh_edges = test_dataset.get_edges()
    mesh_vert_normals = test_dataset.get_vert_normals()
    mesh_texture = test_dataset.get_texture()
    pytorch_mesh = test_dataset.pytorch_mesh.cuda()
    face_attrs = test_dataset.get_faces_as_vertex_matrices()

    feature_size = test_dataset.param_vectors.shape[1]

    torch_verts = torch.from_numpy(np.array(mesh_verts)).float().cuda()
    torch_edges = torch.from_numpy(np.array(mesh_edges)).long().cuda()
    torch_normals = torch.from_numpy(
        np.array(mesh_vert_normals)).float().cuda()
    torch_texture = torch.from_numpy(np.array(mesh_texture)).float().cuda()
    torch_texture = torch.unsqueeze(torch_texture.permute(2, 0, 1), 0)
    torch_face_attrs = torch.from_numpy(np.array(face_attrs)).float().cuda()

    subset_indices = [82]  #random.sample(list(range(len(test_dataset))),1)
    test_dataloader = Subset(test_dataset, subset_indices)
    print(subset_indices, len(test_dataloader))

    image_translator = ImageTranslator(input_dim=6,
                                       output_dim=3,
                                       image_size=tuple(
                                           cfg.data.image_size)).cuda()

    mse_loss = torch.nn.MSELoss()

    # Initialize the optimizer.
    optimizer = torch.optim.Adam(
        image_translator.parameters(),
        lr=cfg.optimizer.lr,
    )

    stats = None
    start_epoch = 0
    checkpoint_path = os.path.join(hydra.utils.get_original_cwd(),
                                   cfg.checkpoint_path)

    # Init the stats object.
    if stats is None:
        stats = Stats(["mse_loss", "sec/it"], )

    # Learning rate scheduler setup.

    # Following the original code, we use exponential decay of the
    # learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
    def lr_lambda(epoch):
        return cfg.optimizer.lr_scheduler_gamma**(
            epoch / cfg.optimizer.lr_scheduler_step_size)

    # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda,
                                                     last_epoch=start_epoch -
                                                     1,
                                                     verbose=False)

    # Initialize the cache for storing variables needed for visulization.
    visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)

    # Init the visualization visdom env.
    if cfg.visualization.visdom:
        viz = Visdom(
            server=cfg.visualization.visdom_server,
            port=cfg.visualization.visdom_port,
            use_incoming_socket=False,
        )
    else:
        viz = None

    loaded_data = torch.load(checkpoint_path)

    image_translator.load_state_dict(loaded_data["model"], strict=False)
    image_translator.eval()
    stats.new_epoch()

    image_list = []
    for iteration, data in enumerate(test_dataloader):
        print(iteration)
        optimizer.zero_grad()

        views, param_vectors = data
        views = torch.unsqueeze(torch.from_numpy(views), 0)
        param_vectors = torch.unsqueeze(torch.from_numpy(param_vectors), 0)
        views = views.float().cuda()
        param_vectors = param_vectors.float().cuda()
        camera_instance = Camera()
        camera_instance.lookAt(param_vectors[0][0],
                               math.degrees(param_vectors[0][1]),
                               math.degrees(param_vectors[0][2]))

        rasterizer_instance = Rasterizer()
        rasterizer_instance.init_rasterizer(camera_instance.camera)
        fragments = rasterizer_instance.rasterizer(pytorch_mesh)
        pix_to_face = fragments.pix_to_face
        bary_coords = fragments.bary_coords

        pix_features = torch.squeeze(
            interpolate_face_attributes(pix_to_face, bary_coords,
                                        torch_face_attrs), 3)
        param_matrix = torch.zeros(pix_features.size()[0],
                                   pix_features.size()[1],
                                   pix_features.size()[2],
                                   param_vectors.size()[1]).float().cuda()
        param_matrix[:, :, :, :] = param_vectors
        image_features = pix_features  # torch.cat([pix_features,param_matrix],3)
        predicted_render = image_translator(image_features, torch_texture)

        image_list = [
            views[0].permute(2, 0, 1), predicted_render[0].permute(2, 0, 1)
        ]

    if viz is not None:
        visualize_image_outputs(validation_images=image_list,
                                viz=viz,
                                visdom_env=cfg.visualization.visdom_env)
예제 #14
0
 def forward(self, fragments, meshes, **kwargs):
     colors = interpolate_face_attributes(fragments.pix_to_face,
                                          fragments.bary_coords,
                                          self.face_colors)
     return colors[:, :, :, 0]
예제 #15
0
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
    """
    Interpolate a 2D texture map using uv vertex texture coordinates for each
    face in the mesh. First interpolate the vertex uvs using barycentric coordinates
    for each pixel in the rasterized output. Then interpolate the texture map
    using the uv coordinate for each pixel.

    Args:
        fragments:
            The outputs of rasterization. From this we use

            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image.
            - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
              the barycentric coordianates of each pixel
              relative to the faces (in the packed
              representation) which overlap the pixel.
        meshes: Meshes representing a batch of meshes. It is expected that
            meshes has a textures attribute which is an instance of the
            Textures class.

    Returns:
        texels: tensor of shape (N, H, W, K, C) giving the interpolated
        texture for each pixel in the rasterized image.
    """
    if not isinstance(meshes.textures, Textures):
        msg = "Expected meshes.textures to be an instance of Textures; got %r"
        raise ValueError(msg % type(meshes.textures))

    faces_uvs = meshes.textures.faces_uvs_packed()
    verts_uvs = meshes.textures.verts_uvs_packed()
    faces_verts_uvs = verts_uvs[faces_uvs]
    texture_maps = meshes.textures.maps_padded()

    # pixel_uvs: (N, H, W, K, 2)
    pixel_uvs = interpolate_face_attributes(fragments.pix_to_face,
                                            fragments.bary_coords,
                                            faces_verts_uvs)

    N, H_out, W_out, K = fragments.pix_to_face.shape
    N, H_in, W_in, C = texture_maps.shape  # 3 for RGB

    # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
    pixel_uvs = pixel_uvs.permute(0, 3, 1, 2,
                                  4).reshape(N * K, H_out, W_out, 2)

    # textures.map:
    #   (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
    #   -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
    texture_maps = (texture_maps.permute(0, 3, 1, 2)[None, ...].expand(
        K, -1, -1, -1, -1).transpose(0, 1).reshape(N * K, C, H_in, W_in))

    # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
    # Now need to format the pixel uvs and the texture map correctly!
    # From pytorch docs, grid_sample takes `grid` and `input`:
    #   grid specifies the sampling pixel locations normalized by
    #   the input spatial dimensions It should have most
    #   values in the range of [-1, 1]. Values x = -1, y = -1
    #   is the left-top pixel of input, and values x = 1, y = 1 is the
    #   right-bottom pixel of input.

    pixel_uvs = pixel_uvs * 2.0 - 1.0
    texture_maps = torch.flip(texture_maps,
                              [2])  # flip y axis of the texture map
    if texture_maps.device != pixel_uvs.device:
        texture_maps = texture_maps.to(pixel_uvs.device)
    texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
    texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
    return texels
예제 #16
0
def main(cfg: DictConfig):
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    obj_path = cfg.data.obj_path
    texture_path = cfg.data.texture_path
    views_folder = cfg.data.views_folder
    params_file = os.path.join(views_folder,"params.json")
    dataset = CowMultiViews(obj_path,views_folder,texture_path,params_file=params_file)

    train_dataset, validation_dataset, test_dataset = CowMultiViews.random_split_dataset(dataset,
                                                                                         train_fraction=0.7,
                                                                                         validation_fraction=0.2)

    del dataset
    train_dataset.unit_normalize()
    validation_dataset.unit_normalize()

    mesh_verts = train_dataset.get_verts()
    mesh_edges = train_dataset.get_edges()
    mesh_vert_normals = train_dataset.get_vert_normals()
    mesh_texture = train_dataset.get_texture()
    pytorch_mesh = train_dataset.pytorch_mesh.cuda()

    random_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['random'],num_random_dims=cfg.training.feature_dim)
    coord_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['coord'],num_random_dims=cfg.training.feature_dim)
    normal_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['normal'],num_random_dims=cfg.training.feature_dim)

    torch_verts = torch.from_numpy(np.array(mesh_verts)).float().cuda()
    torch_edges = torch.from_numpy(np.array(mesh_edges)).long().cuda()
    torch_normals = torch.from_numpy(np.array(mesh_vert_normals)).float().cuda()
    torch_texture = torch.from_numpy(np.array(mesh_texture)).float().cuda()
    torch_texture = torch.unsqueeze(torch_texture,0)
    torch_random_face_attrs = torch.tensor(np.array(random_face_attrs),requires_grad=True).float().cuda()
    torch_random_face_attrs = torch.nn.Parameter(torch_random_face_attrs)
    torch_coord_face_attrs = torch.tensor(np.array(coord_face_attrs)).float().cuda()
    torch_normal_face_attrs = torch.tensor(np.array(normal_face_attrs)).float().cuda()

    train_dataloader = DataLoader(train_dataset,batch_size=cfg.training.batch_size,shuffle=True,num_workers=4)
    validation_dataloader = DataLoader(validation_dataset,batch_size=cfg.training.batch_size,shuffle=True,num_workers=4)

    image_translator = ImageTranslator(input_dim=cfg.training.feature_dim+9,output_dim=3,
                                   image_size=tuple(cfg.data.image_size)).cuda()

    mse_loss = torch.nn.MSELoss()

    # Initialize the optimizer.
    optimizer = torch.optim.Adam(
        list(image_translator.parameters())+[torch_random_face_attrs],
        lr=cfg.optimizer.lr,
    )

    stats = None
    start_epoch = 0
    checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path)

    # Init the stats object.
    if stats is None:
        stats = Stats(
            ["mse_loss", "sec/it"],
        )

    # Learning rate scheduler setup.

    # Following the original code, we use exponential decay of the
    # learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
    def lr_lambda(epoch):
        return cfg.optimizer.lr_scheduler_gamma ** (
                epoch / cfg.optimizer.lr_scheduler_step_size
        )

    # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
    )

    # Initialize the cache for storing variables needed for visulization.
    visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)

    # Init the visualization visdom env.
    if cfg.visualization.visdom:
        viz = Visdom(
            server=cfg.visualization.visdom_server,
            port=cfg.visualization.visdom_port,
            use_incoming_socket=False,
        )
    else:
        viz = None

    for epoch in range(cfg.optimizer.max_epochs):
        image_translator.train()
        stats.new_epoch()
        for iteration,data in enumerate(train_dataloader):
            optimizer.zero_grad()

            views,param_vectors = data
            views = views.float().cuda()
            param_vectors = param_vectors.float().cuda()
            camera_instance = Camera()
            camera_instance.lookAt(param_vectors[0][0],math.degrees(param_vectors[0][1]),math.degrees(param_vectors[0][2]))
            camera_location = param_vectors[0,3:6]
            light_location = param_vectors[0,6:9]
            torch_camera_face_attrs = torch_coord_face_attrs - camera_location
            torch_light_face_attrs = torch_coord_face_attrs - light_location
            torch_face_attrs = torch.cat([torch_camera_face_attrs,torch_normal_face_attrs,torch_light_face_attrs,torch_random_face_attrs],2)

            rasterizer_instance = Rasterizer()
            rasterizer_instance.init_rasterizer(camera_instance.camera)
            fragments = rasterizer_instance.rasterizer(pytorch_mesh)
            pix_to_face = fragments.pix_to_face
            bary_coords = fragments.bary_coords

            pix_features = torch.squeeze(interpolate_face_attributes(pix_to_face,bary_coords,torch_face_attrs),3)
            predicted_render = image_translator(pix_features,torch_texture)

            loss = 1000*mse_loss(predicted_render,views)
            loss.backward()
            optimizer.step()

            # Update stats with the current metrics.
            stats.update(
                {"mse_loss": float(loss)},
                stat_set="train",
            )

            if iteration % cfg.stats_print_interval == 0:
                stats.print(stat_set="train")

        # Adjust the learning rate.
        #lr_scheduler.step()

        # Validation
        if epoch % cfg.validation_epoch_interval == 0: # and epoch > 0:

            # Sample a validation camera/image.
            val_batch = next(validation_dataloader.__iter__())
            views, param_vectors= val_batch
            views = views.float().cuda()
            param_vectors = param_vectors.float().cuda()

            # Activate eval mode of the model (allows to do a full rendering pass).
            image_translator.eval()
            with torch.no_grad():
                camera_instance = Camera()
                camera_instance.lookAt(param_vectors[0][0], math.degrees(param_vectors[0][1]), math.degrees(param_vectors[0][2]))
                camera_location = param_vectors[0,3:6]
                light_location = param_vectors[0,6:9]
                torch_camera_face_attrs = torch_coord_face_attrs - camera_location
                torch_light_face_attrs = torch_coord_face_attrs - light_location
                torch_face_attrs = torch.cat([torch_camera_face_attrs,torch_normal_face_attrs,torch_light_face_attrs,torch_random_face_attrs],2)

                rasterizer_instance = Rasterizer()
                rasterizer_instance.init_rasterizer(camera_instance.camera)
                fragments = rasterizer_instance.rasterizer(pytorch_mesh)
                pix_to_face = fragments.pix_to_face
                bary_coords = fragments.bary_coords

                pix_features = torch.squeeze(interpolate_face_attributes(pix_to_face, bary_coords, torch_face_attrs), 3)
                #pix_features = pix_features.permute(0, 3, 1, 2)
                predicted_render = image_translator(pix_features,torch_texture)
                loss = 1000*mse_loss(predicted_render,views)


            # Update stats with the validation metrics.
            stats.update({"mse_loss":loss}, stat_set="val")
            stats.print(stat_set="val")

            if viz is not None:
                # Plot that loss curves into visdom.
                stats.plot_stats(
                    viz=viz,
                    visdom_env=cfg.visualization.visdom_env,
                    plot_file=None,
                )
                # Visualize the intermediate results.
                render_max = torch.max(predicted_render)
                visualize_image_outputs(
                    validation_images = [views[0].permute(2,0,1),predicted_render[0].permute(2,0,1)],viz=viz,visdom_env=cfg.visualization.visdom_env
                )

            # Set the model back to train mode.
            image_translator.train()

        # Checkpoint.
        if (
                epoch % cfg.checkpoint_epoch_interval == 0
                and len(cfg.checkpoint_path) > 0
                and epoch > 0
        ):
            print(f"Storing checkpoint {checkpoint_path}.")
            data_to_store = {
                "model": image_translator.state_dict(),
                "features" : torch_face_attrs,
                "optimizer": optimizer.state_dict(),
                "stats": pickle.dumps(stats),
            }
            torch.save(data_to_store, checkpoint_path)