def _GRUBlockCellGrad(op, *grad): r"""Gradient for GRUBlockCell. Args: op: Op for which the gradient is defined. *grad: Gradients of the optimization function wrt output for the Op. Returns: d_x: Gradients wrt to x d_h: Gradients wrt to h d_w_ru: Gradients wrt to w_ru d_w_c: Gradients wrt to w_c d_b_ru: Gradients wrt to b_ru d_b_c: Gradients wrt to b_c Mathematics behind the Gradients below: ``` d_c_bar = d_h \circ (1-u) \circ (1-c \circ c) d_u_bar = d_h \circ (h-c) \circ u \circ (1-u) d_r_bar_u_bar = [d_r_bar d_u_bar] [d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T [d_x_component_2 d_h_prevr] = d_c_bar * w_c^T d_x = d_x_component_1 + d_x_component_2 d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u ``` Below calculation is performed in the python wrapper for the Gradients (not in the gradient kernel.) ``` d_w_ru = x_h_prevr^T * d_c_bar d_w_c = x_h_prev^T * d_r_bar_u_bar d_b_ru = sum of d_r_bar_u_bar along axis = 0 d_b_c = sum of d_c_bar along axis = 0 ``` """ x, h_prev, w_ru, w_c, b_ru, b_c = op.inputs r, u, c, _ = op.outputs _, _, _, d_h = grad d_x, d_h_prev, d_c_bar, d_r_bar_u_bar = gen_gru_ops.gru_block_cell_grad( x, h_prev, w_ru, w_c, b_ru, b_c, r, u, c, d_h) x_h_prev = array_ops.concat([x, h_prev], 1) d_w_ru = math_ops.matmul(x_h_prev, d_r_bar_u_bar, transpose_a=True) d_b_ru = nn_ops.bias_add_grad(d_r_bar_u_bar) x_h_prevr = array_ops.concat([x, h_prev * r], 1) d_w_c = math_ops.matmul(x_h_prevr, d_c_bar, transpose_a=True) d_b_c = nn_ops.bias_add_grad(d_c_bar) return d_x, d_h_prev, d_w_ru, d_w_c, d_b_ru, d_b_c
def _LSTMBlockCellGrad(op, *grad): """Gradient for LSTMBlockCell.""" (x, cs_prev, h_prev, w, wci, wco, wcf, b) = op.inputs (i, cs, f, o, ci, co, _) = op.outputs (_, cs_grad, _, _, _, _, h_grad) = grad batch_size = x.get_shape().with_rank(2)[0].value if batch_size is None: batch_size = -1 input_size = x.get_shape().with_rank(2)[1].value if input_size is None: raise ValueError("input_size from `x` should not be None.") cell_size = cs_prev.get_shape().with_rank(2)[1].value if cell_size is None: raise ValueError("cell_size from `cs_prev` should not be None.") (cs_prev_grad, dicfo, wci_grad, wcf_grad, wco_grad) = _lstm_ops_so.lstm_block_cell_grad( x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, cs_grad, h_grad, use_peephole=op.get_attr("use_peephole")) # Backprop from dicfo to xh. xh_grad = math_ops.matmul(dicfo, w, transpose_b=True) x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size)) x_grad.get_shape().merge_with(x.get_shape()) h_prev_grad = array_ops.slice(xh_grad, (0, input_size), (batch_size, cell_size)) h_prev_grad.get_shape().merge_with(h_prev.get_shape()) # Backprop from dicfo to w. xh = array_ops.concat(1, [x, h_prev]) w_grad = math_ops.matmul(xh, dicfo, transpose_a=True) w_grad.get_shape().merge_with(w.get_shape()) # Backprop from dicfo to b. b_grad = nn_ops.bias_add_grad(dicfo) b_grad.get_shape().merge_with(b.get_shape()) return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad)
def _LSTMBlockCellGrad(op, *grad): """Gradient for LSTMBlockCell.""" (x, cs_prev, h_prev, w, wci, wco, wcf, b) = op.inputs (i, cs, f, o, ci, co, _) = op.outputs (_, cs_grad, _, _, _, _, h_grad) = grad batch_size = x.get_shape().with_rank(2)[0].value if batch_size is None: batch_size = -1 input_size = x.get_shape().with_rank(2)[1].value if input_size is None: raise ValueError("input_size from `x` should not be None.") cell_size = cs_prev.get_shape().with_rank(2)[1].value if cell_size is None: raise ValueError("cell_size from `cs_prev` should not be None.") (cs_prev_grad, dicfo, wci_grad, wcf_grad, wco_grad) = gen_lstm_ops.lstm_block_cell_grad( x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, cs_grad, h_grad, use_peephole=op.get_attr("use_peephole")) # Backprop from dicfo to xh. xh_grad = math_ops.matmul(dicfo, w, transpose_b=True) x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size)) x_grad.get_shape().merge_with(x.get_shape()) h_prev_grad = array_ops.slice(xh_grad, (0, input_size), (batch_size, cell_size)) h_prev_grad.get_shape().merge_with(h_prev.get_shape()) # Backprop from dicfo to w. xh = array_ops.concat([x, h_prev], 1) w_grad = math_ops.matmul(xh, dicfo, transpose_a=True) w_grad.get_shape().merge_with(w.get_shape()) # Backprop from dicfo to b. b_grad = nn_ops.bias_add_grad(dicfo) b_grad.get_shape().merge_with(b.get_shape()) return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad)