def _trace_view_rays_out(self, rays, recursive_depth): n = rays.directions.shape[0] if recursive_depth == 0: return device_control.get_device_float32_array([n, 3], 0) closest_intersections = ray_surface_intersection.RaySurfaceIntersection( ) colors = device_control.get_device_float32_array([n, 3], 0) closest_shapes = self._get_closest_ray_intersections( rays, closest_intersections) # Mask out the non intersected ones further_indices = torch.nonzero( closest_intersections.intersected).view(-1) if further_indices.shape[0] > 0: closest_intersections = closest_intersections.mask(further_indices) closest_shapes = closest_shapes[further_indices] # colors_at_intersections = closest_shapes.get_color_at_point(closest_intersections.intersection_points) colors_at_intersections = closest_intersections.colors_at_intersection colors_from_light = self._get_color_contribution_from_lights( closest_shapes, closest_intersections, colors_at_intersections) # Contribution from reflection reflected_view_ray = ray.Ray( closest_intersections.intersection_points, closest_intersections.reflection_directions) reflected_view_ray.advance_by_epsilon() reflection_scale = self.shape_coefficients[closest_shapes, 5] reflection_contribution = self._trace_view_rays_out( reflected_view_ray, recursive_depth - 1) * reflection_scale.view(-1, 1) # if (!closestShape.isRefractable()) // not triangle # return colorFromLight.add(reflectionContribution); # # // Must consider total internal reflection # double[] REFRACTION_ALPHA = new double[1]; # Color3 refractionContribution = getRefractionColor(closestIntersection, recursionDepthLeft - 1, closestShape, # REFRACTION_ALPHA).scaleBy(closestShape.refCoeff); # # Color3 finalColor = colorFromLight; # finalColor = finalColor.add(refractionContribution.scaleBy(REFRACTION_ALPHA[0])); # finalColor = finalColor.add(reflectionContribution.scaleBy(1 - REFRACTION_ALPHA[0])); final_colors = colors_from_light + reflection_contribution colors[further_indices, :] = final_colors return colors
def find_intersections(self, rays): a = rays.directions.dot(rays.directions) o_minus_c = rays.starts - self.center b = 2 * rays.directions.dot(o_minus_c) c = o_minus_c.dot(o_minus_c) - self.radius * self.radius no_intersection_mask = b * b < 4.0 * a * c root_del = torch.sqrt(b * b - 4 * a * c) t1 = (-b + root_del) / (2.0 * a) t2 = (-b - root_del) / (2.0 * a) bad_t_mask = (t1 < 0) & (t2 < 0) no_intersection_mask = no_intersection_mask | bad_t_mask intersection = ray_surface_intersection.RaySurfaceIntersection() intersection.intersected = 1 - no_intersection_mask inf_tensor = device_control.get_device_float32_array([t1.shape[0]], 1e30) intersection.t = torch.min(torch.where(t1 < 0, inf_tensor, t1), torch.where(t2 < 0, inf_tensor, t2)) intersection.intersection_points = rays.get_point_on_line( intersection.t) intersection.intersection_normals = (intersection.intersection_points - self.center).unit_vectors() intersection.reflection_directions = rays.directions.get_reflection_directions( intersection.intersection_normals) intersection.incident_rays = rays intersection.colors_at_intersection = self.get_color_at_point( intersection.intersection_points) return intersection
def _phong_illumination_color(self, closest_shapes, colors_at_intersections, n, s, r, v): count = closest_shapes.shape[0] diffuse_scale = torch.max( device_control.get_device_float32_array([count], 0.0), 1.0 / len(self.lights) * self.shape_coefficients[closest_shapes, 1] * s.dot(n)) diffuse_color = colors_at_intersections * diffuse_scale.view(-1, 1) r_dot_v = r.dot(v) speculer_scale = ( 1.0 / len(self.lights) * self.shape_coefficients[closest_shapes, 2] * torch.pow( r_dot_v, self.shape_coefficients[closest_shapes, 3])).view( -1, 1) specular_color = device_control.get_device_float32_array( [count, 3], 1.0) * speculer_scale combined = diffuse_color + specular_color return combined
def get_intersection_t(self, rays): a = rays.directions.dot(rays.directions) o_minus_c = rays.starts - self.center b = 2 * rays.directions.dot(o_minus_c) c = o_minus_c.dot(o_minus_c) - self.radius * self.radius no_intersection_mask = b * b < 4.0 * a * c root_del = torch.sqrt(b * b - 4 * a * c) t1 = (-b + root_del) / (2.0 * a) t2 = (-b - root_del) / (2.0 * a) bad_t_mask = (t1 < 0) & (t2 < 0) no_intersection_mask = no_intersection_mask | bad_t_mask intersected = 1 - no_intersection_mask inf_tensor = device_control.get_device_float32_array([t1.shape[0]], 1e30) t = torch.min(torch.where(t1 < 0, inf_tensor, t1), torch.where(t2 < 0, inf_tensor, t2)) t = torch.where( intersected, t, device_control.get_device_float32_array([t.shape[0]], -1e10)) return t
def get_intersection_t(self, rays): # If parallel to plane, then no intersection parallel_mask = self.are_parallel_to(rays.directions) # Otherwise there we have intersection nom = (-self.d - self.normal_directions.dot(rays.starts)) denom = self.normal_directions.dot(rays.directions) t = nom / denom bad_intersection_mask = (t < 0) | (t > 1e15) proper_intersection_mask = 1 - (parallel_mask & bad_intersection_mask) n = proper_intersection_mask.shape[0] return torch.where(proper_intersection_mask, t, device_control.get_device_float32_array([n], 1e-10))
def get_intersection_t(self, rays): t = self.plane.get_intersection_t(rays) p = rays.get_point_on_line(t) c0 = p - self.a c1 = p - self.b c2 = p - self.c edge0, edge1, edge2 = self.b - self.a, self.c - self.b, self.a - self.c inside0 = self.plane.normal_directions.dot(edge0.cross(c0)) >= 0 inside1 = self.plane.normal_directions.dot(edge1.cross(c1)) >= 0 inside2 = self.plane.normal_directions.dot(edge2.cross(c2)) >= 0 inside_mask = (inside0 + inside1 + inside2) == 3 masked_t = torch.where( inside_mask, t, device_control.get_device_float32_array([t.shape[0]], -1e10)) return masked_t
def _get_color_contribution_from_lights(self, closest_shapes, closest_intersections, colors_at_intersections): # TODO improve performance for pure black colors res = colors_at_intersections * self.shape_coefficients[closest_shapes, 0].view(-1, 1) n = closest_shapes.shape[0] # Contributions from the light sources for light_pos in self.lights: t_light = closest_intersections.intersection_points.distances_to( light_pos) t_light_mask = t_light > 0 # if (tLight < 0) # continue; rays_to_light = ray.Ray( closest_intersections.intersection_points, light_pos - closest_intersections.intersection_points) rays_to_light.advance_by_epsilon() rays_reach_lights = device_control.get_device_uint8_array([n], 1) for shape in self.shapes: new_t = shape.get_intersection_t(rays_to_light) obstacle = (new_t >= 0) & (new_t < t_light) rays_reach_lights = rays_reach_lights & (1 - obstacle) v = closest_intersections.incident_rays.directions.reverse_vector( ).unit_vectors() r = rays_to_light.directions.reverse_vector( ).get_reflection_directions( closest_intersections.intersection_normals) color_from_light = self._phong_illumination_color( closest_shapes, colors_at_intersections, closest_intersections.intersection_normals, rays_to_light.directions, r, v) res += torch.where((t_light_mask & rays_reach_lights).view(-1, 1), color_from_light, device_control.get_device_float32_array([n, 3], 0.0)) return res