def backward(ctx, grad_soft_colors, grad_p2f_info=None, grad_aggrs_info=None): face_vertices, textures, soft_colors, faces_info, aggrs_info = ctx.saved_tensors image_size = ctx.image_size background_color = ctx.background_color near = ctx.near far = ctx.far eps = ctx.eps sigma_val = ctx.sigma_val dist_eps = ctx.dist_eps gamma_val = ctx.gamma_val func_dist_type = ctx.func_dist_type func_rgb_type = ctx.func_rgb_type func_alpha_type = ctx.func_alpha_type texture_type = ctx.texture_type fill_back = ctx.fill_back grad_faces = torch.zeros_like(face_vertices, dtype=torch.float32).to(ctx.device).contiguous() grad_textures = torch.zeros_like(textures, dtype=torch.float32).to(ctx.device).contiguous() grad_soft_colors = grad_soft_colors.contiguous() grad_faces, grad_textures = \ soft_rasterize_cuda.backward_soft_rasterize(face_vertices, textures, soft_colors, faces_info, aggrs_info, grad_faces, grad_textures, grad_soft_colors, image_size, near, far, eps, sigma_val, func_dist_type, dist_eps, gamma_val, func_rgb_type, func_alpha_type, texture_type, fill_back) return grad_faces, grad_textures, None, None, None, None, None, None, None, None, None, None, None, None, None
def grad(self, grad_soft_colors): face_vertices, textures, soft_colors, faces_info, aggrs_info = self.save_vars image_size = self.image_size background_color = self.background_color near = self.near far = self.far eps = self.eps sigma_val = self.sigma_val dist_eps = self.dist_eps gamma_val = self.gamma_val func_dist_type = self.func_dist_type func_rgb_type = self.func_rgb_type func_alpha_type = self.func_alpha_type texture_type = self.texture_type fill_back = self.fill_back grad_faces = jt.zeros((face_vertices.shape)) grad_textures = jt.zeros((textures.shape)) grad_faces, grad_textures = \ soft_rasterize_cuda.backward_soft_rasterize(face_vertices, textures, soft_colors, faces_info, aggrs_info, grad_faces, grad_textures, grad_soft_colors, image_size, near, far, eps, sigma_val, func_dist_type, dist_eps, gamma_val, func_rgb_type, func_alpha_type, texture_type, int(fill_back)) # print(grad_faces, grad_textures) return grad_faces, grad_textures