コード例 #1
0
class ToyNeuralGraphicsDataset(data.Dataset):
    def __init__(self,
                 dir: str,
                 rasterization_settings: dict,
                 znear: float = 1.0,
                 zfar: float = 1000.0,
                 scale_min: float = 0.5,
                 scale_max: float = 2.0,
                 device: str = 'cuda'):
        super(ToyNeuralGraphicsDataset, self).__init__()
        device = torch.device(device)
        self.device = device
        self.scale_min = scale_min
        self.scale_max = scale_max
        self.scale_range = scale_max - scale_min
        objs = [
            os.path.join(dir, f) for f in os.listdir(dir) if f.endswith('.obj')
        ]
        self.meshes = load_objs_as_meshes(objs, device=device)
        R, T = look_at_view_transform(0, 0, 0)
        self.cameras = FoVPerspectiveCameras(R=R,
                                             T=T,
                                             znear=znear,
                                             zfar=zfar,
                                             device=device)
        self.renderer = MeshRenderer(rasterizer=MeshRasterizer(
            cameras=self.cameras,
            raster_settings=RasterizationSettings(**rasterization_settings),
        ),
                                     shader=HardFlatShader(
                                         device=device,
                                         cameras=self.cameras,
                                     ))

    def get_random_transform(self):
        scale = (torch.rand(1).squeeze() * self.scale_range +
                 self.scale_min).item()

        rot = random_rotation()

        x, y, d = torch.rand(3)
        x = x * 2.0 - 1.0
        y = y * 2.0 - 1.0
        trans = torch.Tensor([x, y, d])
        trans = self.cameras.unproject_points(
            trans.unsqueeze(0).to(self.device),
            world_coordinates=False,
            scaled_depth_input=True)[0].cpu()
        return scale, rot, trans

    def __getitem__(self, index):
        index %= len(self.meshes)
        scale, rot, trans = self.get_random_transform()
        transform = Transform3d() \
            .scale(scale) \
            .compose(Rotate(rot)) \
            .translate(*trans) \
            .get_matrix() \
            .squeeze()
        mesh = self.meshes[index].scale_verts(scale)
        pixels = self.renderer(mesh,
                               R=rot.unsqueeze(0).to(self.device),
                               T=trans.unsqueeze(0).to(self.device))
        pixels = pixels[0, ..., :3].transpose(0, -1)
        return (pixels, [transform.to(self.device)])

    def __len__(self):
        return len(self.meshes) * 1024
コード例 #2
0
    # With world coordinates +Y up, +X left and +Z in, the front of the cow is facing the -Z direction.
    # So we move the camera by 180 in the azimuth direction so it is facing the front of the cow.
    R, T = look_at_view_transform(0, 0, 0)
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T, zfar=zfar)
    smin = 0.1
    smax = 2.0
    srange = smax - smin
    scale = (torch.rand(1).squeeze() * srange + smin).item()

    # Generate a random NDC coordinate https://pytorch3d.org/docs/cameras
    x, y, d = torch.rand(3)
    x = x * 2.0 - 1.0
    y = y * 2.0 - 1.0
    trans = torch.Tensor([x, y, d]).to(device)
    trans = cameras.unproject_points(trans.unsqueeze(0),
                                     world_coordinates=False,
                                     scaled_depth_input=True)[0]
    rot = random_rotations(1)[0].to(device)

    transform = Transform3d() \
        .scale(scale) \
        .compose(Rotate(rot)) \
        .translate(*trans)

    # TODO: transform mesh
    # Create a phong renderer by composing a rasterizer and a shader. The textured phong shader will
    # interpolate the texture uv coordinates for each vertex, sample from a texture image and
    # apply the Phong lighting model
    renderer = MeshRenderer(rasterizer=MeshRasterizer(
        cameras=cameras, raster_settings=raster_settings),
                            shader=SoftPhongShader(