def testSymbolToTensorMap(self): """Tests that cell_fn can rely on the contextual symbol-to-tensor map.""" x = symbolic.Symbol('x') y = symbolic.Symbol('y') def PlusWXT(theta, state, inputs): """state.value += theta.w * x * inputs.t.""" next_state = py_utils.NestedMap() x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) next_state.value = state.value + theta.w * x_tensor * inputs.t return next_state, py_utils.NestedMap() def PlusWXTGrad(theta, state0, inputs, extras, dstate1): """Gradient function for PlusWXT.""" del state0, extras x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) dtheta = py_utils.NestedMap(w=dstate1.value * x_tensor * inputs.t) dstate0 = py_utils.NestedMap(value=dstate1.value) dinputs = py_utils.NestedMap(t=dstate1.value * theta.w * x_tensor) return dtheta, dstate0, dinputs, None with self.session() as sess: theta = py_utils.NestedMap(w=tf.constant(1., name='w')) state0 = py_utils.NestedMap(value=tf.constant(0., name='value')) inputs = py_utils.NestedMap(t=tf.constant([1., 2., 3.], name='t')) # With automatic cell_grad. with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, { x: tf.constant(7., name='x7'), y: 8 }): x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT) dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0] dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0] final_value, x_val, dx_val, dw_val = sess.run( [state1.value, x_tensor, dx, dw]) self.assertEqual(x_val, 7) self.assertEqual(final_value, x_val * (1. + 2. + 3.)) self.assertEqual(dw_val, x_val * (1. + 2. + 3.)) self.assertEqual(dx_val, (1. + 2. + 3.)) # With manual cell_grad. with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {x: tf.constant(5., name='x5')}): x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x) _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT, cell_grad=PlusWXTGrad) dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0] dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0] final_value, x_val, dx_val, dw_val = sess.run( [state1.value, x_tensor, dx, dw]) self.assertEqual(x_val, 5) self.assertEqual(final_value, x_val * (1. + 2. + 3.)) self.assertEqual(dw_val, x_val * (1. + 2. + 3.)) self.assertEqual(dx_val, (1. + 2. + 3.))
def testEvalExpr(self): x = symbolic.Symbol('x') y = symbolic.Symbol('y') xy = x * y # Without symbol-to-value map. self.assertEqual(xy, symbolic.ToStatic(xy)) self.assertEqual(xy, symbolic.ToTensor(xy)) with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 2, y: 3}): self.assertEqual(symbolic.ToStatic(xy), 6) # The inner map overrides the outer map. with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 5, y: 6}): self.assertEqual(symbolic.ToStatic(xy), 30) # Back to the outer map. self.assertEqual(symbolic.ToStatic(xy), 6) # EvalExpr can also evaluate a symbolic expression to a # Tensor. a = tf.placeholder(tf.float32) b = tf.placeholder(tf.float32) with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {x: a, y: b}): with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 2, y: 3}): # Value maps of different types do not affect each other. self.assertEqual(symbolic.ToStatic(xy), 6) ab = symbolic.ToTensor(xy) self.assertIsInstance(ab, tf.Tensor) with self.session() as sess: self.assertEqual(12, sess.run(ab, {a: 3, b: 4})) # EvalExpr supports partial evaluation. with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {y: 3}): x3 = symbolic.ToStatic(xy) with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 9}): self.assertEqual(27, symbolic.ToStatic(x3))
def testSetRnnCellNodes(self): decoder_p = decoder.AsrDecoder.Params() base_rnn_p = rnn_cell.LSTMCellSimple.Params().Set(num_output_nodes=4) # rnn_cell_dim > 0. decoder_p.rnn_cell_dim = 8 rnn_p = base_rnn_p.Copy() decoder_utils.SetRnnCellNodes(decoder_p, rnn_p) self.assertEqual(rnn_p.num_output_nodes, decoder_p.rnn_cell_dim) # rnn_cell_dim <= 0. decoder_p.rnn_cell_dim = 0 rnn_p = base_rnn_p.Copy() decoder_utils.SetRnnCellNodes(decoder_p, rnn_p) self.assertEqual(rnn_p.num_output_nodes, base_rnn_p.num_output_nodes) # rnn_cell_dim is a symbol. decoder_p.rnn_cell_dim = symbolic.Symbol("rnn_cell_dim") rnn_p = base_rnn_p.Copy() decoder_utils.SetRnnCellNodes(decoder_p, rnn_p) self.assertIs(rnn_p.num_output_nodes, decoder_p.rnn_cell_dim) # rnn_cell_hidden_dim > 0. decoder_p.rnn_cell_hidden_dim = 16 rnn_p = base_rnn_p.Copy() decoder_utils.SetRnnCellNodes(decoder_p, rnn_p) self.assertEqual(rnn_p.num_hidden_nodes, decoder_p.rnn_cell_hidden_dim) # rnn_cell_hidden_dim <= 0. decoder_p.rnn_cell_hidden_dim = 0 rnn_p = base_rnn_p.Copy() decoder_utils.SetRnnCellNodes(decoder_p, rnn_p) self.assertEqual(rnn_p.num_hidden_nodes, base_rnn_p.num_hidden_nodes) # rnn_cell_hidden_dim is a symbol. decoder_p.rnn_cell_hidden_dim = symbolic.Symbol("rnn_cell_hidden_dim") rnn_p = base_rnn_p.Copy() decoder_utils.SetRnnCellNodes(decoder_p, rnn_p) self.assertIs(rnn_p.num_hidden_nodes, decoder_p.rnn_cell_hidden_dim)
def testDeepCopy(self): inner = hyperparams.Params() inner.Define('alpha', 2, '') inner.Define('tensor', tf.constant(0), '') inner.Define('symbol', symbolic.Symbol('symbol'), '') outer = hyperparams.Params() outer.Define('beta', 1, '') outer.Define('inner', inner, '') outer_copy = outer.Copy() self.assertIsNot(outer, outer_copy) self.assertEqual(outer, outer_copy) self.assertIsNot(outer.inner, outer_copy.inner) self.assertEqual(outer.inner, outer_copy.inner) self.assertEqual(outer.inner.alpha, outer_copy.inner.alpha) self.assertIs(outer.inner.tensor, outer_copy.inner.tensor) self.assertIs(outer.inner.symbol, outer_copy.inner.symbol)
def testToFromProto(self): outer = hyperparams.Params() outer.Define('integer_val', 1, '') outer.Define('cls_type', type(int), '') inner = hyperparams.Params() inner.Define('float_val', 2.71, '') inner.Define('string_val', 'rosalie et adrien', '') inner.Define('bool_val', True, '') inner.Define('list_of_tuples_of_dicts', [({'string_key': 1729})], '') inner.Define('range', range(1, 3), '') outer.Define('inner', inner, '') outer.Define('empty_list', [], '') outer.Define('empty_tuple', (), '') outer.Define('empty_dict', {}, '') outer.Define('enum', TestEnum.B, '') outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '') outer.Define('dataclass', TestDataClass(a=[42], b=tf.float32), '') outer.Define('namedtuple', tf.io.FixedLenSequenceFeature([42], tf.float32), '') outer.Define('symbol_x', symbolic.Symbol('x'), '') outer.Define('symbol_2x', outer.symbol_x * 2, '') rebuilt_outer = hyperparams.InstantiableParams.FromProto( outer.ToProto()) self.assertNotIn('cls', rebuilt_outer) self.assertEqual(outer.integer_val, rebuilt_outer.integer_val) self.assertEqual(outer.cls_type, rebuilt_outer.cls_type) self.assertNear(outer.inner.float_val, rebuilt_outer.inner.float_val, 1e-6) self.assertEqual(outer.inner.string_val, rebuilt_outer.inner.string_val) self.assertEqual(outer.inner.bool_val, rebuilt_outer.inner.bool_val) self.assertEqual(outer.inner.list_of_tuples_of_dicts, rebuilt_outer.inner.list_of_tuples_of_dicts) self.assertEqual([1, 2], rebuilt_outer.inner.range) # Rebuilt as list. self.assertEqual(outer.empty_list, rebuilt_outer.empty_list) self.assertEqual(outer.empty_tuple, rebuilt_outer.empty_tuple) self.assertEqual(outer.empty_dict, rebuilt_outer.empty_dict) self.assertEqual(outer.enum, rebuilt_outer.enum) self.assertEqual(outer.proto, rebuilt_outer.proto) self.assertEqual(outer.dataclass, rebuilt_outer.dataclass) self.assertEqual(outer.namedtuple, rebuilt_outer.namedtuple) with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {rebuilt_outer.symbol_x: 42}): self.assertEqual(symbolic.ToStatic(rebuilt_outer.symbol_2x), 84)
def testDecoderFPropWithSymbolicShape(self): """Create decoder with default params, and verify that FProp runs.""" with self.session() as sess: p = self._DecoderParams( vn_config=py_utils.VariationalNoiseParams( None, True, False, seed=12345)) p.rnn_cell_dim = symbolic.Symbol('rnn_cell_dim') with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {p.rnn_cell_dim: 6}): loss, per_sequence_loss = self._testDecoderFPropHelper(params=p) tf.global_variables_initializer().run() loss_val, per_sequence_loss_val = sess.run([loss, per_sequence_loss]) print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val) # Target batch size is 4. Therefore, we should expect 4 here. self.assertEqual(per_sequence_loss_val.shape, (4,))
def testGetSymbol(self): x = symbolic.Symbol('x') self.assertIsInstance(x, sympy.Expr)