コード例 #1
0
    def _build_tf_train_graph(
            self,
            session_data: SessionDataType) -> Tuple["tf.Tensor", "tf.Tensor"]:

        # get in tensors from generator
        self.batch_in = self._iterator.get_next()
        # convert encoded all labels into the batch format
        label_batch = train_utils.prepare_batch(self._label_data)

        # convert batch format into sparse and dense tensors
        batch_data, _ = train_utils.batch_to_session_data(
            self.batch_in, session_data)
        label_data, _ = train_utils.batch_to_session_data(
            label_batch, self._label_data)

        a = self._combine_sparse_dense_features(batch_data["text_features"],
                                                batch_data["text_mask"][0],
                                                "text")
        b = self._combine_sparse_dense_features(batch_data["label_features"],
                                                batch_data["label_mask"][0],
                                                "label")
        all_bs = self._combine_sparse_dense_features(
            label_data["label_features"], label_data["label_mask"][0], "label")

        self.message_embed = self._create_tf_embed_fnn(
            a,
            self.hidden_layer_sizes["text"],
            fnn_name="text_label" if self.share_hidden_layers else "text",
            embed_name="text",
        )
        self.label_embed = self._create_tf_embed_fnn(
            b,
            self.hidden_layer_sizes["label"],
            fnn_name="text_label" if self.share_hidden_layers else "label",
            embed_name="label",
        )
        self.all_labels_embed = self._create_tf_embed_fnn(
            all_bs,
            self.hidden_layer_sizes["label"],
            fnn_name="text_label" if self.share_hidden_layers else "label",
            embed_name="label",
        )

        return train_utils.calculate_loss_acc(
            self.message_embed,
            self.label_embed,
            b,
            self.all_labels_embed,
            all_bs,
            self.num_neg,
            None,
            self.loss_type,
            self.mu_pos,
            self.mu_neg,
            self.use_max_sim_neg,
            self.C_emb,
            self.scale_loss,
        )
コード例 #2
0
    def _build_tf_pred_graph(self,
                             session_data: "SessionDataType") -> "tf.Tensor":

        shapes, types = train_utils.get_shapes_types(session_data)

        batch_placeholder = []
        for s, t in zip(shapes, types):
            batch_placeholder.append(tf.placeholder(t, s))

        self.batch_in = tf.tuple(batch_placeholder)

        batch_data, self.batch_tuple_sizes = train_utils.batch_to_session_data(
            self.batch_in, session_data)

        a = self._combine_sparse_dense_features(batch_data["text_features"],
                                                batch_data["text_mask"][0],
                                                "text")
        b = self._combine_sparse_dense_features(batch_data["label_features"],
                                                batch_data["label_mask"][0],
                                                "label")

        self.all_labels_embed = tf.constant(
            self.session.run(self.all_labels_embed))

        self.message_embed = self._create_tf_embed_fnn(
            a,
            self.hidden_layer_sizes["text"],
            fnn_name="text_label" if self.share_hidden_layers else "text",
            embed_name="text",
        )

        self.sim_all = train_utils.tf_raw_sim(
            self.message_embed[:, tf.newaxis, :],
            self.all_labels_embed[tf.newaxis, :, :],
            None,
        )

        self.label_embed = self._create_tf_embed_fnn(
            b,
            self.hidden_layer_sizes["label"],
            fnn_name="text_label" if self.share_hidden_layers else "label",
            embed_name="label",
        )

        self.sim = train_utils.tf_raw_sim(self.message_embed[:, tf.newaxis, :],
                                          self.label_embed, None)

        return train_utils.confidence_from_sim(self.sim_all,
                                               self.similarity_type)