示例#1
0
    def testGetAllParams(self):
        m = seq2seq_model_helper.Seq2SeqModelHelper()

        p1 = m.AddParam('test_param1', init_value=1, trainable=True)
        p2 = m.AddParam('test_param2', init_value=2, trainable=False)

        self.assertEqual(m.GetAllParams(), [p1, p2])
    def testConstuctor(self):
        model_name = 'TestModel'
        m = seq2seq_model_helper.Seq2SeqModelHelper(name=model_name)

        self.assertEqual(m.name, model_name)
        self.assertEqual(m.init_params, True)

        self.assertEqual(m.arg_scope, {
            'use_cudnn': True,
            'cudnn_exhaustive_search': False,
            'order': 'NHWC'
        })
示例#3
0
    def testGetNonTrainableParams(self):
        m = seq2seq_model_helper.Seq2SeqModelHelper()

        m.AddParam('test_param1', init_value=1, trainable=True)
        p2 = m.AddParam('test_param2', init_value=2, trainable=False)

        self.assertEqual(m.GetNonTrainableParams(), [p2])

        with scope.NameScope('A', reset=True):
            p3 = m.AddParam('test_param3', init_value=3, trainable=False)
            self.assertEqual(m.GetNonTrainableParams(), [p3])

        self.assertEqual(m.GetNonTrainableParams(), [p2, p3])
    def testAddParam(self):
        m = seq2seq_model_helper.Seq2SeqModelHelper()

        param_name = 'test_param'
        param = m.AddParam(param_name, init_value=1)
        self.assertEqual(str(param), param_name)