def calculate_gradient(self, param: Parametrization) -> np.ndarray: struct = param.get_structure() efields = self.sim.simulate(struct) total_df_dz = self.calculate_df_dz(efields, struct) # TODO(logansu): Cache gradient calculation. dz_dp = aslinearoperator(param.calculate_gradient()) df_dp = np.conj( dz_dp.adjoint() @ np.conj(total_df_dz + self.calculate_partial_df_dz(efields, struct))) return df_dp
def eval_grad(fun: "OptimizationFunction", param: Parametrization) -> np.ndarray: """Evalutes the gradient of `fun` with respect to `param`. This function evaluates the gradient by reverse-mode autodiff using similar technique described in `eval_fun`. Reverse-mode autodiff occurs in two phases. In the first "forward" phase, all the values of the functions at each node are computed. In the "backward" phase, the gradients are computed starting from the output function moving towards the inputs. This function assumes that `param.get_structure()` will replace up to one `Variable` in the computational graph. If no `Variable` nodes exist, the value will not be used. If multiple `Variable` nodes exist, a `ValueError` is raised. Args: fun: The function to evaluate. param: Parametrization to evaluate with. Returns: The function value. Raises: ValueError: If more than one `Variable` node was found. """ fun_map, graph, in_nodes, out_nodes, heavy_nodes = _create_computational_graph( [fun]) out_node = out_nodes[0] # Forward pass (phase 1): Compute function values. This is a near identical # behavior to `eval_fun`. # Dictionary mapping function ids to values. fun_vals = {} # Check that we have at most one input node. # It is actually possible to have no input node because the old-style # function might be masking it (i.e. not listing it as a dependency). if len(in_nodes) == 1: in_node = list(in_nodes)[0] fun_vals[in_node] = param.get_structure() elif len(in_nodes) > 1: raise ValueError("Multiple Variable nodes detected.") top_sorted_nodes = _top_sort_affinity(graph, heavy_nodes) _eval_fun_vals(fun_vals, fun_map, graph, top_sorted_nodes, param) # Backward pass (phase 2): The gradients are computed moving from # the last node to the first. # Dictionary mapping function ids to gradients. grad_vals = {node: 0.0 for node in fun_map.keys()} grad_vals[out_node] = 1 def _eval_grad(node, input_vals, grad_val): if _is_old_fun(fun_map[node]): return fun_map[node].calculate_gradient(param) return fun_map[node].grad(input_vals, grad_val) # We need to maintain a list of gradients computed using the old-style # functions as the old-style functions don't list dependencies. These # gradients are then summed with the new-style gradients at the end. old_node_grads = [] try: node_iter = iter(top_sorted_nodes) node = next(node_iter) while True: while True: # Postpone any heavy computation. if _is_heavy_fun(fun_map[node]): break if _is_old_fun(fun_map[node]): old_grad = fun_map[node].calculate_gradient(param) # Use `np.dot` since either operand could be a scalar # and matrix multiplication(@) disallows scalars. old_node_grads.append(np.dot(grad_vals[node], old_grad)) else: # Get the gradients of the output function with respect to # the inputs of function `node`. input_vals = [ fun_vals[next_node] for next_node in graph[node] ] node_grad_vals = fun_map[node].grad( input_vals, grad_vals[node]) # Update the gradient values. for next_node, next_fun_grad in zip( graph[node], node_grad_vals): grad_vals[next_node] += next_fun_grad node = next(node_iter) # At this point, we have processed all the possible nodes. # We now parallelize the computation of the `heavy_compute_nodes`. heavy_node_block = [] last_node_reached = False try: while _is_heavy_fun(fun_map[node]): heavy_node_block.append(node) node = next(node_iter) except StopIteration: last_node_reached = True # Parallelize heavy compute. arg_list = [] for node_ in heavy_node_block: arg_list.append(( node_, [fun_vals[in_node] for in_node in graph[node_]], grad_vals[node_], )) with concurrent.futures.ThreadPoolExecutor() as executor: heavy_node_vals = executor.map(lambda args: _eval_grad(*args), arg_list) for node_, node_grad_vals in zip(heavy_node_block, heavy_node_vals): for next_node, next_fun_grad in zip(graph[node_], node_grad_vals): grad_vals[next_node] += next_fun_grad if last_node_reached: break except StopIteration: pass # Compute the final gradients. This consists of any gradients computed # using old-style functions as well as the new-style function. total_grads = old_node_grads if len(in_nodes) >= 1: total_grads += [ np.array(grad_vals[in_node]) @ param.calculate_gradient() ] return np.sum(total_grads, axis=0)