예제 #1
0
    def testTrainerFn(self):
        temp_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        schema_file = os.path.join(self._testdata_path,
                                   'schema_gen/schema.pbtxt')
        trainer_fn_args = trainer_executor.TrainerFnArgs(
            train_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/train/*.gz'),
            transform_output=os.path.join(self._testdata_path,
                                          'transform/transform_output/'),
            serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'),
            eval_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/eval/*.gz'),
            schema_file=schema_file,
            train_steps=1,
            eval_steps=1,
            base_model=os.path.join(self._testdata_path,
                                    'trainer/current/serving_model_dir'),
            data_accessor=DataAccessor(tf_dataset_factory=tfxio_utils.
                                       get_tf_dataset_factory_from_artifact(
                                           [standard_artifacts.Examples()],
                                           []),
                                       record_batch_factory=None,
                                       data_view_decode_fn=None))
        schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
        training_spec = taxi_utils_bqml.trainer_fn(trainer_fn_args, schema)

        estimator = training_spec['estimator']
        train_spec = training_spec['train_spec']
        eval_spec = training_spec['eval_spec']
        eval_input_receiver_fn = training_spec['eval_input_receiver_fn']

        self.assertIsInstance(estimator, tf.estimator.Estimator)
        self.assertIsInstance(train_spec, tf.estimator.TrainSpec)
        self.assertIsInstance(eval_spec, tf.estimator.EvalSpec)
        self.assertIsInstance(eval_input_receiver_fn, types.FunctionType)

        # Train for one step, then eval for one step.
        eval_result, exports = tf.estimator.train_and_evaluate(
            estimator, train_spec, eval_spec)
        self.assertGreater(eval_result['loss'], 0.0)
        self.assertEqual(len(exports), 1)
        self.assertGreaterEqual(len(fileio.listdir(exports[0])), 1)

        # Export the eval saved model.
        eval_savedmodel_path = tfma.export.export_eval_savedmodel(
            estimator=estimator,
            export_dir_base=path_utils.eval_model_dir(temp_dir),
            eval_input_receiver_fn=eval_input_receiver_fn)
        self.assertGreaterEqual(len(fileio.listdir(eval_savedmodel_path)), 1)

        # Test exported serving graph.
        with tf.compat.v1.Session() as sess:
            metagraph_def = tf.compat.v1.saved_model.loader.load(
                sess, [tf.saved_model.SERVING], exports[0])
            self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef)
예제 #2
0
    def testTrainerFn(self):
        temp_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        schema_file = os.path.join(self._testdata_path,
                                   'schema_gen/schema.pbtxt')
        output_dir = os.path.join(temp_dir, 'output_dir')
        trainer_fn_args = trainer_executor.TrainerFnArgs(
            train_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/train/*.gz'),
            transform_output=os.path.join(self._testdata_path,
                                          'transform/transform_output'),
            output_dir=output_dir,
            serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'),
            eval_files=os.path.join(
                self._testdata_path,
                'transform/transformed_examples/eval/*.gz'),
            schema_file=schema_file,
            train_steps=1,
            eval_steps=1,
            verbosity='INFO',
            base_model=None)
        schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
        training_spec = taxi_utils.trainer_fn(trainer_fn_args, schema)

        estimator = training_spec['estimator']
        train_spec = training_spec['train_spec']
        eval_spec = training_spec['eval_spec']
        eval_input_receiver_fn = training_spec['eval_input_receiver_fn']

        self.assertIsInstance(estimator,
                              tf.estimator.DNNLinearCombinedClassifier)
        self.assertIsInstance(train_spec, tf.estimator.TrainSpec)
        self.assertIsInstance(eval_spec, tf.estimator.EvalSpec)
        self.assertIsInstance(eval_input_receiver_fn, types.FunctionType)

        # Train for one step, then eval for one step.
        eval_result, exports = tf.estimator.train_and_evaluate(
            estimator, train_spec, eval_spec)
        self.assertGreater(eval_result['loss'], 0.0)
        self.assertEqual(len(exports), 1)
        self.assertGreaterEqual(len(tf.io.gfile.listdir(exports[0])), 1)

        # Export the eval saved model.
        eval_savedmodel_path = tfma.export.export_eval_savedmodel(
            estimator=estimator,
            export_dir_base=path_utils.eval_model_dir(output_dir),
            eval_input_receiver_fn=eval_input_receiver_fn)
        self.assertGreaterEqual(len(tf.io.gfile.listdir(eval_savedmodel_path)),
                                1)

        # Test exported serving graph.
        with tf.compat.v1.Session() as sess:
            metagraph_def = tf.compat.v1.saved_model.loader.load(
                sess, [tf.saved_model.SERVING], exports[0])
            self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef)
예제 #3
0
 def testTrainerFn(self):
     trainer_fn_args = trainer_executor.TrainerFnArgs(
         train_files='/path/to/train.file',
         transform_output='/path/to/transform_output',
         serving_model_dir='/path/to/model_dir',
         eval_files='/path/to/eval.file',
         schema_file='/path/to/schema_file',
         train_steps=1000,
         eval_steps=100,
     )
     schema = schema_pb2.Schema()
     result = model._create_train_and_eval_spec(trainer_fn_args, schema)  # pylint: disable=protected-access
     self.assertIsInstance(result['estimator'], tf.estimator.Estimator)
     self.assertIsInstance(result['train_spec'], tf.estimator.TrainSpec)
     self.assertIsInstance(result['eval_spec'], tf.estimator.EvalSpec)
     self.assertTrue(callable(result['eval_input_receiver_fn']))
예제 #4
0
 def testTrainerFactory(self):
     bert_dir = 'fake/path'
     vocab_file_path = os.path.join(bert_dir, 'vocab.txt')
     bert_config = modeling.BertConfig.from_dict({
         "attention_probs_dropout_prob":
         0.1,
         "directionality":
         "bidi",
         "hidden_act":
         "gelu",
         "hidden_dropout_prob":
         0.1,
         "hidden_size":
         768,
         "initializer_range":
         0.02,
         "intermediate_size":
         3072,
         "max_position_embeddings":
         512,
         "num_attention_heads":
         12,
         "num_hidden_layers":
         12,
         "pooler_fc_size":
         768,
         "pooler_num_attention_heads":
         12,
         "pooler_num_fc_layers":
         3,
         "pooler_size_per_head":
         128,
         "pooler_type":
         "first_token_transform",
         "type_vocab_size":
         2,
         "vocab_size":
         119547
     })
     bert_checkpoint_dir = os.path.join(bert_dir, 'bert_checkpoint')
     trainer_fn = train.trainer_factory(
         batch_size=16,
         vocab_file_path=vocab_file_path,
         bert_config_path=bert_config,
         bert_checkpoint_dir=bert_checkpoint_dir,
         max_seq_length=8,
         learning_rate=2e-5,
         hidden_layer_dims=[2, 3],
         categorical_feature_keys=['cat_1', 'cat_2'],
         label_key='label',
         warmup_prop=0.1,
         cooldown_prop=0.0,
         non_string_keys=['cat_1'],
         warm_start_from=None,
         save_summary_steps=1000,
         save_checkpoints_secs=100,
         _test_mode=True)
     trainer_fn_args = trainer_executor.TrainerFnArgs(
         train_files='/path/to/train.file',
         transform_output='/path/to/transform_output',
         serving_model_dir='/path/to/model_dir',
         eval_files='/path/to/eval.file',
         schema_file='/path/to/schema_file',
         train_steps=1000,
         eval_steps=100,
     )
     schema = schema_pb2.Schema()
     result = trainer_fn(trainer_fn_args, schema)  # pylint: disable=protected-access
     self.assertIsInstance(result['estimator'], tf.estimator.Estimator)
     self.assertIsInstance(result['train_spec'], tf.estimator.TrainSpec)
     self.assertIsInstance(result['eval_spec'], tf.estimator.EvalSpec)
     self.assertTrue(callable(result['eval_input_receiver_fn']))