def forward(self): self.z, self.conv0, self.conv2, self.conv3, self.conv4 = self.encode( self.real_A) self.z_b, self.conv0_b, self.conv2_b, self.conv3_b, self.conv4_b = self.encode( self.real_B) self.z_tf = self.transform(self.z, self.real_RT) self.depth_tf = self.depthdecode(self.z_tf) self.depth_a = self.depthdecode(self.z) self.depth_b = self.depthdecode(self.z_b) self.warp(self.real_A, self.depth_tf, self.real_RT) _, self.conv0_w, self.conv2_w, self.conv3_w, self.conv4_w = self.encode( self.warp_fake_B) self.fake_A = self.decode(self.z, self.conv0, self.conv2, self.conv3, self.conv4) self.conv0_tf, _, _ = inverse_warp(self.conv0, self.depth_tf, self.real_RT, self.intrinsics) self.conv2_tf, _, _ = inverse_warp( self.conv2, torch.nn.functional.upsample(self.depth_tf, scale_factor=0.25), self.real_RT, self.get_K(self.intrinsics, 0.25)) self.conv3_tf, _, _ = inverse_warp( self.conv3, torch.nn.functional.upsample(self.depth_tf, scale_factor=0.125), self.real_RT, self.get_K(self.intrinsics, 0.125)) self.conv4_tf, _, _ = inverse_warp( self.conv4, torch.nn.functional.upsample(self.depth_tf, scale_factor=0.0625), self.real_RT, self.get_K(self.intrinsics, 0.0625)) self.fake_B, self.fake_B3, self.fake_B2, self.fake_B1 = self.decode( self.z_tf, self.conv0_tf, self.conv2_tf, self.conv3_tf, self.conv4_tf)
def get_high_res(self, image, pose, z=None): image_small = cv2.resize(image, (256, 256)) image_small = torch.from_numpy(image_small / 128. - 1).permute( (2, 0, 1)).contiguous().unsqueeze(0) image_small = Variable(image_small).to(self.device).float() RT = self.get_RT(pose) z = self.enc(image_small) if z is None else z z_tf = self.transform(z, RT) depth = self.decode(z_tf) depth = F.upsample(depth, scale_factor=4, mode='bilinear') intrinsics = self.intrinsics[:1, :, :] * 4 intrinsics[0, 2, 2] = 1 image = image / 128. - 1 image = torch.from_numpy(image).permute( (2, 0, 1)).contiguous().unsqueeze(0).to(self.device).float() image, _, _ = inverse_warp(image, depth, RT, intrinsics) image = tensor2im(image.data.detach()) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image
def warp(self, image, depth, RT): self.warp_fake_B, self.flow, self.mask = inverse_warp( image, depth, RT, self.intrinsics)