def testTrainEval(self, strategy, interaction, use_orbit=True): # Set up simple trainer with synthetic data. # By default the mode must be `train_and_eval`. self.assertEqual(FLAGS.mode, 'train_and_eval') vocab_sizes = [40, 12, 11, 13] FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes, interaction=interaction, use_orbit=use_orbit, strategy=strategy) train.main('unused_args') self.assertNotEmpty( tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml')))
def testTrainThenEval(self, strategy, interaction, use_orbit=True): # Set up simple trainer with synthetic data. vocab_sizes = [40, 12, 11, 13] FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes, interaction=interaction, use_orbit=use_orbit, strategy=strategy) # Training. FLAGS.mode = 'train' train.main('unused_args') self.assertNotEmpty( tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml'))) # Evaluation. FLAGS.mode = 'eval' train.main('unused_args')