Example #1
0
    def forward(self, input):
        # Split the input of shape (B, N, C, H, W) into a list over the input images [(B, 1, C, H, W)_1, ..., (B, 1, C, H, W)_N]
        input_images = torch.split(input, 1, dim=1)

        # Invoke the generator for all the input images
        encoder_decoder_outputs = []
        global_track_outputs = []
        for input_image in input_images:
            encoder_decoder_output, global_track_output = self.generator(
                input_image.squeeze(1))
            encoder_decoder_outputs.append(encoder_decoder_output.unsqueeze(1))
            global_track_outputs.append(global_track_output.unsqueeze(1))

        # Merge the outputs back into a tensors of shape (B, N, C, H, W)
        encoder_decoder_outputs = torch.cat(encoder_decoder_outputs, dim=1)
        global_track_outputs = torch.cat(global_track_outputs, dim=1)

        # Pool over the input image dimension
        pooled_encoder_decoder_outputs, _ = torch.max(encoder_decoder_outputs,
                                                      dim=1)
        pooled_global_track_outputs, _ = torch.max(global_track_outputs, dim=1)

        x = self.merge(pooled_encoder_decoder_outputs,
                       pooled_global_track_outputs)
        mean = torch.mean(pooled_encoder_decoder_outputs,
                          dim=(2, 3),
                          keepdim=False)
        global_track = self.gt1(mean, pooled_global_track_outputs)
        x, mean = self.conv1(x, global_track)
        global_track = self.gt2(mean, global_track)
        x, mean = self.conv2(x, global_track)
        global_track = self.gt3(mean, global_track)
        x, mean = self.conv3(x, global_track)

        svbrdf = self.activation(x)

        # 9 channel SVBRDF to 12 channels
        svbrdf = utils.decode_svbrdf(svbrdf)

        # Map ranges from [-1, 1] to [0, 1], except for the normals
        normals, diffuse, roughness, specular = utils.unpack_svbrdf(svbrdf)
        diffuse = utils.encode_as_unit_interval(diffuse)
        roughness = utils.encode_as_unit_interval(roughness)
        specular = utils.encode_as_unit_interval(specular)

        return utils.pack_svbrdf(normals, diffuse, roughness, specular)
Example #2
0
    def forward(self, input):
        if len(input.shape) == 5:
            # If we get multiple input images, we just ignore all but one
            input = input[:, 0, :, :, :]

        svbrdf, _ = self.generator(input)
        svbrdf = self.activation(svbrdf)

        # 9 channel SVBRDF to 12 channels
        svbrdf = utils.decode_svbrdf(svbrdf)

        # Map ranges from [-1, 1] to [0, 1], except for the normals
        normals, diffuse, roughness, specular = utils.unpack_svbrdf(svbrdf)
        diffuse = utils.encode_as_unit_interval(diffuse)
        roughness = utils.encode_as_unit_interval(roughness)
        specular = utils.encode_as_unit_interval(specular)

        return utils.pack_svbrdf(normals, diffuse, roughness, specular)
    def mix(self, svbrdf_0, svbrdf_1, alpha=None):
        if alpha is None:
            alpha = torch.Tensor(1).uniform_(0.1, 0.9)

        normals_0, diffuse_0, roughness_0, specular_0 = utils.unpack_svbrdf(svbrdf_0)
        normals_1, diffuse_1, roughness_1, specular_1 = utils.unpack_svbrdf(svbrdf_1)

        # Reference "Project the normals to use the X and Y derivative"
        normals_0_projected = normals_0 / torch.max(torch.Tensor([0.01]), normals_0[2:3,:,:])
        normals_1_projected = normals_1 / torch.max(torch.Tensor([0.01]), normals_1[2:3,:,:])

        normals_mixed = alpha * normals_0_projected + (1.0 - alpha) * normals_1_projected
        normals_mixed = normals_mixed / torch.sqrt(torch.sum(normals_mixed**2, axis=0,keepdim=True)) # Normalization

        diffuse_mixed   = alpha * diffuse_0 + (1.0 - alpha) * diffuse_1
        roughness_mixed = alpha * roughness_0 + (1.0 - alpha) * roughness_1
        specular_mixed  = alpha * specular_0 + (1.0 - alpha) * specular_1
        
        return utils.pack_svbrdf(normals_mixed, diffuse_mixed, roughness_mixed, specular_mixed)
    def read_sample(self, file_path):
        # Read full image
        # TODO: Use utils.read_image_tensor()
        full_image   = torch.Tensor(plt.imread(file_path)).permute(2, 0, 1)

        # Split the full image apart along the horizontal direction 
        # Magick number 4 is the number of maps in the SVBRDF
        svbrdf_map_count = 0 if self.no_svbrdf else 4 
        image_parts      = torch.cat(full_image.unsqueeze(0).chunk(self.input_image_count + svbrdf_map_count, dim=-1), 0) # [n, 3, 256, 256]

        # Read the SVBRDF (dummy if there is none in the dataset)
        svbrdf = None
        if self.no_svbrdf:
            # If there are no SVBRDFs in the data, there must be images which we can use as size guide.
            width  = image_parts[0].shape[-1]
            height = image_parts[0].shape[-2]

            normals   = torch.cat([torch.zeros((2, height, width)), torch.ones((1, height, width))], dim=0)
            diffuse   = torch.zeros_like(normals)
            roughness = torch.zeros_like(normals)
            specular  = torch.zeros_like(normals)
        else:
            normals   = image_parts[self.input_image_count + 0].unsqueeze(0)
            normals   = utils.decode_from_unit_interval(normals)
            diffuse   = image_parts[self.input_image_count + 1].unsqueeze(0)
            roughness = image_parts[self.input_image_count + 2].unsqueeze(0)
            specular  = image_parts[self.input_image_count + 3].unsqueeze(0)

        svbrdf = utils.pack_svbrdf(normals, diffuse, roughness, specular).squeeze(0) # [12, 256, 256]

        # We read as many input images from the disk as we can
        # FIXME: This is a little bit counter-intuitive, as we are reading the last n images, not the first n
        read_input_image_count = min(self.input_image_count, self.used_input_image_count)
        input_images           = image_parts[(self.input_image_count - read_input_image_count) : self.input_image_count] # [ni, 3, 256, 256]

        return input_images, svbrdf
        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 3)
        plt.imshow(diffuse.squeeze(0).permute(1, 2, 0))
        plt.axis('off')

        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 4)
        plt.imshow(roughness.squeeze(0).permute(1, 2, 0))
        plt.axis('off')

        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 5)
        plt.imshow(specular.squeeze(0).permute(1, 2, 0))
        plt.axis('off')

        rendering = utils.gamma_encode(
            renderer.render(
                scene,
                utils.pack_svbrdf(normals, diffuse, roughness,
                                  specular))).squeeze(0).permute(1, 2, 0)
        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 6)
        plt.imshow(rendering)
        plt.axis('off')

        perspective_rendering = perspective_mapping.apply(rendering.numpy())
        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 7)
        plt.imshow(perspective_rendering)
        plt.axis('off')

        rendering = utils.gamma_encode(
            redner_renderer.render(
                scene,
                utils.pack_svbrdf(normals, diffuse, roughness,
                                  specular))).squeeze(0).permute(1, 2, 0)
        fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 8)