示例#1
0
    def _build_tf_train_graph(
            self,
            session_data: SessionDataType) -> Tuple["tf.Tensor", "tf.Tensor"]:

        # get in tensors from generator
        self.batch_in = self._iterator.get_next()
        # convert encoded all labels into the batch format
        label_batch = train_utils.prepare_batch(self._label_data)

        # convert batch format into sparse and dense tensors
        batch_data, _ = train_utils.batch_to_session_data(
            self.batch_in, session_data)
        label_data, _ = train_utils.batch_to_session_data(
            label_batch, self._label_data)

        a = self._combine_sparse_dense_features(batch_data["text_features"],
                                                batch_data["text_mask"][0],
                                                "text")
        b = self._combine_sparse_dense_features(batch_data["label_features"],
                                                batch_data["label_mask"][0],
                                                "label")
        all_bs = self._combine_sparse_dense_features(
            label_data["label_features"], label_data["label_mask"][0], "label")

        self.message_embed = self._create_tf_embed_fnn(
            a,
            self.hidden_layer_sizes["text"],
            fnn_name="text_label" if self.share_hidden_layers else "text",
            embed_name="text",
        )
        self.label_embed = self._create_tf_embed_fnn(
            b,
            self.hidden_layer_sizes["label"],
            fnn_name="text_label" if self.share_hidden_layers else "label",
            embed_name="label",
        )
        self.all_labels_embed = self._create_tf_embed_fnn(
            all_bs,
            self.hidden_layer_sizes["label"],
            fnn_name="text_label" if self.share_hidden_layers else "label",
            embed_name="label",
        )

        return train_utils.calculate_loss_acc(
            self.message_embed,
            self.label_embed,
            b,
            self.all_labels_embed,
            all_bs,
            self.num_neg,
            None,
            self.loss_type,
            self.mu_pos,
            self.mu_neg,
            self.use_max_sim_neg,
            self.C_emb,
            self.scale_loss,
        )
    def predict_label(
            self, message: "Message"
    ) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]:
        """Predicts the intent of the provided message."""

        label = {"name": None, "confidence": 0.0}
        label_ranking = []

        if self.session is None:
            logger.error("There is no trained tf.session: "
                         "component is either not trained or "
                         "didn't receive enough training data.")
            return label, label_ranking

        # create session data from message and convert it into a batch of 1
        session_data = self._create_session_data([message])

        # if no text-features are present (e.g. incoming message is not in the
        # vocab), do not predict a random intent
        if not self._text_features_present(session_data):
            return label, label_ranking

        batch = train_utils.prepare_batch(session_data,
                                          tuple_sizes=self.batch_tuple_sizes)

        # load tf graph and session
        label_ids, message_sim = self._calculate_message_sim(batch)

        # if X contains all zeros do not predict some label
        if label_ids.size > 0:
            label = {
                "name": self.inverted_label_dict[label_ids[0]],
                "confidence": message_sim[0],
            }

            if self.ranking_length and 0 < self.ranking_length < LABEL_RANKING_LENGTH:
                output_length = self.ranking_length
            else:
                output_length = LABEL_RANKING_LENGTH

            ranking = list(zip(list(label_ids), message_sim))
            ranking = ranking[:output_length]
            label_ranking = [{
                "name": self.inverted_label_dict[label_idx],
                "confidence": score
            } for label_idx, score in ranking]

        return label, label_ranking