コード例 #1
0
    def testSymbolToTensorMap(self):
        """Tests that cell_fn can rely on the contextual symbol-to-tensor map."""

        x = symbolic.Symbol('x')
        y = symbolic.Symbol('y')

        def PlusWXT(theta, state, inputs):
            """state.value += theta.w * x * inputs.t."""
            next_state = py_utils.NestedMap()
            x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
            next_state.value = state.value + theta.w * x_tensor * inputs.t
            return next_state, py_utils.NestedMap()

        def PlusWXTGrad(theta, state0, inputs, extras, dstate1):
            """Gradient function for PlusWXT."""
            del state0, extras
            x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
            dtheta = py_utils.NestedMap(w=dstate1.value * x_tensor * inputs.t)
            dstate0 = py_utils.NestedMap(value=dstate1.value)
            dinputs = py_utils.NestedMap(t=dstate1.value * theta.w * x_tensor)
            return dtheta, dstate0, dinputs, None

        with self.session() as sess:
            theta = py_utils.NestedMap(w=tf.constant(1., name='w'))
            state0 = py_utils.NestedMap(value=tf.constant(0., name='value'))
            inputs = py_utils.NestedMap(t=tf.constant([1., 2., 3.], name='t'))

            # With automatic cell_grad.
            with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {
                    x: tf.constant(7., name='x7'),
                    y: 8
            }):
                x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
                _, state1 = recurrent.Recurrent(theta, state0, inputs, PlusWXT)
                dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0]
                dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0]
                final_value, x_val, dx_val, dw_val = sess.run(
                    [state1.value, x_tensor, dx, dw])
            self.assertEqual(x_val, 7)
            self.assertEqual(final_value, x_val * (1. + 2. + 3.))
            self.assertEqual(dw_val, x_val * (1. + 2. + 3.))
            self.assertEqual(dx_val, (1. + 2. + 3.))

            # With manual cell_grad.
            with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES,
                                           {x: tf.constant(5., name='x5')}):
                x_tensor = symbolic.EvalExpr(symbolic.TENSOR_VALUES, x)
                _, state1 = recurrent.Recurrent(theta,
                                                state0,
                                                inputs,
                                                PlusWXT,
                                                cell_grad=PlusWXTGrad)
                dw = tf.gradients(ys=[state1.value], xs=[theta.w])[0]
                dx = tf.gradients(ys=[state1.value], xs=[x_tensor])[0]
                final_value, x_val, dx_val, dw_val = sess.run(
                    [state1.value, x_tensor, dx, dw])
            self.assertEqual(x_val, 5)
            self.assertEqual(final_value, x_val * (1. + 2. + 3.))
            self.assertEqual(dw_val, x_val * (1. + 2. + 3.))
            self.assertEqual(dx_val, (1. + 2. + 3.))
コード例 #2
0
  def testEvalExpr(self):
    x = symbolic.Symbol('x')
    y = symbolic.Symbol('y')
    xy = x * y

    # Without symbol-to-value map.
    self.assertEqual(xy, symbolic.ToStatic(xy))
    self.assertEqual(xy, symbolic.ToTensor(xy))

    with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 2, y: 3}):
      self.assertEqual(symbolic.ToStatic(xy), 6)
      # The inner map overrides the outer map.
      with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 5, y: 6}):
        self.assertEqual(symbolic.ToStatic(xy), 30)
      # Back to the outer map.
      self.assertEqual(symbolic.ToStatic(xy), 6)

    # EvalExpr can also evaluate a symbolic expression to a
    # Tensor.
    a = tf.placeholder(tf.float32)
    b = tf.placeholder(tf.float32)
    with symbolic.SymbolToValueMap(symbolic.TENSOR_VALUES, {x: a, y: b}):
      with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 2, y: 3}):
        # Value maps of different types do not affect each other.
        self.assertEqual(symbolic.ToStatic(xy), 6)
        ab = symbolic.ToTensor(xy)
        self.assertIsInstance(ab, tf.Tensor)
        with self.session() as sess:
          self.assertEqual(12, sess.run(ab, {a: 3, b: 4}))

      # EvalExpr supports partial evaluation.
      with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {y: 3}):
        x3 = symbolic.ToStatic(xy)
        with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {x: 9}):
          self.assertEqual(27, symbolic.ToStatic(x3))
コード例 #3
0
    def testSetRnnCellNodes(self):
        decoder_p = decoder.AsrDecoder.Params()
        base_rnn_p = rnn_cell.LSTMCellSimple.Params().Set(num_output_nodes=4)

        # rnn_cell_dim > 0.
        decoder_p.rnn_cell_dim = 8
        rnn_p = base_rnn_p.Copy()
        decoder_utils.SetRnnCellNodes(decoder_p, rnn_p)
        self.assertEqual(rnn_p.num_output_nodes, decoder_p.rnn_cell_dim)

        # rnn_cell_dim <= 0.
        decoder_p.rnn_cell_dim = 0
        rnn_p = base_rnn_p.Copy()
        decoder_utils.SetRnnCellNodes(decoder_p, rnn_p)
        self.assertEqual(rnn_p.num_output_nodes, base_rnn_p.num_output_nodes)

        # rnn_cell_dim is a symbol.
        decoder_p.rnn_cell_dim = symbolic.Symbol("rnn_cell_dim")
        rnn_p = base_rnn_p.Copy()
        decoder_utils.SetRnnCellNodes(decoder_p, rnn_p)
        self.assertIs(rnn_p.num_output_nodes, decoder_p.rnn_cell_dim)

        # rnn_cell_hidden_dim > 0.
        decoder_p.rnn_cell_hidden_dim = 16
        rnn_p = base_rnn_p.Copy()
        decoder_utils.SetRnnCellNodes(decoder_p, rnn_p)
        self.assertEqual(rnn_p.num_hidden_nodes, decoder_p.rnn_cell_hidden_dim)

        # rnn_cell_hidden_dim <= 0.
        decoder_p.rnn_cell_hidden_dim = 0
        rnn_p = base_rnn_p.Copy()
        decoder_utils.SetRnnCellNodes(decoder_p, rnn_p)
        self.assertEqual(rnn_p.num_hidden_nodes, base_rnn_p.num_hidden_nodes)

        # rnn_cell_hidden_dim is a symbol.
        decoder_p.rnn_cell_hidden_dim = symbolic.Symbol("rnn_cell_hidden_dim")
        rnn_p = base_rnn_p.Copy()
        decoder_utils.SetRnnCellNodes(decoder_p, rnn_p)
        self.assertIs(rnn_p.num_hidden_nodes, decoder_p.rnn_cell_hidden_dim)
コード例 #4
0
 def testDeepCopy(self):
   inner = hyperparams.Params()
   inner.Define('alpha', 2, '')
   inner.Define('tensor', tf.constant(0), '')
   inner.Define('symbol', symbolic.Symbol('symbol'), '')
   outer = hyperparams.Params()
   outer.Define('beta', 1, '')
   outer.Define('inner', inner, '')
   outer_copy = outer.Copy()
   self.assertIsNot(outer, outer_copy)
   self.assertEqual(outer, outer_copy)
   self.assertIsNot(outer.inner, outer_copy.inner)
   self.assertEqual(outer.inner, outer_copy.inner)
   self.assertEqual(outer.inner.alpha, outer_copy.inner.alpha)
   self.assertIs(outer.inner.tensor, outer_copy.inner.tensor)
   self.assertIs(outer.inner.symbol, outer_copy.inner.symbol)
コード例 #5
0
    def testToFromProto(self):
        outer = hyperparams.Params()
        outer.Define('integer_val', 1, '')
        outer.Define('cls_type', type(int), '')
        inner = hyperparams.Params()
        inner.Define('float_val', 2.71, '')
        inner.Define('string_val', 'rosalie et adrien', '')
        inner.Define('bool_val', True, '')
        inner.Define('list_of_tuples_of_dicts', [({'string_key': 1729})], '')
        inner.Define('range', range(1, 3), '')
        outer.Define('inner', inner, '')
        outer.Define('empty_list', [], '')
        outer.Define('empty_tuple', (), '')
        outer.Define('empty_dict', {}, '')
        outer.Define('enum', TestEnum.B, '')
        outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '')
        outer.Define('dataclass', TestDataClass(a=[42], b=tf.float32), '')
        outer.Define('namedtuple',
                     tf.io.FixedLenSequenceFeature([42], tf.float32), '')
        outer.Define('symbol_x', symbolic.Symbol('x'), '')
        outer.Define('symbol_2x', outer.symbol_x * 2, '')

        rebuilt_outer = hyperparams.InstantiableParams.FromProto(
            outer.ToProto())

        self.assertNotIn('cls', rebuilt_outer)
        self.assertEqual(outer.integer_val, rebuilt_outer.integer_val)
        self.assertEqual(outer.cls_type, rebuilt_outer.cls_type)
        self.assertNear(outer.inner.float_val, rebuilt_outer.inner.float_val,
                        1e-6)
        self.assertEqual(outer.inner.string_val,
                         rebuilt_outer.inner.string_val)
        self.assertEqual(outer.inner.bool_val, rebuilt_outer.inner.bool_val)
        self.assertEqual(outer.inner.list_of_tuples_of_dicts,
                         rebuilt_outer.inner.list_of_tuples_of_dicts)
        self.assertEqual([1, 2], rebuilt_outer.inner.range)  # Rebuilt as list.
        self.assertEqual(outer.empty_list, rebuilt_outer.empty_list)
        self.assertEqual(outer.empty_tuple, rebuilt_outer.empty_tuple)
        self.assertEqual(outer.empty_dict, rebuilt_outer.empty_dict)
        self.assertEqual(outer.enum, rebuilt_outer.enum)
        self.assertEqual(outer.proto, rebuilt_outer.proto)
        self.assertEqual(outer.dataclass, rebuilt_outer.dataclass)
        self.assertEqual(outer.namedtuple, rebuilt_outer.namedtuple)

        with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES,
                                       {rebuilt_outer.symbol_x: 42}):
            self.assertEqual(symbolic.ToStatic(rebuilt_outer.symbol_2x), 84)
コード例 #6
0
ファイル: decoder_test.py プロジェクト: ruby11dog/lingvo
  def testDecoderFPropWithSymbolicShape(self):
    """Create decoder with default params, and verify that FProp runs."""
    with self.session() as sess:
      p = self._DecoderParams(
          vn_config=py_utils.VariationalNoiseParams(
              None, True, False, seed=12345))
      p.rnn_cell_dim = symbolic.Symbol('rnn_cell_dim')

      with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES,
                                     {p.rnn_cell_dim: 6}):
        loss, per_sequence_loss = self._testDecoderFPropHelper(params=p)
        tf.global_variables_initializer().run()
        loss_val, per_sequence_loss_val = sess.run([loss, per_sequence_loss])

      print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val)
      # Target batch size is 4. Therefore, we should expect 4 here.
      self.assertEqual(per_sequence_loss_val.shape, (4,))
コード例 #7
0
 def testGetSymbol(self):
   x = symbolic.Symbol('x')
   self.assertIsInstance(x, sympy.Expr)