def _build_tf_train_graph( self, session_data: SessionDataType) -> Tuple["tf.Tensor", "tf.Tensor"]: # get in tensors from generator self.batch_in = self._iterator.get_next() # convert encoded all labels into the batch format label_batch = train_utils.prepare_batch(self._label_data) # convert batch format into sparse and dense tensors batch_data, _ = train_utils.batch_to_session_data( self.batch_in, session_data) label_data, _ = train_utils.batch_to_session_data( label_batch, self._label_data) a = self._combine_sparse_dense_features(batch_data["text_features"], batch_data["text_mask"][0], "text") b = self._combine_sparse_dense_features(batch_data["label_features"], batch_data["label_mask"][0], "label") all_bs = self._combine_sparse_dense_features( label_data["label_features"], label_data["label_mask"][0], "label") self.message_embed = self._create_tf_embed_fnn( a, self.hidden_layer_sizes["text"], fnn_name="text_label" if self.share_hidden_layers else "text", embed_name="text", ) self.label_embed = self._create_tf_embed_fnn( b, self.hidden_layer_sizes["label"], fnn_name="text_label" if self.share_hidden_layers else "label", embed_name="label", ) self.all_labels_embed = self._create_tf_embed_fnn( all_bs, self.hidden_layer_sizes["label"], fnn_name="text_label" if self.share_hidden_layers else "label", embed_name="label", ) return train_utils.calculate_loss_acc( self.message_embed, self.label_embed, b, self.all_labels_embed, all_bs, self.num_neg, None, self.loss_type, self.mu_pos, self.mu_neg, self.use_max_sim_neg, self.C_emb, self.scale_loss, )
def _build_tf_pred_graph(self, session_data: "SessionDataType") -> "tf.Tensor": shapes, types = train_utils.get_shapes_types(session_data) batch_placeholder = [] for s, t in zip(shapes, types): batch_placeholder.append(tf.placeholder(t, s)) self.batch_in = tf.tuple(batch_placeholder) batch_data, self.batch_tuple_sizes = train_utils.batch_to_session_data( self.batch_in, session_data) a = self._combine_sparse_dense_features(batch_data["text_features"], batch_data["text_mask"][0], "text") b = self._combine_sparse_dense_features(batch_data["label_features"], batch_data["label_mask"][0], "label") self.all_labels_embed = tf.constant( self.session.run(self.all_labels_embed)) self.message_embed = self._create_tf_embed_fnn( a, self.hidden_layer_sizes["text"], fnn_name="text_label" if self.share_hidden_layers else "text", embed_name="text", ) self.sim_all = train_utils.tf_raw_sim( self.message_embed[:, tf.newaxis, :], self.all_labels_embed[tf.newaxis, :, :], None, ) self.label_embed = self._create_tf_embed_fnn( b, self.hidden_layer_sizes["label"], fnn_name="text_label" if self.share_hidden_layers else "label", embed_name="label", ) self.sim = train_utils.tf_raw_sim(self.message_embed[:, tf.newaxis, :], self.label_embed, None) return train_utils.confidence_from_sim(self.sim_all, self.similarity_type)