Beispiel #1
0
  def testTrainEval(self, strategy, interaction, use_orbit=True):
    # Set up simple trainer with synthetic data.
    # By default the mode must be `train_and_eval`.
    self.assertEqual(FLAGS.mode, 'train_and_eval')

    vocab_sizes = [40, 12, 11, 13]

    FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes,
                                                 interaction=interaction,
                                                 use_orbit=use_orbit,
                                                 strategy=strategy)
    train.main('unused_args')
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml')))
Beispiel #2
0
  def testTrainThenEval(self, strategy, interaction, use_orbit=True):
    # Set up simple trainer with synthetic data.
    vocab_sizes = [40, 12, 11, 13]

    FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes,
                                                 interaction=interaction,
                                                 use_orbit=use_orbit,
                                                 strategy=strategy)
    # Training.
    FLAGS.mode = 'train'
    train.main('unused_args')
    self.assertNotEmpty(
        tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml')))

    # Evaluation.
    FLAGS.mode = 'eval'
    train.main('unused_args')