示例#1
0
    def initialize_predictors(self):
        """Initializes predictors based on the NN of the model. This allows to 
        achieve online prediction without loading an estimator repeatedly,
        which would largely slow down the process."""
        def serving_input_fn():
            x = tf.placeholder(dtype=tf.float32, name='x')
            inputs = {'x': x}
            return tf.estimator.export.ServingInputReceiver(inputs, inputs)

        self.pred12 = predictor.from_estimator(self.net12,
                                               serving_input_fn,
                                               graph=tf.Graph())
        self.pred24 = predictor.from_estimator(self.net24,
                                               serving_input_fn,
                                               graph=tf.Graph())
        self.pred48 = predictor.from_estimator(self.net48,
                                               serving_input_fn,
                                               graph=tf.Graph())
def create_predictor(estimator, serving_input_func):
    """Create an estimator for prediction.

  Args:
    estimator: TF estimator.
    serving_input_func: placeholder for input to model.

  Returns:
    Estimator for prediction
  """
    return contrib_predictor.from_estimator(estimator, serving_input_func)
示例#3
0
    def train(self, training_data, cfg, **kwargs):
        """Train this component."""

        # Clean up checkpoint
        if self.checkpoint_remove_before_training and os.path.exists(self.checkpoint_dir):
            shutil.rmtree(self.checkpoint_dir, ignore_errors=True)

        self.label_list = run_classifier.get_labels(training_data)

        run_config = tf.estimator.RunConfig(
            model_dir=self.checkpoint_dir,
            save_summary_steps=self.save_summary_steps,
            save_checkpoints_steps=self.save_checkpoints_steps)
        
        train_examples = run_classifier.get_train_examples(training_data.training_examples)
        num_train_steps = int(len(train_examples) / self.batch_size * self.epochs)
        num_warmup_steps = int(num_train_steps * self.warmup_proportion)

        tf.logging.info("***** Running training *****")
        tf.logging.info("Num examples = %d", len(train_examples))
        tf.logging.info("Batch size = %d", self.batch_size)
        tf.logging.info("Num steps = %d", num_train_steps)
        tf.logging.info("Num epochs = %d", self.epochs)

        model_fn = run_classifier.model_fn_builder(
            bert_tfhub_module_handle=self.bert_tfhub_module_handle,
            num_labels=len(self.label_list),
            learning_rate=self.learning_rate,
            num_train_steps=num_train_steps,
            num_warmup_steps=num_warmup_steps)
        
        self.estimator = tf.estimator.Estimator(
            model_fn=model_fn,
            config=run_config,
            params={"batch_size": self.batch_size})
        
        train_features = run_classifier.convert_examples_to_features(
            train_examples, self.label_list, self.max_seq_length, self.tokenizer)

        train_input_fn = run_classifier.input_fn_builder(
            features=train_features,
            seq_length=self.max_seq_length,
            is_training=True,
            drop_remainder=True)

        # Start training
        self.estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

        self.session = tf.Session()

        # Create predictor incase running evaluation
        self.predict_fn = predictor.from_estimator(self.estimator,
                                                   run_classifier.serving_input_fn_builder(self.max_seq_length))
示例#4
0
    model = LISAModel(hparams, model_config, layer_task_config, layer_attention_config, feature_idx_map, label_idx_map,
                      vocab)
    # ws = WarmStartSettings(ckpt_to_initialize_from=path,
    #                   vars_to_warm_start=".*")
    # print("debug <loading from>:", path)
    estimator = tf.estimator.Estimator(model_fn=model.model_fn, model_dir=path)
    return estimator

if args.ensemble:
  predict_fns = [predictor.from_saved_model("%s/%s" % (args.save_dir, subdir))
                 for subdir in util.get_immediate_subdirectories(args.save_dir)]
else:
  # predict_fns = [predictor.from_saved_model(args.save_dir)]
  estimator = constrcut_predictor(args.save_dir)
  # print("debug <converting estimator to predictor>")
  predict_fns = [predictor.from_estimator(estimator, serving_input_receiver_fn=train_utils.serving_input_receiver_fn)]


def dev_input_fn():
  return train_utils.get_input_fn(vocab, data_config, dev_filenames, hparams.batch_size, num_epochs=1, shuffle=False,
                                  embedding_files=embedding_files, is_token_based_batching = hparams.is_token_based_batching)


def eval_fn(input_op, sess, input_source):
  if args.eval_with_transformation:
    task_config['srl']['eval_fns']['srl_f1']['name'] = 'conll_srl_eval_with_transformation'
    pass

  eval_accumulators = eval_fns.get_accumulators(task_config)
  eval_results = OrderedDict({})
  i = 0