def save_model(output_dir: str, estimator: tf.estimator.Estimator) -> None:
    def serving_input_receiver_fn() -> Any:
        embeddings = tf.placeholder(dtype=tf.float32,
                                    shape=[None, None, 768],
                                    name='embeddings')
        input_mask = tf.placeholder(dtype=tf.int32,
                                    shape=[None, None],
                                    name='input_mask')
        features = {'embeddings': embeddings, 'input_mask': input_mask}
        return tf.estimator.export.ServingInputReceiver(
            features=features, receiver_tensors=features)

    estimator.export_saved_model(output_dir, serving_input_receiver_fn)
Esempio n. 2
0
def train_estimator(estimator: tf.estimator.Estimator, input_config,
                    train_config, export_config, task_config: TrainTaskConfig):
    example_config = input_config['example_config']
    label_col = input_config['label_col']

    feature_specs = parse_feature_specs(example_config)
    dataset_fn = get_dataset_fn(feature_specs=feature_specs,
                                label_col=label_col,
                                **train_config)
    train_spec = tf.estimator.TrainSpec(dataset_fn)
    eval_spec = tf.estimator.EvalSpec(dataset_fn, steps=1)
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    feature_placeholders = get_feature_placeholders(**export_config)
    serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
        feature_placeholders)

    if (task_config.task_type == 'chief' and task_config.task_index == 0) or \
            (task_config.num_workers == 1):
        logging.info("Start exporting...")
        estimator.export_saved_model(
            task_config.saved_model_dir,
            serving_input_receiver_fn=serving_input_receiver_fn)
        logging.info("Finish exporting.")
Esempio n. 3
0
 def export(self, estimator: tf.estimator.Estimator):
     features = {field.name: field.as_placeholder(batch=True) for field in self.fields}
     return estimator.export_saved_model(
         self.path_saved_model, tf.estimator.export.build_raw_serving_input_receiver_fn(features)
     )