def testNestedCellFn(self): """Tests when cell_fn calls another function.""" @function.Defun(tf.float32) def RandWithCoeff(coeff): return coeff * tf.random_uniform(shape=[], dtype=coeff.dtype) def Rand(theta, state, inputs): del theta next_state = py_utils.NestedMap() next_state.value = state.value + RandWithCoeff(inputs.coeff) return next_state, py_utils.NestedMap() @function.Defun(tf.float32) def Coeff(coeff): return coeff * 2 def Deterministic(theta, state, inputs): del theta next_state = py_utils.NestedMap() next_state.value = state.value + Coeff(inputs.coeff) return next_state, py_utils.NestedMap() with self.session(): theta = py_utils.NestedMap() state = py_utils.NestedMap() state.value = tf.constant(0.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3.]) recurrent.Recurrent( theta, state, inputs, Deterministic, check_stateful_ops=True) with self.assertRaisesRegexp(ValueError, 'stateful.*RandWithCoeff'): recurrent.Recurrent(theta, state, inputs, Rand, check_stateful_ops=True)
def testSymbolToTensorMap(self): """Tests that cell_fn can rely on the contextual symbol-to-tensor map.""" x = symbolic.Symbol('x') y = symbolic.Symbol('y') def PlusWXT(theta, state, inputs): """state.value += theta.w * x * inputs.t.""" next_state = py_utils.NestedMap() x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) next_state.value = state.value + theta.w * x_tensor * inputs.t return next_state, py_utils.NestedMap() def PlusWXTGrad(theta, state0, inputs, extras, dstate1): """Gradient function for PlusWXT.""" del state0, extras x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) dtheta = py_utils.NestedMap(w=dstate1.value * x_tensor * inputs.t) dstate0 = py_utils.NestedMap(value=dstate1.value) dinputs = py_utils.NestedMap(t=dstate1.value * theta.w * x_tensor) return dtheta, dstate0, dinputs, None with self.session() as sess: theta = py_utils.NestedMap(w=tf.constant(1., name='w')) state0 = py_utils.NestedMap(value=tf.constant(0., name='value')) inputs = py_utils.NestedMap(t=tf.constant([1., 2., 3.], name='t')) # With automatic cell_grad. with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, { x: tf.constant(7., name='x7'), y: 8 }): x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT) dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0] dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0] final_value, x_val, dx_val, dw_val = sess.run( [state1.value, x_tensor, dx, dw]) self.assertEqual(x_val, 7) self.assertEqual(final_value, x_val * (1. + 2. + 3.)) self.assertEqual(dw_val, x_val * (1. + 2. + 3.)) self.assertEqual(dx_val, (1. + 2. + 3.)) # With manual cell_grad. with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {x: tf.constant(5., name='x5')}): x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT, cell_grad=PlusWXTGrad) dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0] dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0] final_value, x_val, dx_val, dw_val = sess.run( [state1.value, x_tensor, dx, dw]) self.assertEqual(x_val, 5) self.assertEqual(final_value, x_val * (1. + 2. + 3.)) self.assertEqual(dw_val, x_val * (1. + 2. + 3.)) self.assertEqual(dx_val, (1. + 2. + 3.))
def testCaptureDisallowed(self): with self.session() as unused_sess: theta = py_utils.NestedMap() theta.x = tf.constant(2.0) state0 = py_utils.NestedMap() state0.value = tf.constant(0.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3.]) captured = tf.constant(1.2, name='captured_const') def CellFn(theta, state, inputs): next_state = py_utils.NestedMap() # Captured is pulled from outside of the function and captured via the # internal Defun. next_state.value = state.value + inputs.coeff * captured * theta.x return next_state, py_utils.NestedMap() # Run real fprop implementation. with self.assertRaisesRegexp(AssertionError, 'implicit capture is disabled'): unused_real_acc, unused_real_staten = recurrent.Recurrent( theta, state0, inputs, CellFn, allow_implicit_capture=False)
def FProp(self, theta, *args): """Runs p.repeat copies of self.body.FProp independently. Args: theta: Layer model parameters. The shape of each variable in theta is always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the i-th copy of self.body. *args: Input arguments. The shape of each tensor in args is always [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs to the i-th copy of self.body.FProp. Returns: The accumulated output_tensors. Each tensor t in the return has the shape [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the return tuple of the i-th self.body.FProp. """ p = self.params for arg in args: if arg is not None: arg = py_utils.HasShape(arg, [p.repeat], ndims=1) theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) inputs = py_utils.NestedMap(theta=theta_stack, args=list(args)) # Infer out_shapes from FPropMeta. out_shapes = self._InferOutShapes(args) def _CellFn(unused_theta, unused_state0, inputs): """Recurrent cell function wrapper of body.FProp.""" # Sets shapes for both theta and inputs to self.body.FProp. for dst, src in zip(inputs.args + inputs.theta.Flatten(), list(args) + theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp fprop_outputs = self.body.FProp(inputs.theta, *inputs.args) fprop_outputs = _ToTuple(fprop_outputs) assert len(fprop_outputs) == len(out_shapes) # Passes fprop outputs to the next layer through state. state1 = py_utils.NestedMap(outputs=list(fprop_outputs)) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Initiate state0 with inferred output shapes. state0 = py_utils.NestedMap( outputs=[tf.zeros(shape, args[0].dtype) for shape in out_shapes]) # Runs body.FProp p.repeat times using Recurrent. acc_states, _ = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=inputs, cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = tuple(acc_states.outputs) for out_idx in range(len(output_tensors)): output_tensors[out_idx].set_shape( tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list())) return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
def testRecurrent(self): """Run on GPU locally with --test_output=streamed --config=cuda.""" def Sum(theta, state, inputs): next_state = py_utils.NestedMap() v = tf.reduce_sum(tf.one_hot(inputs.one_hot, depth=2) * theta.x, axis=0) next_state.sum = state.sum + v return next_state, py_utils.NestedMap() with self.session(use_gpu=True) as sess: theta = py_utils.NestedMap() theta.x = tf.constant([-1.0, 2.0]) state = py_utils.NestedMap() state.sum = tf.constant(0.0) inputs = py_utils.NestedMap() inputs.one_hot = tf.constant([0, 1, 1], dtype=tf.int32) # sum = -1 + 2 + 2 ret = recurrent.Recurrent(theta, state, inputs, Sum) acc, state = sess.run(ret) self.assertAllClose(acc.sum, [-1., 1., 3.]) self.assertAllClose(state.sum, 3.) y = ret[1].sum dx, d_inputs = tf.gradients(ys=[y], xs=[theta.x, inputs.one_hot]) tf.logging.info('d(inputs) = %s', d_inputs) dx_val = sess.run(dx) self.assertAllClose(dx_val, [1, 2])
def testCapture(self): with self.session(): theta = py_utils.NestedMap() theta.x = tf.constant(2.0) state0 = py_utils.NestedMap() state0.value = tf.constant(0.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3.]) captured = tf.constant(1.2, name='captured_const') def CellFn(theta, state, inputs): next_state = py_utils.NestedMap() # Captured is pulled from outside of the function and captured via the # internal Defun. next_state.value = state.value + inputs.coeff * captured * theta.x return next_state, py_utils.NestedMap() # Run static fprop reference implementation. ref_acc, ref_staten = _ReferenceStaticUnroll( theta, state0, inputs, CellFn) ref_acc_values, ref_staten_values = self.evaluate( (ref_acc, ref_staten)) print('Ref fprop: acc =', ref_acc, ', stateN =', ref_staten) print('Ref fprop values: acc =', ref_acc_values, ', stateN =', ref_staten_values) self.assertAllClose(ref_acc_values.value, [2.4, 7.2, 14.4]) self.assertAllClose(ref_staten_values.value, 14.4) # Run real fprop implementation. real_acc, real_staten = recurrent.Recurrent( theta, state0, inputs, CellFn, allow_implicit_capture=True) real_acc_values, real_staten_values = self.evaluate( (real_acc, real_staten)) print('Real fprop: acc =', real_acc, ', stateN =', real_staten) print('Real fprop values: acc =', real_acc_values, ', stateN =', real_staten_values) self.assertAllClose(ref_acc_values.value, real_acc_values.value) self.assertAllClose(ref_staten_values.value, real_staten_values.value) # BProp real vs ref of stateN. ref_dx, ref_dcaptured = tf.gradients(ys=[ref_staten.value], xs=[theta.x, captured]) ref_dx_values, ref_dcaptured_values = self.evaluate( [ref_dx, ref_dcaptured]) real_dx, real_dcaptured = tf.gradients(ys=[real_staten.value], xs=[theta.x, captured]) real_dx_values, real_dcaptured_values = self.evaluate( [real_dx, real_dcaptured]) print('Ref Dstate/[dx,dcaptured] =', ref_dx_values, ', ', ref_dcaptured_values) print('Real Dstate/[dx,dcaptured] =', real_dx_values, ', ', real_dcaptured_values) self.assertAllClose(ref_dx_values, 7.2) self.assertAllClose(ref_dcaptured_values, 12.0) self.assertAllClose(ref_dx_values, real_dx_values) self.assertAllClose(ref_dcaptured_values, real_dcaptured_values)
def FProp(self, theta, *args): p = self.params # Collects all variable key and values into sets. theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) def _ArgsToState(arg_list): """Returns a NestedMap from a list of FProp args.""" state = py_utils.NestedMap() # Maintains a mapping from arg_idx to tensor. states cannot contains # None tensors. for idx in range(len(args)): if arg_list[idx] is not None: state['_s{}'.format(idx)] = arg_list[idx] return state def _StateToArgs(state): """Returns a list of FProp args from a NestedMap.""" arg_list = [] for idx in range(len(args)): attr = '_s{}'.format(idx) arg_list.append(state[attr] if attr in state else None) if arg_list[-1] is not None: arg_list[-1].set_shape(args[idx].shape) return arg_list def _CellFn(unused_theta, state0, theta_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. frop_inputs = _StateToArgs(state0) # Sets shapes for theta_i as well. for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp frop_outputs = self.body.FProp(theta_i, *frop_inputs) frop_outputs = _ToTuple(frop_outputs) assert len(frop_outputs) == len(frop_inputs) # Passes fprop outputs to the next layer through state. state1 = _ArgsToState(frop_outputs) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Add FProp arg list to state0. state0 = _ArgsToState(args) # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap. _, state1 = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=theta_stack, # Pass cell_fn theta through inputs. cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = _StateToArgs(state1) return output_tensors[0] if len(args) == 1 else tuple( output_tensors)
def _testElmanHelper(self, seqlen, use_grad, stop_fn=None): with self.session() as sess: tf.set_random_seed(342462) batch = 3 dims = 4 theta = py_utils.NestedMap() theta.w = self.Rand([2 * dims, dims]) theta.b = self.Rand([dims]) state0 = py_utils.NestedMap() state0.h = self.Rand([batch, dims]) inputs = py_utils.NestedMap() inputs.x = self.Rand([seqlen, batch, dims]) # Static unrolled. s = state0 out = [] for i in range(seqlen): inp = py_utils.NestedMap() inp.x = inputs.x[i, :] s, _ = self.Elman(theta, s, inp) out += [s.h] if stop_fn and stop_fn(i + 1, theta, s): out += [ tf.zeros_like(out[-1]) for _ in range(seqlen - i - 1) ] break acc0, final0 = tf.stack(out), s.h loss0 = tf.reduce_sum(acc0) + tf.reduce_sum(final0) (dw0, db0, dh0, di0) = tf.gradients(loss0, [theta.w, theta.b, state0.h, inputs.x]) # Uses the Recurrent() library. acc1, final1 = recurrent.Recurrent( theta=theta, state0=state0, inputs=inputs, cell_fn=self.Elman, cell_grad=self.ElmanGrad if use_grad else None, stop_fn=stop_fn) acc1, final1 = acc1.h, final1.h loss1 = tf.reduce_sum(acc1) + tf.reduce_sum(final1) (dw1, db1, dh1, di1) = tf.gradients(loss1, [theta.w, theta.b, state0.h, inputs.x]) # Fetches a bunch of values and compare them. (acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0, di1) = sess.run([ acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0, di1 ]) self.assertAllClose(acc0, acc1) self.assertAllClose(final0, final1) self.assertAllClose(dw0, dw1) self.assertAllClose(db0, db1) self.assertAllClose(dh0, dh1) self.assertAllClose(di0, di1)
def testRnnStepRecurrent(self): with self.session(use_gpu=False) as sess: p = rnn_steps.RnnStep.Params() p.name = 'rnn_step' p.cell.name = 'rnn_step_cell' p.cell.output_nonlinearity = True p.cell.params_init = py_utils.WeightInit.Uniform(1.24, 429891685) p.cell.bias_init = py_utils.WeightInit.Uniform(1.24, 429891685) p.cell.vn.global_vn = False p.cell.vn.per_step_vn = False p.cell.num_input_nodes = 2 p.cell.num_output_nodes = 2 p.cell.inputs_arity = 2 rnn_step = p.Instantiate() external = tf.constant([[3]], tf.float32) packed = rnn_step.PrepareExternalInputs(rnn_step.theta, external) padding = tf.constant([0.0], dtype=tf.float32) def cell_fn(theta, state0, inputs): output, state1 = rnn_step.FProp(theta, state0.packed, inputs, state0.padding, state0.state) return py_utils.NestedMap( output=output, state=state1, packed=state0.packed, padding=state0.padding), py_utils.NestedMap() _, state2 = recurrent.Recurrent( theta=rnn_step.theta, state0=py_utils.NestedMap( state=rnn_step.ZeroState(rnn_step.theta, packed, 1), output=py_utils.NestedMap(output=tf.zeros([1, 2], tf.float32), extra=py_utils.NestedMap(), padding=padding), padding=padding, packed=packed), cell_fn=cell_fn, inputs=py_utils.NestedMap( inputs=[tf.constant([[[4]], [[4]]], tf.float32)])) tf.global_variables_initializer().run() state2 = sess.run(state2) self.assertAllClose(state2.state.m, [[-0.32659757, 0.87739915]]) self.assertAllClose(state2.state.c, [[-1.9628618, 1.4194499]]) self.assertAllClose(state2.output.output, [[-0.32659757, 0.87739915]])
def testStopFnNotTriggeredBeforeEOS(self): with self.session() as sess: def StopFn(t, unused_theta, unused_state): # The input sequence is only length 4, so this is never true. # However, the Recurrent call should still terminate after iteration 4. return t >= 5 theta = py_utils.NestedMap() theta.x = tf.constant(2.0) state = py_utils.NestedMap() state.value = tf.constant(0.0) state.x_power = tf.constant(1.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3., 4.]) # x = 2 # 1 + 2*x + 3*x^2 + 4*x^3 ret = recurrent.Recurrent(theta, state, inputs, _Poly, stop_fn=StopFn) acc, state = sess.run(ret) self.assertAllClose([1., 5., 17., 49.], acc.value) self.assertAllClose([2., 4., 8., 16.], acc.x_power) self.assertAllClose(49., state.value) self.assertAllClose(16., state.x_power) y = ret[1].value dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 2 + 6*x + 12*x^2 self.assertAllClose(62., dx_val) self.assertAllClose([1., 2., 4., 8.], d_coeff_val) # acc = [1, 1+2x, 1+2x+3x^2, 1+2x+3x^2+4x^3] # sum(acc) = 4 + 6x + 6x^2 + 4x^3 acc = ret[0].value dx, d_coeff = tf.gradients(ys=[tf.reduce_sum(acc)], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 6 + 12*x + 12*x^2 self.assertAllClose(78., dx_val) self.assertAllClose([4., 6., 8., 8.], d_coeff_val)
def FProp(self, theta, prepared_inputs, inputs, padding, state0, **kwargs): """Runs a Step layer over multiple timesteps using Recurrent. Args: theta: A NestedMap containing weights' values of this layer and its children layers. prepared_inputs: External inputs returned by Step.PrepareExternalInputs(). inputs: A NestedMap of inputs of shape [time, batch_size, dim]. padding: A 0/1 float tensor of shape [time, batch_size]; 1.0 means that this batch element is empty in this step. state0: A NestedMap containing the initial recurrent state. **kwargs: Additional kwargs to pass to Recurrent. Returns: A tuple (outputs, state1). - outputs: A NestedMap containing the accumulated outputs of all steps, containing Tensors shaped [time, batch_size, dim]. - state1: A NestedMap containing the accumulated recurrent states, containing Tensors shaped [time, batch_size, dim]. """ def RnnStep(recurrent_theta, recurrent_state0, recurrent_inputs): """Compute a single timestep.""" output, state1 = self.step.FProp( theta=recurrent_theta.theta, prepared_inputs=recurrent_theta.prepared_inputs, step_inputs=recurrent_inputs.inputs, padding=recurrent_inputs.padding, state0=recurrent_state0.state) recurrent_state1 = py_utils.NestedMap(output=output, state=state1) return recurrent_state1, py_utils.NestedMap() # In order to pass Step outputs through Recurrent, they need to be # included as part of state. output0, _ = self.step.FProp(theta.step, prepared_inputs, inputs.Transform(lambda x: x[0]), padding[0], state0) accumulated_states, _ = recurrent.Recurrent( theta=py_utils.NestedMap( theta=theta.step, prepared_inputs=prepared_inputs), state0=py_utils.NestedMap(output=output0, state=state0), inputs=py_utils.NestedMap(inputs=inputs, padding=padding), cell_fn=RnnStep, **kwargs) return accumulated_states.output, accumulated_states.state
def testStateBasedStopFn(self): with self.session() as sess: def StopFn(unused_t, unused_theta, state): # This stops after 3 iterations. return state.value >= 15. theta = py_utils.NestedMap() theta.x = tf.constant(2.0) state = py_utils.NestedMap() state.value = tf.constant(0.0) state.x_power = tf.constant(1.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3., 4.]) # x = 2 # 1 + 2*x + 3*x^2 ret = recurrent.Recurrent(theta, state, inputs, _Poly, stop_fn=StopFn) acc, state = sess.run(ret) self.assertAllClose([1., 5., 17., 0.], acc.value) self.assertAllClose([2., 4., 8., 0.], acc.x_power) self.assertAllClose(17., state.value) self.assertAllClose(8., state.x_power) y = ret[1].value dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 2 + 6*x self.assertAllClose(14., dx_val) self.assertAllClose([1., 2., 4., 0.], d_coeff_val) # acc = [1, 1+2x, 1+2x+3x^2] # sum(acc) = 3 + 4x + 3x^2 acc = ret[0].value dx, d_coeff = tf.gradients(ys=[tf.reduce_sum(acc)], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 4 + 6*x self.assertAllClose(16., dx_val) self.assertAllClose([3., 4., 4., 0.], d_coeff_val)
def testBasic(self): with self.session() as sess: theta = py_utils.NestedMap() theta.x = tf.constant(2.0) state = py_utils.NestedMap() state.value = tf.constant(0.0) state.x_power = tf.constant(1.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3.]) def Poly(theta, state, inputs): next_state = py_utils.NestedMap() next_state.value = state.value + inputs.coeff * state.x_power next_state.x_power = state.x_power * theta.x return next_state, py_utils.NestedMap() # x = 2 # 1 + 2*x + 3*x^2 ret = recurrent.Recurrent(theta, state, inputs, Poly) acc, state = sess.run(ret) self.assertAllClose(acc.value, [1., 5., 17.]) self.assertAllClose(acc.x_power, [2., 4., 8.]) self.assertAllClose(state.value, 17.) self.assertAllClose(state.x_power, 8.) y = ret[1].value dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 2 + 6*x self.assertAllClose(dx_val, 14.) self.assertAllClose(d_coeff_val, [1., 2., 4.]) # acc = [1, 1+2x, 1+2x+3x^2] # sum(acc) = 3 + 4x + 3x^2 acc = ret[0].value dx, d_coeff = tf.gradients( ys=[tf.reduce_sum(acc)], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 4 + 6*x self.assertAllClose(dx_val, 16.) self.assertAllClose(d_coeff_val, [3., 4., 4.])
def FProp(self, theta, in_nmap): p = self.params if not p.remat: return self._FProp(theta, in_nmap) def CellFn(theta, state0, unused_inputs): out_nmap = self._FProp(theta, state0) return out_nmap, py_utils.NestedMap() _, state1 = recurrent.Recurrent( theta=theta, state0=in_nmap, inputs=py_utils.NestedMap( inputs=tf.zeros([1, 0])), # A dummy input of shape [T, ?]. cell_fn=CellFn, allow_implicit_capture=p.allow_implicit_capture) return state1
def testStatefulCellFn(self): def Rand(theta, state, inputs): del theta next_state = py_utils.NestedMap() next_state.value = ( state.value + inputs.coeff * tf.random_uniform(shape=[], dtype=state.value.dtype)) return next_state, py_utils.NestedMap() with self.session(): theta = py_utils.NestedMap() state = py_utils.NestedMap() state.value = tf.constant(0.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3.]) with self.assertRaisesRegexp(ValueError, 'stateful.*random_uniform'): recurrent.Recurrent(theta, state, inputs, Rand, check_stateful_ops=True)
def FProp(self, theta, inputs, paddings): p = self.params if not p.remat: return self._FProp(theta, inputs, paddings) def CellFn(theta, state0, unused_inputs): outs, out_paddings = self._FProp(theta, state0.inputs, state0.paddings) return py_utils.NestedMap( inputs=outs, paddings=out_paddings), py_utils.NestedMap() state0 = py_utils.NestedMap(inputs=inputs, paddings=paddings) _, state1 = recurrent.Recurrent( theta=theta, state0=state0, inputs=py_utils.NestedMap( inputs=tf.zeros([1, 0])), # A dummy input of shape [T, ?]. cell_fn=CellFn, allow_implicit_capture=p.allow_implicit_capture) return state1.inputs, state1.paddings
def EmbLookup(self, theta, ids, state0): """Use this layer like a regular EmbeddingLayer (ids -> embeddings). Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. ids: A tensor of token ids with dimensions [t, b]. state0: A NestedMap containing the state of previous tokens. - prev_ids: A Tensor containing the n previous token ids. [batch, num_prev_tokens]. Each row is the token ids at t-1, ..., t-n. - embeddings: The output embeddings of the previous step. Returns: Embeddings time major [t, b, d] and next state """ def _FPropWrapper(theta, state0, inputs): embedding, state1 = self.FProp( theta, None, inputs, None, state0, ) state1.embedding = embedding.output return state1, py_utils.NestedMap() steps_input = py_utils.NestedMap(inputs=[tf.cast(ids, tf.float32)]) state1, _ = recurrent.Recurrent( theta=theta, state0=state0, inputs=steps_input, cell_fn=_FPropWrapper, ) # for training, we want the output across all timesteps sequence_of_embeddings = state1.embedding # for inference, we want just the last timestep's state, not all states state1.embedding = state1.embedding[-1] state1.prev_ids = state1.prev_ids[-1] return sequence_of_embeddings, state1
def testDropoutInRecurrent(self, graph_seed): with self.session() as sess: if graph_seed: tf.random.set_seed(12345) l = lingvo_layers.DeterministicDropoutLayer.Params().Set( name='dropout', keep_prob=0.7).Instantiate() # Input variable. w = tf.get_variable('w', shape=[9, 20], initializer=tf.ones_initializer()) sess.run(tf.global_variables_initializer()) prev_sum = np.sum(np.isclose(sess.run(w), 0.0)) def Step(theta, state0, unused_inputs): w = l.FProp(theta.l, state0.w) state1 = py_utils.NestedMap(w=w) return state1, py_utils.NestedMap() acc, final = recurrent.Recurrent( theta=py_utils.NestedMap(l=l.theta), state0=py_utils.NestedMap(w=w), inputs=py_utils.NestedMap(x=tf.zeros([4])), cell_fn=Step) acc_w = sess.run(acc.w) self.assertLen(acc_w, 4) for acc_w_i in acc_w: next_sum = np.sum(np.isclose(acc_w_i, 0.0)) self.assertGreater(next_sum, prev_sum) prev_sum = next_sum # Construct loss function such that gradients = final activation. loss = tf.reduce_sum(final.w) grads = py_utils.ComputeGradients(loss, py_utils.NestedMap(w=w)) w_val, grads_val = sess.run([final.w, grads.w.grad]) self.assertAllClose(w_val, grads_val)
def _Repeat(self, theta, fn, *, common_input=None, layerwise_inputs=None, iterative_input_0=None, layerwise_output_0=None): """Invokes 'fn' for each sub-layer. Args: theta: layer parameters. fn: A function with args (theta, common_input, layerwise_input, iterative) returning a NestedMap(layerwise_output=..., iterative=...). common_input: a NestedMap with common inputs for every sub-layer. layerwise_inputs: a nested tensor with separate inputs for each sub-layer, where each leaf value T is a tensor of shape [p.repeat, ...] and T[i, ...] represents layerwise inputs to the i'th sub-layer. iterative_input_0: a nested tensor for the iterative input of the 0'th sub-layer. layerwise_output_0: a nested tensor representing the layerwise output of the first sub layer. The actual tensor values are not used. Only the structure and tensor shapes are. In particular, if layerwise_output_0 is None, calls fn(...) to compute the output. Returns: A NestedMap with: - iterative: a nested tensor with the same structure as iterative_input_0 representing the iterative output of the last sub-layer. - layerwise: a nested tensor where each leaf value T is a tensor of shape [p.repeat, ...] and T[i, ...] represents layerwise output from the i'th sub-layer. """ p = self.params if common_input is None: common_input = py_utils.NestedMap() if layerwise_inputs is None: layerwise_inputs = py_utils.NestedMap() if iterative_input_0 is None: iterative_input_0 = py_utils.NestedMap() if p.per_layer_vars: all_iters = [theta['body_iter_%05d' % i] for i in range(p.repeat)] theta_stack = py_utils.NestedMap(body=tf.nest.map_structure( lambda *t: tf.stack(list(t)), *all_iters)) else: theta_stack = self._MaybeStackExtraTheta(theta) theta_0 = tf.nest.map_structure(lambda t: t[0, ...], theta_stack) layerwise_input_0 = tf.nest.map_structure(lambda t: t[0, ...], layerwise_inputs) if layerwise_output_0 is None: # Run 'fn' for sublayer 0 to get a template for layerwise output. layerwise_output_0 = fn( theta_0, common_input=common_input, layerwise_input=layerwise_input_0, iterative=iterative_input_0).layerwise_output # Split 'common_input' to static vs. tensor values. static_common_input = tf.nest.map_structure( lambda x: None if isinstance(x, tf.Tensor) else x, common_input) dummy_tensor = tf.zeros([], dtype=p.dtype) tensor_common_input = tf.nest.map_structure( lambda x: x if isinstance(x, tf.Tensor) else dummy_tensor, common_input) def _ToRecurState(*, iterative, layerwise_output): """Returns a recurrent state.""" # Make sure each value in the NestedMap is a tensor. recur_state = py_utils.NestedMap(iterative=iterative, layerwise_output=layerwise_output) for key, value in recur_state.FlattenItems(): if not isinstance(value, tf.Tensor) or value.shape.rank is None: raise ValueError( 'Each value in the recurrent state must be a tensor ' f'with a known rank: {key}={value}') return recur_state def _CellFn(recur_theta, recur_state0, recur_input_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. theta_i = _CopyShapes(theta_0, recur_input_i.theta) # Use non-tensor values from 'static_common_input' if available. merged_common_input = tf.nest.map_structure( (lambda x, s, t: t if isinstance(x, tf.Tensor) else s), common_input, static_common_input, recur_theta) # Call 'fn'. fn_results = fn(theta_i, common_input=merged_common_input, layerwise_input=_CopyShapes( layerwise_input_0, recur_input_i.layerwise), iterative=_CopyShapes(iterative_input_0, recur_state0.iterative)) # Pack results. return _ToRecurState(**fn_results), py_utils.NestedMap() with tf.name_scope('repeat'): recur_state0 = _ToRecurState(iterative=iterative_input_0, layerwise_output=layerwise_output_0) # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap. acc_states, final_recur_state = recurrent.Recurrent( theta=tensor_common_input, state0=recur_state0, inputs=py_utils.NestedMap(theta=theta_stack, layerwise=layerwise_inputs), cell_fn=_CellFn, allow_implicit_capture=p.allow_implicit_capture) # Retrieves outputs from recur_state1 and sets shapes. iterative_output = _CopyShapes(iterative_input_0, final_recur_state.iterative) # Set shapes of layer-wise outputs according to layerwise_output_0. layerwise_outputs = acc_states.layerwise_output tf.nest.map_structure( lambda t_stack, t_0: t_stack.set_shape( [p.repeat] + t_0.shape.as_list()), layerwise_outputs, layerwise_output_0) return py_utils.NestedMap(iterative=iterative_output, layerwise=layerwise_outputs)
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. Returns: A NestedMap containing the following tensors: - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap( theta=decoder_theta, random_seed=random_seed, encoder_outputs=encoder_outputs) bs_result, bs_state = init_state_callback( recurrent_theta.theta, encoder_outputs, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=tf.fill([batch], tf.to_int32(p.target_sos_id)), bs_state=bs_state) inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch])) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_multinomial( state1.logits / p.temperature, num_samples=1, seed=tf.stack([recurrent_theta.random_seed, state0.timestep]), output_dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.logical_and(bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback(recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap() accumulated_states, _ = recurrent.Recurrent(recurrent_theta, recurrent_state0, inputs, Step) result = py_utils.NestedMap( logits=tf.transpose(accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where( tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def FProp(self, theta, *args): p = self.params if self.do_eval and p.unrolled_in_eval: return self._unrolled_fprop(theta, *args) if p.per_layer_vars: all_iters = [theta['body_iter_%05d' % i] for i in range(p.repeat)] theta_stack = tf.nest.map_structure(lambda *t: tf.stack(list(t)), *all_iters) else: theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) def _ArgsToState(arg_list): """Returns a NestedMap from a list of FProp args.""" state = py_utils.NestedMap() # Maintains a mapping from arg_idx to tensor. states cannot contains # None tensors. for idx in range(len(args)): if isinstance(arg_list[idx], py_utils.NestedMap): # Make sure each value in the NestedMap is a tensor. if not all(isinstance(t, tf.Tensor) for t in arg_list[idx].Flatten()): raise ValueError( 'Each value in the input NestedMap must be a tensor.') if arg_list[idx] is not None: state['_s{}'.format(idx)] = arg_list[idx] return state def _StateToArgs(state): """Returns a list of FProp args from a NestedMap.""" arg_list = [] for idx in range(len(args)): attr = '_s{}'.format(idx) arg_list.append(state[attr] if attr in state else None) if isinstance(arg_list[-1], py_utils.NestedMap): assert isinstance(args[idx], py_utils.NestedMap) py_utils.SetShapes(arg_list[-1], args[idx]) elif isinstance(arg_list[-1], tf.Tensor): if arg_list[-1] is not None: arg_list[-1].set_shape(args[idx].shape) return arg_list def _CellFn(unused_theta, state0, theta_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. fprop_inputs = _StateToArgs(state0) # Sets shapes for theta_i as well. for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp fprop_outputs = self._body.FProp(theta_i, *fprop_inputs) fprop_outputs = _ToTuple(fprop_outputs) assert len(fprop_outputs) == len(fprop_inputs) # Passes fprop outputs to the next layer through state. state1 = _ArgsToState(fprop_outputs) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Add FProp arg list to state0. state0 = _ArgsToState(args) # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap. _, state1 = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=theta_stack, # Pass cell_fn theta through inputs. cell_fn=_CellFn, allow_implicit_capture=p.allow_implicit_capture) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = _StateToArgs(state1) return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
def FProp(self, theta, inputs, paddings, state0=None, segment_id=None): """Computes LSTM forward pass. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: A single tensor or a tuple of tensors with cardinality equal to rnn_cell.inputs_arity. For every input tensor, the first dimension is assumed to be time, second dimension batch, and third dimension depth. paddings: A tensor. First dim is time, second dim is batch, and third dim is expected to be 1. state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to the cell's zero-state. segment_id: A tensor to support packed inputs. First dim is time, second dim is batch, and third dim is expected to be 1. Returns: A tensor of [time, batch, dims]. The final recurrent state. """ p = self.params assert isinstance(self.cell, rnn_cell.RNNCell) if not isinstance(inputs, (list, tuple)): inputs = [inputs] # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular # LSTM baseline. # Keeping slicing within the loop gives only < 3% speedup. cell_theta = theta.cell.copy() num_input_nodes = p.cell.num_input_nodes cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :] cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :] tf.logging.vlog(1, 'cell_theta: %r', cell_theta) if p.packed_input: assert segment_id is not None reset_mask = rnn_layers.GeneratePackedInputResetMask( segment_id, is_reverse=False) reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings)) else: reset_mask = tf.zeros_like(paddings) if p.reverse: inputs = [tf.reverse(x, [0]) for x in inputs] paddings = tf.reverse(paddings, [0]) reset_mask = tf.reverse(reset_mask, [0]) if not state0: batch_size = py_utils.GetShape(paddings)[1] state0 = self.cell.zero_state(cell_theta, batch_size) # [T, B, H] proj_inputs = self.cell.ProjectInputSequence( cell_theta, py_utils.NestedMap(act=inputs)) proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs, padding=paddings, reset_mask=reset_mask) acc_state, final_state = recurrent.Recurrent( theta=cell_theta, state0=state0, inputs=proj_inputs, cell_fn=self.cell.FPropWithProjectedInput, cell_type=self.cell.layer_type, accumulator_layer=self, allow_implicit_capture=p.allow_implicit_capture) act = self.cell.GetOutput(acc_state) if p.reverse: act = tf.reverse(act, [0]) return act, final_state
def _BuildStackedRecurrentElman(self, seqlen, trailing_pad_len, batch, dims, layers): tf.set_random_seed(342462) np.random.seed(32540) seqlen += trailing_pad_len dtype = tf.float64 def CreateTheta(): return py_utils.NestedMap( w=tf.constant(np.random.uniform(0, 0.2, (2 * dims, dims)), dtype=dtype), b=tf.constant(np.random.uniform(0, 0.2, (dims, )), dtype=dtype)) def CreateState0(): return py_utils.NestedMap(h=tf.constant(np.random.uniform( 0, 0.2, (batch, dims)), dtype=dtype), padding=tf.constant([[0]] * batch, dtype=dtype)) devices = ['/cpu:0'] * layers cell_fns = [self.Elman] * layers cell_grads = [self.ElmanGrad] * layers cell_outs = [self.ElmanOut] * layers cell_out_grads = [self.ElmanOutGrad] * layers thetas = [CreateTheta() for _ in range(layers)] init_states = [CreateState0() for _ in range(layers)] padding = np.zeros((seqlen, batch, 1)) padding[-trailing_pad_len:, :, :] = 1. padding[-trailing_pad_len - 3:-trailing_pad_len - 1, :, :] = 1. inputs = py_utils.NestedMap(x=tf.constant(np.random.uniform( 0, 0.2, (seqlen, batch, dims)), dtype=dtype), padding=tf.constant(padding, dtype=dtype)) output, _ = recurrent.StackedRecurrent(devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs) o = output.x if 'padding' in inputs: o *= (1 - inputs.padding) loss = tf.reduce_sum(tf.square(o)) xs = recurrent.Flatten(thetas + [py_utils.NestedMap(x=inputs.x)]) dxs = tf.gradients(ys=loss, xs=xs) # Reference implementation using Recurrent(). ref = inputs for i in range(layers): ref = self.ElmanOut( recurrent.Recurrent(cell_fn=cell_fns[i], cell_grad=cell_grads[i], theta=thetas[i], state0=init_states[i], inputs=ref)[0]) return ref.x, output.x, loss, xs, dxs
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback, init_step_ids=None): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. init_step_ids: [batch], optional init step ids, default to SOS. Returns: A NestedMap containing the following tensors - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 assert p.top_k >= 0 assert p.num_hyps_per_beam >= 1 if getattr(encoder_outputs, 'segment_id', 1) is None: # Remove None values, which are not supported by recurrent. del encoder_outputs['segment_id'] # init_state_callback may modify 'encoder_outputs', e.g., by inserting # 'packed_src'. bs_result, bs_state = init_state_callback(decoder_theta, encoder_outputs, p.num_hyps_per_beam) # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap(random_seed=random_seed, encoder_outputs=encoder_outputs) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=init_step_ids if init_step_ids is not None else tf.fill( [batch], tf.cast(p.target_sos_id, tf.int32)), bs_state=bs_state) if p.use_recurrent: inputs = py_utils.NestedMap( dummy=tf.zeros([p.target_seq_len, batch])) else: inputs = py_utils.NestedMap( ids=tf.TensorArray(dtype=tf.int32, size=p.target_seq_len), logits=tf.TensorArray(dtype=bs_result.log_probs.dtype, size=p.target_seq_len), ) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=p.num_hyps_per_beam) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(state1.logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits else: sample_logits = state1.logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs) if p.use_recurrent: def StopFn(t, theta, state): del t, theta # Unused: this stop function only uses the state ids. return tf.equal(state.ids, p.target_eos_id) else: def StopFn(recurrent_theta, state, inputs): del recurrent_theta, inputs return tf.logical_not( tf.reduce_all(tf.equal(state.ids, p.target_eos_id))) if p.use_stop_fn: stop_fn = StopFn else: stop_fn = None if p.use_recurrent: accumulated_states, _ = recurrent.Recurrent( recurrent_theta, recurrent_state0, inputs, Step, stop_fn=stop_fn, allow_implicit_capture=True) else: loop_vars = (recurrent_theta, recurrent_state0, inputs) (_, _, accumulated_states) = tf.while_loop( StopFn, Step, loop_vars=loop_vars, shape_invariants=_GetShapes(loop_vars, none_shapes=True), back_prop=False, maximum_iterations=p.target_seq_len) accumulated_states.ids = accumulated_states.ids.stack() accumulated_states.logits = accumulated_states.logits.stack() result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def testStatefulEmbeddingStep(self): with self.session(use_gpu=False): tf.random.set_seed(398847392) p = embedding_steps.StatefulEmbeddingStep.Params().Set( name='emb_step', num_prev_tokens=1, include_current_token=False, target_sos_id=1, embedding_dim=3, ) p.emb.Set( vocab_size=10, embedding_dim=3, max_num_shards=1, params_init=py_utils.WeightInit.Gaussian(0.01), ) p.emb.vn.global_vn = False p.emb.vn.per_step_vn = False emb = p.Instantiate() # Verify that nothing bad happens when these methods are called. packed = emb.PrepareExternalInputs(None, None) state0 = emb.ZeroState(emb.theta, packed, 2) # Test FProp of the unit out1, state1 = emb.FProp( emb.theta, packed, py_utils.NestedMap(inputs=[tf.constant([4, 3], tf.int32)]), tf.constant([0.0], dtype=tf.float32), state0) self.evaluate(tf.global_variables_initializer()) out1, state1 = self.evaluate([out1, state1]) self.assertAllEqual(state1.prev_ids, np.array([[4], [3]])) self.assertAllClose( out1.output, np.array([[-0.00740041, -0.00746862, 0.00093992], [-0.00740041, -0.00746862, 0.00093992]])) # Test FProp and BProp when integrated with Recurrent() def _FProp(theta, state0, inputs): embedding, state1 = emb.FProp( theta, None, inputs, None, state0, ) state1.embedding = embedding.output return state1, py_utils.NestedMap() inputs = py_utils.NestedMap(inputs=[ tf.constant([[1., 2.], [3., 2.], [0., 1.], [2., 3.], [3., 0.]]) ]) acc, _ = recurrent.Recurrent( emb.theta, state0, inputs, _FProp, ) loss = tf.math.l2_normalize(acc.embedding) grad = tf.gradients(loss, emb.emb.theta.wm[0]) self.evaluate(tf.global_variables_initializer()) acc_, _, grad_ = self.evaluate([acc, emb.emb.theta.wm[0], grad]) prev_ids_expected = np.array([ [[1.], [2.]], [[3.], [2.]], [[0.], [1.]], [[2.], [3.]], [[3.], [0.]], ]) grad_expected = np.array([[21.952698, 20.50312, 19.037958], [79.622116, 72.15271, 106.34329], [41.631985, 70.19292, 75.52608], [53.644493, 36.28507, 36.64856], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) self.assertAllClose(acc_.prev_ids, prev_ids_expected) self.assertAllClose(grad_[0], grad_expected)
def testBasicWithAccumulator(self): with self.session() as sess: p = _SampleAccumulatorLayer.Params() p.name = 'sample' accum_layer = _SampleAccumulatorLayer(p) accum_obj = accum_layer.accumulators[accum_layer.accumulator_name] theta = py_utils.NestedMap() theta.x = tf.constant(2.0) state = py_utils.NestedMap() state.value = tf.constant(0.0) state.x_power = tf.constant(1.0) inputs = py_utils.NestedMap() inputs.coeff = tf.constant([1., 2., 3.]) def _CellFn(theta, state, inputs): print('TEST ACCUM WITHIN CellFn = ', accum_obj.GetValue()) accum_obj.Update(inputs.coeff) return _Poly(theta, state, inputs) # By doing one accumulate prior to recurrent, we ensure that incoming # recurrent state is preserved. accum_obj.Update(10.) # x = 2 # 1 + 2*x + 3*x^2 ret = recurrent.Recurrent(theta, state, inputs, _CellFn, accumulator_layer=accum_layer) # Verify bprop. y = ret[1].value dx, d_coeff = tf.gradients(ys=[y], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 2 + 6*x self.assertAllClose(dx_val, 14.) self.assertAllClose(d_coeff_val, [1., 2., 4.]) # acc = [1, 1+2x, 1+2x+3x^2] # sum(acc) = 3 + 4x + 3x^2 acc = ret[0].value dx, d_coeff = tf.gradients(ys=[tf.reduce_sum(acc)], xs=[theta.x, inputs.coeff]) dx_val, d_coeff_val = sess.run([dx, d_coeff]) # 4 + 6*x self.assertAllClose(dx_val, 16.) self.assertAllClose(d_coeff_val, [3., 4., 4.]) # Verify fprop. (acc, state), accum_obj_value = sess.run( (ret, accum_obj.GetValue())) # Verify that accumulators don't change fprop results. self.assertAllClose(acc.value, [1., 5., 17.]) self.assertAllClose(acc.x_power, [2., 4., 8.]) self.assertAllClose(state.value, 17.) self.assertAllClose(state.x_power, 8.) # Verify accumulator (should be 10 (initial increment) + 1 + 2 + 3). self.assertEqual(0, accum_obj._disable_count) self.assertAllClose([accum_obj_value], [16.0])