def _build_tf_train_graph( self, session_data: SessionDataType) -> Tuple["tf.Tensor", "tf.Tensor"]: # get in tensors from generator self.batch_in = self._iterator.get_next() # convert encoded all labels into the batch format label_batch = train_utils.prepare_batch(self._label_data) # convert batch format into sparse and dense tensors batch_data, _ = train_utils.batch_to_session_data( self.batch_in, session_data) label_data, _ = train_utils.batch_to_session_data( label_batch, self._label_data) a = self._combine_sparse_dense_features(batch_data["text_features"], batch_data["text_mask"][0], "text") b = self._combine_sparse_dense_features(batch_data["label_features"], batch_data["label_mask"][0], "label") all_bs = self._combine_sparse_dense_features( label_data["label_features"], label_data["label_mask"][0], "label") self.message_embed = self._create_tf_embed_fnn( a, self.hidden_layer_sizes["text"], fnn_name="text_label" if self.share_hidden_layers else "text", embed_name="text", ) self.label_embed = self._create_tf_embed_fnn( b, self.hidden_layer_sizes["label"], fnn_name="text_label" if self.share_hidden_layers else "label", embed_name="label", ) self.all_labels_embed = self._create_tf_embed_fnn( all_bs, self.hidden_layer_sizes["label"], fnn_name="text_label" if self.share_hidden_layers else "label", embed_name="label", ) return train_utils.calculate_loss_acc( self.message_embed, self.label_embed, b, self.all_labels_embed, all_bs, self.num_neg, None, self.loss_type, self.mu_pos, self.mu_neg, self.use_max_sim_neg, self.C_emb, self.scale_loss, )
def predict_label( self, message: "Message" ) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]: """Predicts the intent of the provided message.""" label = {"name": None, "confidence": 0.0} label_ranking = [] if self.session is None: logger.error("There is no trained tf.session: " "component is either not trained or " "didn't receive enough training data.") return label, label_ranking # create session data from message and convert it into a batch of 1 session_data = self._create_session_data([message]) # if no text-features are present (e.g. incoming message is not in the # vocab), do not predict a random intent if not self._text_features_present(session_data): return label, label_ranking batch = train_utils.prepare_batch(session_data, tuple_sizes=self.batch_tuple_sizes) # load tf graph and session label_ids, message_sim = self._calculate_message_sim(batch) # if X contains all zeros do not predict some label if label_ids.size > 0: label = { "name": self.inverted_label_dict[label_ids[0]], "confidence": message_sim[0], } if self.ranking_length and 0 < self.ranking_length < LABEL_RANKING_LENGTH: output_length = self.ranking_length else: output_length = LABEL_RANKING_LENGTH ranking = list(zip(list(label_ids), message_sim)) ranking = ranking[:output_length] label_ranking = [{ "name": self.inverted_label_dict[label_idx], "confidence": score } for label_idx, score in ranking] return label, label_ranking