示例#1
0
    def testInit(self):
        p = base_model.MultiTaskModel.Params()
        p.name = 'MultiTaskModel'
        p0 = BaseTaskTest.TestParams()
        p0.train.learner = (learner.Learner.Params().Set(name='loss'))
        p1 = BaseTaskTest.TestParams()
        p1.train.learner = (learner.Learner.Params().Set(name='loss'))

        p.input = base_model_params.MultiTaskModelParams().Train()
        p.input.Define(
            'a', base_input_generator.BaseSequenceInputGenerator.Params(), '')
        p.input.Define(
            'b', base_input_generator.BaseSequenceInputGenerator.Params(), '')

        p.task_params = hyperparams.Params()
        p.task_params.Define('a', p0, '')
        p.task_params.Define('b', p1, '')

        p.task_probs = hyperparams.Params()
        p.task_probs.Define('a', 0.5, '')
        p.task_probs.Define('b', 0.5, '')

        model = p.Instantiate()
        self.assertEqual(len(model.tasks), 2)
        self.assertEqual(set(model.task_names), {'a', 'b'})
        self.assertEqual(
            set(model.tasks),
            {model.GetTask('a'), model.GetTask('b')})
        self.assertEqual(model.params.task_params.a, model.GetTask('a').params)
        self.assertEqual(model.params.task_params.b, model.GetTask('b').params)
示例#2
0
    def testSharedEncoderDecoderModel(self):
        p = multitask_model.SharedEncoderDecoderModel.Params()
        p.name = 'test'
        p.encoder_to_share = 'p0'
        p.decoder_to_share = 'p0'

        p0 = MultiTaskModelTest._TestTask.Params()
        p1 = MultiTaskModelTest._TestTask.Params()
        p1.encoder = None
        p1.decoder = None

        p.input = base_model_params.MultiTaskModelParams().Train()
        p.input.Define('p0', base_input_generator.BaseInputGenerator.Params(),
                       '')
        p.input.Define('p1', base_input_generator.BaseInputGenerator.Params(),
                       '')
        p.task_params = hyperparams.Params()
        p.task_params.Define('p0', p0, '')
        p.task_params.Define('p1', p1, '')
        p.task_probs = hyperparams.Params()
        p.task_probs.Define('p0', 0.5, '')
        p.task_probs.Define('p1', 0.5, '')

        model = p.Instantiate()
        self.assertEqual(model.p0.encoder, model.p1.encoder)
        self.assertEqual(model.p0.decoder, model.p1.decoder)
示例#3
0
    def testRegExSharedVariableModel(self):
        p = multitask_model.RegExSharedVariableModel.Params()
        p.name = 'test'
        p.variable_renaming_rules = [('p./(.*)', 'shared/%s')]

        p0 = MultiTaskModelTest._TestTaskWithVars.Params()
        p1 = MultiTaskModelTest._TestTaskWithVars.Params()

        p.input = base_model_params.MultiTaskModelParams().Train()
        p.input.Define('p0', base_input_generator.BaseInputGenerator.Params(),
                       '')
        p.input.Define('p1', base_input_generator.BaseInputGenerator.Params(),
                       '')

        p.task_params = hyperparams.Params()
        p.task_params.Define('p0', p0, '')
        p.task_params.Define('p1', p1, '')
        p.task_probs = hyperparams.Params()

        p.task_probs.Define('p0', 0.5, '')
        p.task_probs.Define('p1', 0.5, '')

        model = p.Instantiate()
        all_vars = model.vars
        self.assertEqual('shared/weight/var:0', all_vars.p0.weight.name)
        self.assertEqual('shared/weight/var:0', all_vars.p1.weight.name)
 def testGetDatasetParams_MultiTaskModelParams(self):
   dummy_model = base_model_params.MultiTaskModelParams()
   self.assertEqual(dummy_model.Train(), dummy_model.GetDatasetParams('Train'))
   self.assertEqual(dummy_model.Dev(), dummy_model.GetDatasetParams('Dev'))
   self.assertEqual(dummy_model.Test(), dummy_model.GetDatasetParams('Test'))
   with self.assertRaises(base_model_params.DatasetError):
     dummy_model.GetDatasetParams('Invalid')
示例#5
0
    def _setUpTestSampleTask(self):
        np.random.seed(_NUMPY_RANDOM_SEED)

        # define and initialize tasks, model and params
        p = base_model.MultiTaskModel.Params()
        p.name = 'MultiTaskModel'
        p0 = BaseTaskTest.TestParams()
        p1 = BaseTaskTest.TestParams()

        p.input = base_model_params.MultiTaskModelParams().Train()
        p.input.Define(
            'a', base_input_generator.BaseSequenceInputGenerator.Params(), '')
        p.input.Define(
            'b', base_input_generator.BaseSequenceInputGenerator.Params(), '')

        p.task_params = hyperparams.Params()
        p.task_params.Define('a', p0, '')
        p.task_params.Define('b', p1, '')

        return p
示例#6
0
    def testInitMissingInputParams(self):
        p = base_model.MultiTaskModel.Params()
        p.name = 'MultiTaskModel'
        p0 = BaseTaskTest.TestParams()
        p0.train.learner = (learner.Learner.Params().Set(name='loss'))
        p1 = BaseTaskTest.TestParams()
        p1.train.learner = (learner.Learner.Params().Set(name='loss'))

        p.input = base_model_params.MultiTaskModelParams().Train()
        p.input.Define(
            'a', base_input_generator.BaseSequenceInputGenerator.Params(), '')

        p.task_params = hyperparams.Params()
        p.task_params.Define('a', p0, '')
        p.task_params.Define('b', p1, '')

        p.task_probs = hyperparams.Params()
        p.task_probs.Define('a', 0.5, '')
        p.task_probs.Define('b', 0.5, '')
        self.assertRaises(AttributeError, p.Instantiate)