示例#1
0
    def train(self, training_data, cfg, **kwargs):
        """Train this component."""

        # Clean up checkpoint
        if self.checkpoint_remove_before_training and os.path.exists(self.checkpoint_dir):
            shutil.rmtree(self.checkpoint_dir, ignore_errors=True)

        self.label_list = run_classifier.get_labels(training_data)

        run_config = tf.estimator.RunConfig(
            model_dir=self.checkpoint_dir,
            save_summary_steps=self.save_summary_steps,
            save_checkpoints_steps=self.save_checkpoints_steps)
        
        train_examples = run_classifier.get_train_examples(training_data.training_examples)
        num_train_steps = int(len(train_examples) / self.batch_size * self.epochs)
        num_warmup_steps = int(num_train_steps * self.warmup_proportion)

        tf.logging.info("***** Running training *****")
        tf.logging.info("Num examples = %d", len(train_examples))
        tf.logging.info("Batch size = %d", self.batch_size)
        tf.logging.info("Num steps = %d", num_train_steps)
        tf.logging.info("Num epochs = %d", self.epochs)

        model_fn = run_classifier.model_fn_builder(
            bert_tfhub_module_handle=self.bert_tfhub_module_handle,
            num_labels=len(self.label_list),
            learning_rate=self.learning_rate,
            num_train_steps=num_train_steps,
            num_warmup_steps=num_warmup_steps)
        
        self.estimator = tf.estimator.Estimator(
            model_fn=model_fn,
            config=run_config,
            params={"batch_size": self.batch_size})
        
        train_features = run_classifier.convert_examples_to_features(
            train_examples, self.label_list, self.max_seq_length, self.tokenizer)

        train_input_fn = run_classifier.input_fn_builder(
            features=train_features,
            seq_length=self.max_seq_length,
            is_training=True,
            drop_remainder=True)

        # Start training
        self.estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

        self.session = tf.Session()

        # Create predictor incase running evaluation
        self.predict_fn = predictor.from_estimator(self.estimator,
                                                   run_classifier.serving_input_fn_builder(self.max_seq_length))
示例#2
0
    def process(self, message: Message, **kwargs: Any) -> None:
        """Return the most likely intent and its similarity to the input"""

        # Classifier needs this to be non empty, so we set to first label.
        message.data["intent"] = self.label_list[0]

        predict_examples = get_test_examples([message])
        predict_features = convert_examples_to_features(
            predict_examples, self.label_list, self.max_seq_length, self.tokenizer
        )

        # Get first index since we are only classifying text blob at a time.
        example = predict_features[0]

        result = self.predict_fn(
            {
                "input_ids": np.array(example.input_ids).reshape(
                    -1, self.max_seq_length
                ),
                "input_mask": np.array(example.input_mask).reshape(
                    -1, self.max_seq_length
                ),
                "label_ids": np.array(example.label_id).reshape(-1),
                "segment_ids": np.array(example.segment_ids).reshape(
                    -1, self.max_seq_length
                ),
            }
        )

        probabilities = list(np.exp(result["probabilities"])[0])

        with self.session.as_default():
            index = tf.argmax(probabilities, axis=0).eval(session=tf.Session())
            label = self.label_list[index]
            score = float(probabilities[index])

            intent = {"name": label, "confidence": score}
            intent_ranking = sorted(
                [
                    {"name": self.label_list[i], "confidence": float(score)}
                    for i, score in enumerate(probabilities)
                ],
                key=lambda k: k["confidence"],
                reverse=True,
            )

        message.set("intent", intent, add_to_output=True)
        message.set("intent_ranking", intent_ranking, add_to_output=True)