예제 #1
0
    def forward(self, input):
        # Split the input of shape (B, N, C, H, W) into a list over the input images [(B, 1, C, H, W)_1, ..., (B, 1, C, H, W)_N]
        input_images = torch.split(input, 1, dim=1)

        # Invoke the generator for all the input images
        encoder_decoder_outputs = []
        global_track_outputs = []
        for input_image in input_images:
            encoder_decoder_output, global_track_output = self.generator(
                input_image.squeeze(1))
            encoder_decoder_outputs.append(encoder_decoder_output.unsqueeze(1))
            global_track_outputs.append(global_track_output.unsqueeze(1))

        # Merge the outputs back into a tensors of shape (B, N, C, H, W)
        encoder_decoder_outputs = torch.cat(encoder_decoder_outputs, dim=1)
        global_track_outputs = torch.cat(global_track_outputs, dim=1)

        # Pool over the input image dimension
        pooled_encoder_decoder_outputs, _ = torch.max(encoder_decoder_outputs,
                                                      dim=1)
        pooled_global_track_outputs, _ = torch.max(global_track_outputs, dim=1)

        x = self.merge(pooled_encoder_decoder_outputs,
                       pooled_global_track_outputs)
        mean = torch.mean(pooled_encoder_decoder_outputs,
                          dim=(2, 3),
                          keepdim=False)
        global_track = self.gt1(mean, pooled_global_track_outputs)
        x, mean = self.conv1(x, global_track)
        global_track = self.gt2(mean, global_track)
        x, mean = self.conv2(x, global_track)
        global_track = self.gt3(mean, global_track)
        x, mean = self.conv3(x, global_track)

        svbrdf = self.activation(x)

        # 9 channel SVBRDF to 12 channels
        svbrdf = utils.decode_svbrdf(svbrdf)

        # Map ranges from [-1, 1] to [0, 1], except for the normals
        normals, diffuse, roughness, specular = utils.unpack_svbrdf(svbrdf)
        diffuse = utils.encode_as_unit_interval(diffuse)
        roughness = utils.encode_as_unit_interval(roughness)
        specular = utils.encode_as_unit_interval(specular)

        return utils.pack_svbrdf(normals, diffuse, roughness, specular)
예제 #2
0
    def forward(self, input):
        if len(input.shape) == 5:
            # If we get multiple input images, we just ignore all but one
            input = input[:, 0, :, :, :]

        svbrdf, _ = self.generator(input)
        svbrdf = self.activation(svbrdf)

        # 9 channel SVBRDF to 12 channels
        svbrdf = utils.decode_svbrdf(svbrdf)

        # Map ranges from [-1, 1] to [0, 1], except for the normals
        normals, diffuse, roughness, specular = utils.unpack_svbrdf(svbrdf)
        diffuse = utils.encode_as_unit_interval(diffuse)
        roughness = utils.encode_as_unit_interval(roughness)
        specular = utils.encode_as_unit_interval(specular)

        return utils.pack_svbrdf(normals, diffuse, roughness, specular)
예제 #3
0
        # We only have one image in the inputs
        batch_inputs.squeeze_(0)

        input = utils.gamma_encode(batch_inputs)
        svbrdf = batch_svbrdf

        normals, diffuse, roughness, specular = utils.unpack_svbrdf(svbrdf)

        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 1)
        plt.imshow(input.squeeze(0).permute(1, 2, 0))
        plt.axis('off')

        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 2)
        plt.imshow(
            utils.encode_as_unit_interval(normals.squeeze(0).permute(1, 2, 0)))
        plt.axis('off')

        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 3)
        plt.imshow(diffuse.squeeze(0).permute(1, 2, 0))
        plt.axis('off')

        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 4)
        plt.imshow(roughness.squeeze(0).permute(1, 2, 0))
        plt.axis('off')

        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 5)
        plt.imshow(specular.squeeze(0).permute(1, 2, 0))
        plt.axis('off')

        rendering = utils.gamma_encode(
예제 #4
0
    def render(self, scene, svbrdf):
        imgs = []

        svbrdf = svbrdf.unsqueeze(0) if len(svbrdf.shape) == 3 else svbrdf

        sensor_size = (svbrdf.shape[-1], svbrdf.shape[-2])

        for svbrdf_single in torch.split(svbrdf, 1, dim=0):
            normals, diffuse, roughness, specular = utils.unpack_svbrdf(
                svbrdf_single.squeeze(0))
            # Redner expects the normal map to be in range [0, 1]
            normals = utils.encode_as_unit_interval(normals)
            # Redner expects the roughness to have one channel only.
            # We also need to convert from GGX roughness to Blinn-Phong power.
            # See: https://github.com/iondune/csc473/blob/master/lectures/07-cook-torrance.md
            roughness = torch.mean(torch.clamp(roughness, min=0.001),
                                   dim=0,
                                   keepdim=True)**4

            # Convert from [c,h,w] to [h,w,c] for redner
            normals = normals.permute(1, 2, 0)
            diffuse = diffuse.permute(1, 2, 0)
            roughness = roughness.permute(1, 2, 0)
            specular = specular.permute(1, 2, 0)

            material = pyredner.Material(
                diffuse_reflectance=pyredner.Texture(
                    diffuse.to(self.redner_device)),
                specular_reflectance=pyredner.Texture(
                    specular.to(self.redner_device)),
                roughness=pyredner.Texture(roughness.to(self.redner_device)),
                normal_map=pyredner.Texture(normals.to(self.redner_device)))

            material_patch = pyredner.Object(vertices=self.patch_vertices,
                                             uvs=self.patch_uvs,
                                             indices=self.patch_indices,
                                             material=material)

            # Define the camera parameters (focused at the middle of the patch) and make sure we always have a valid 'up' direction
            position = np.array(scene.camera.pos)
            lookat = np.array([0.0, 0.0, 0.0])
            cz = lookat - position  # Principal axis
            up = np.array([0.0, 0.0, 1.0])
            if np.linalg.norm(np.cross(cz, up)) == 0.0:
                up = np.array([0.0, 1.0, 0.0])

            camera = pyredner.Camera(
                position=torch.FloatTensor(position).to(self.redner_device),
                look_at=torch.FloatTensor(lookat).to(self.redner_device),
                up=torch.FloatTensor(up).to(self.redner_device),
                fov=torch.FloatTensor([90]),
                resolution=sensor_size,
                camera_type=self.camera_type)

            # # The deferred rendering path.
            # # It does not have a specular model and therefore is of limited usability for us
            # full_scene = pyredner.Scene(camera = camera, objects = [material_patch])
            # light = pyredner.PointLight(position = torch.tensor(scene.light.pos).to(self.redner_device),
            #                                    intensity = torch.tensor(scene.light.color).to(self.redner_device))
            # img = pyredner.render_deferred(scene = full_scene, lights = [light])

            light = pyredner.generate_quad_light(
                position=torch.Tensor(scene.light.pos).to(self.redner_device),
                look_at=torch.zeros(3).to(self.redner_device),
                size=torch.Tensor([0.6, 0.6]).to(self.redner_device),
                intensity=torch.Tensor(scene.light.color).to(
                    self.redner_device))
            full_scene = pyredner.Scene(camera=camera,
                                        objects=[material_patch, light])
            img = pyredner.render_pathtracing(full_scene, num_samples=(16, 8))

            # Transform the rendered image back to something torch can interprete
            imgs.append(img.permute(2, 0, 1).to(svbrdf.device))

        return torch.stack(imgs)
예제 #5
0
    outputs = model(batch_inputs)

    input = utils.gamma_encode(batch_inputs.squeeze(0)[0]).cpu().permute(
        1, 2, 0)
    target_maps = torch.cat(batch_svbrdf.split(3, dim=1),
                            dim=0).clone().cpu().detach().permute(0, 2, 3, 1)
    output_maps = torch.cat(outputs.split(3, dim=1),
                            dim=0).clone().cpu().detach().permute(0, 2, 3, 1)

    fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 1)
    plt.imshow(input)
    plt.axis('off')

    fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 2)
    plt.imshow(utils.encode_as_unit_interval(target_maps[0]))
    plt.axis('off')

    fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 3)
    plt.imshow(target_maps[1])
    plt.axis('off')

    fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 4)
    plt.imshow(target_maps[2])
    plt.axis('off')

    fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 5)
    plt.imshow(target_maps[3])
    plt.axis('off')

    fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 7)