def forward(self, input): # Split the input of shape (B, N, C, H, W) into a list over the input images [(B, 1, C, H, W)_1, ..., (B, 1, C, H, W)_N] input_images = torch.split(input, 1, dim=1) # Invoke the generator for all the input images encoder_decoder_outputs = [] global_track_outputs = [] for input_image in input_images: encoder_decoder_output, global_track_output = self.generator( input_image.squeeze(1)) encoder_decoder_outputs.append(encoder_decoder_output.unsqueeze(1)) global_track_outputs.append(global_track_output.unsqueeze(1)) # Merge the outputs back into a tensors of shape (B, N, C, H, W) encoder_decoder_outputs = torch.cat(encoder_decoder_outputs, dim=1) global_track_outputs = torch.cat(global_track_outputs, dim=1) # Pool over the input image dimension pooled_encoder_decoder_outputs, _ = torch.max(encoder_decoder_outputs, dim=1) pooled_global_track_outputs, _ = torch.max(global_track_outputs, dim=1) x = self.merge(pooled_encoder_decoder_outputs, pooled_global_track_outputs) mean = torch.mean(pooled_encoder_decoder_outputs, dim=(2, 3), keepdim=False) global_track = self.gt1(mean, pooled_global_track_outputs) x, mean = self.conv1(x, global_track) global_track = self.gt2(mean, global_track) x, mean = self.conv2(x, global_track) global_track = self.gt3(mean, global_track) x, mean = self.conv3(x, global_track) svbrdf = self.activation(x) # 9 channel SVBRDF to 12 channels svbrdf = utils.decode_svbrdf(svbrdf) # Map ranges from [-1, 1] to [0, 1], except for the normals normals, diffuse, roughness, specular = utils.unpack_svbrdf(svbrdf) diffuse = utils.encode_as_unit_interval(diffuse) roughness = utils.encode_as_unit_interval(roughness) specular = utils.encode_as_unit_interval(specular) return utils.pack_svbrdf(normals, diffuse, roughness, specular)
def forward(self, input): if len(input.shape) == 5: # If we get multiple input images, we just ignore all but one input = input[:, 0, :, :, :] svbrdf, _ = self.generator(input) svbrdf = self.activation(svbrdf) # 9 channel SVBRDF to 12 channels svbrdf = utils.decode_svbrdf(svbrdf) # Map ranges from [-1, 1] to [0, 1], except for the normals normals, diffuse, roughness, specular = utils.unpack_svbrdf(svbrdf) diffuse = utils.encode_as_unit_interval(diffuse) roughness = utils.encode_as_unit_interval(roughness) specular = utils.encode_as_unit_interval(specular) return utils.pack_svbrdf(normals, diffuse, roughness, specular)
def mix(self, svbrdf_0, svbrdf_1, alpha=None): if alpha is None: alpha = torch.Tensor(1).uniform_(0.1, 0.9) normals_0, diffuse_0, roughness_0, specular_0 = utils.unpack_svbrdf(svbrdf_0) normals_1, diffuse_1, roughness_1, specular_1 = utils.unpack_svbrdf(svbrdf_1) # Reference "Project the normals to use the X and Y derivative" normals_0_projected = normals_0 / torch.max(torch.Tensor([0.01]), normals_0[2:3,:,:]) normals_1_projected = normals_1 / torch.max(torch.Tensor([0.01]), normals_1[2:3,:,:]) normals_mixed = alpha * normals_0_projected + (1.0 - alpha) * normals_1_projected normals_mixed = normals_mixed / torch.sqrt(torch.sum(normals_mixed**2, axis=0,keepdim=True)) # Normalization diffuse_mixed = alpha * diffuse_0 + (1.0 - alpha) * diffuse_1 roughness_mixed = alpha * roughness_0 + (1.0 - alpha) * roughness_1 specular_mixed = alpha * specular_0 + (1.0 - alpha) * specular_1 return utils.pack_svbrdf(normals_mixed, diffuse_mixed, roughness_mixed, specular_mixed)
def read_sample(self, file_path): # Read full image # TODO: Use utils.read_image_tensor() full_image = torch.Tensor(plt.imread(file_path)).permute(2, 0, 1) # Split the full image apart along the horizontal direction # Magick number 4 is the number of maps in the SVBRDF svbrdf_map_count = 0 if self.no_svbrdf else 4 image_parts = torch.cat(full_image.unsqueeze(0).chunk(self.input_image_count + svbrdf_map_count, dim=-1), 0) # [n, 3, 256, 256] # Read the SVBRDF (dummy if there is none in the dataset) svbrdf = None if self.no_svbrdf: # If there are no SVBRDFs in the data, there must be images which we can use as size guide. width = image_parts[0].shape[-1] height = image_parts[0].shape[-2] normals = torch.cat([torch.zeros((2, height, width)), torch.ones((1, height, width))], dim=0) diffuse = torch.zeros_like(normals) roughness = torch.zeros_like(normals) specular = torch.zeros_like(normals) else: normals = image_parts[self.input_image_count + 0].unsqueeze(0) normals = utils.decode_from_unit_interval(normals) diffuse = image_parts[self.input_image_count + 1].unsqueeze(0) roughness = image_parts[self.input_image_count + 2].unsqueeze(0) specular = image_parts[self.input_image_count + 3].unsqueeze(0) svbrdf = utils.pack_svbrdf(normals, diffuse, roughness, specular).squeeze(0) # [12, 256, 256] # We read as many input images from the disk as we can # FIXME: This is a little bit counter-intuitive, as we are reading the last n images, not the first n read_input_image_count = min(self.input_image_count, self.used_input_image_count) input_images = image_parts[(self.input_image_count - read_input_image_count) : self.input_image_count] # [ni, 3, 256, 256] return input_images, svbrdf
fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 3) plt.imshow(diffuse.squeeze(0).permute(1, 2, 0)) plt.axis('off') fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 4) plt.imshow(roughness.squeeze(0).permute(1, 2, 0)) plt.axis('off') fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 5) plt.imshow(specular.squeeze(0).permute(1, 2, 0)) plt.axis('off') rendering = utils.gamma_encode( renderer.render( scene, utils.pack_svbrdf(normals, diffuse, roughness, specular))).squeeze(0).permute(1, 2, 0) fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 6) plt.imshow(rendering) plt.axis('off') perspective_rendering = perspective_mapping.apply(rendering.numpy()) fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 7) plt.imshow(perspective_rendering) plt.axis('off') rendering = utils.gamma_encode( redner_renderer.render( scene, utils.pack_svbrdf(normals, diffuse, roughness, specular))).squeeze(0).permute(1, 2, 0) fig.add_subplot(row_count, col_count, 2 * i_row * col_count + 8)