def testInit(self): p = base_model.MultiTaskModel.Params() p.name = 'MultiTaskModel' p0 = BaseTaskTest.TestParams() p0.train.learner = (learner.Learner.Params().Set(name='loss')) p1 = BaseTaskTest.TestParams() p1.train.learner = (learner.Learner.Params().Set(name='loss')) p.input = base_model_params.MultiTaskModelParams().Train() p.input.Define( 'a', base_input_generator.BaseSequenceInputGenerator.Params(), '') p.input.Define( 'b', base_input_generator.BaseSequenceInputGenerator.Params(), '') p.task_params = hyperparams.Params() p.task_params.Define('a', p0, '') p.task_params.Define('b', p1, '') p.task_probs = hyperparams.Params() p.task_probs.Define('a', 0.5, '') p.task_probs.Define('b', 0.5, '') model = p.Instantiate() self.assertEqual(len(model.tasks), 2) self.assertEqual(set(model.task_names), {'a', 'b'}) self.assertEqual( set(model.tasks), {model.GetTask('a'), model.GetTask('b')}) self.assertEqual(model.params.task_params.a, model.GetTask('a').params) self.assertEqual(model.params.task_params.b, model.GetTask('b').params)
def testSharedEncoderDecoderModel(self): p = multitask_model.SharedEncoderDecoderModel.Params() p.name = 'test' p.encoder_to_share = 'p0' p.decoder_to_share = 'p0' p0 = MultiTaskModelTest._TestTask.Params() p1 = MultiTaskModelTest._TestTask.Params() p1.encoder = None p1.decoder = None p.input = base_model_params.MultiTaskModelParams().Train() p.input.Define('p0', base_input_generator.BaseInputGenerator.Params(), '') p.input.Define('p1', base_input_generator.BaseInputGenerator.Params(), '') p.task_params = hyperparams.Params() p.task_params.Define('p0', p0, '') p.task_params.Define('p1', p1, '') p.task_probs = hyperparams.Params() p.task_probs.Define('p0', 0.5, '') p.task_probs.Define('p1', 0.5, '') model = p.Instantiate() self.assertEqual(model.p0.encoder, model.p1.encoder) self.assertEqual(model.p0.decoder, model.p1.decoder)
def testRegExSharedVariableModel(self): p = multitask_model.RegExSharedVariableModel.Params() p.name = 'test' p.variable_renaming_rules = [('p./(.*)', 'shared/%s')] p0 = MultiTaskModelTest._TestTaskWithVars.Params() p1 = MultiTaskModelTest._TestTaskWithVars.Params() p.input = base_model_params.MultiTaskModelParams().Train() p.input.Define('p0', base_input_generator.BaseInputGenerator.Params(), '') p.input.Define('p1', base_input_generator.BaseInputGenerator.Params(), '') p.task_params = hyperparams.Params() p.task_params.Define('p0', p0, '') p.task_params.Define('p1', p1, '') p.task_probs = hyperparams.Params() p.task_probs.Define('p0', 0.5, '') p.task_probs.Define('p1', 0.5, '') model = p.Instantiate() all_vars = model.vars self.assertEqual('shared/weight/var:0', all_vars.p0.weight.name) self.assertEqual('shared/weight/var:0', all_vars.p1.weight.name)
def testGetDatasetParams_MultiTaskModelParams(self): dummy_model = base_model_params.MultiTaskModelParams() self.assertEqual(dummy_model.Train(), dummy_model.GetDatasetParams('Train')) self.assertEqual(dummy_model.Dev(), dummy_model.GetDatasetParams('Dev')) self.assertEqual(dummy_model.Test(), dummy_model.GetDatasetParams('Test')) with self.assertRaises(base_model_params.DatasetError): dummy_model.GetDatasetParams('Invalid')
def _setUpTestSampleTask(self): np.random.seed(_NUMPY_RANDOM_SEED) # define and initialize tasks, model and params p = base_model.MultiTaskModel.Params() p.name = 'MultiTaskModel' p0 = BaseTaskTest.TestParams() p1 = BaseTaskTest.TestParams() p.input = base_model_params.MultiTaskModelParams().Train() p.input.Define( 'a', base_input_generator.BaseSequenceInputGenerator.Params(), '') p.input.Define( 'b', base_input_generator.BaseSequenceInputGenerator.Params(), '') p.task_params = hyperparams.Params() p.task_params.Define('a', p0, '') p.task_params.Define('b', p1, '') return p
def testInitMissingInputParams(self): p = base_model.MultiTaskModel.Params() p.name = 'MultiTaskModel' p0 = BaseTaskTest.TestParams() p0.train.learner = (learner.Learner.Params().Set(name='loss')) p1 = BaseTaskTest.TestParams() p1.train.learner = (learner.Learner.Params().Set(name='loss')) p.input = base_model_params.MultiTaskModelParams().Train() p.input.Define( 'a', base_input_generator.BaseSequenceInputGenerator.Params(), '') p.task_params = hyperparams.Params() p.task_params.Define('a', p0, '') p.task_params.Define('b', p1, '') p.task_probs = hyperparams.Params() p.task_probs.Define('a', 0.5, '') p.task_probs.Define('b', 0.5, '') self.assertRaises(AttributeError, p.Instantiate)