def PostProcessing(self): """Perform postprocessing at the end of gradients(). We have created the gradient graph at this point. So this function can be used to perform any postprocessing on the gradient graph. We currently perform the following postprocessing: 1. Patch the gradient graph if the output of a loop variable doesn't depend on its input. """ for _, grad_state in self._map.items(): for _, b_merge in grad_state.switch_map.items(): if b_merge.op.inputs[0] == b_merge.op.inputs[1]: # The value of this loop variable at iteration i+1 doesn't # depend on its value at iteration i. So use zeros as the # gradients for all iterations > 0. dtype = b_merge.op.inputs[0].dtype shape = b_merge.op.inputs[0].get_shape() # pylint: disable=protected-access if shape.is_fully_defined(): grad_state.grad_context.Enter() # Create a zeros and use it for iterations > 0. grad_val = constant_op.constant(0, dtype=dtype, shape=shape) next_grad_val = control_flow_ops._NextIteration( grad_val) grad_state.grad_context.Exit() else: # Create a zeros in the outer grad context. outer_grad_ctxt = grad_state.grad_context.outer_context if outer_grad_ctxt: outer_grad_ctxt.Enter() enter_grad_op = b_merge.op.inputs[0].op enter_grad = enter_grad_op.inputs[0] grad_shape = array_ops.shape_internal(enter_grad, optimize=False) grad_val = array_ops.zeros(grad_shape) if outer_grad_ctxt: outer_grad_ctxt.Exit() # Use the zeros for iterations > 0. grad_state.grad_context.Enter() next_grad_val = control_flow_ops._NextIteration( grad_val) grad_state.grad_context.Exit() b_merge.op._update_input(1, next_grad_val)
def _SwitchGrad(op, *grad): """Gradients for a Switch op is calculated using a Merge op. If the switch is a loop switch, it will be visited twice. We create the merge on the first visit, and update the other input of the merge on the second visit. A next_iteration is also added on second visit. """ graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): merge_op = grad_ctxt.grad_state.switch_map.get(op) if merge_op: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO: Perform shape inference with this new input. # pylint: disable=protected-access merge_op._update_input(1, control_flow_ops._NextIteration(grad[1])) # pylint: enable=protected-access return None, None else: # This is the first time this Switch is visited. It always comes # from the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. merge_fn = control_flow_ops._Merge # pylint: disable=protected-access merge_op = merge_fn([grad[0], grad[0]], name="b_switch")[0] grad_ctxt.grad_state.switch_map[op] = merge_op.op return merge_op, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] # If we are in a grad context, this switch is part of a cond within a # loop. In this case, we have called ControlFlowState.ZeroLike() so grad # is ready for merge. Otherwise, we need a switch to control zero_grad. if not (grad_ctxt and grad_ctxt.grad_state): dtype = good_grad.dtype branch = op_ctxt.branch zero_grad = switch(zero_grad, op_ctxt.pred, dtype=dtype)[1 - branch] return merge([good_grad, zero_grad], name="cond_grad")[0], None else: false_grad = switch(grad[0], op.inputs[1])[0] true_grad = switch(grad[1], op.inputs[1])[1] return merge([false_grad, true_grad])[0], None