Ejemplo n.º 1
0
    def persist(self, path: Text) -> None:
        """Persists the policy to a storage."""

        if self.session is None:
            logger.debug("Method `persist(...)` was called "
                         "without a trained model present. "
                         "Nothing to persist then!")
            return

        self.featurizer.persist(path)

        meta = {
            "priority": self.priority,
            "loss_type": self.loss_type,
            "ranking_length": self.ranking_length,
        }

        meta_file = os.path.join(path, "embedding_policy.json")
        rasa.utils.io.dump_obj_as_json_to_file(meta_file, meta)

        file_name = "tensorflow_embedding.ckpt"
        checkpoint = os.path.join(path, file_name)
        rasa.utils.io.create_directory_for_file(checkpoint)

        with self.graph.as_default():
            train_utils.persist_tensor("user_placeholder", self.a_in,
                                       self.graph)
            train_utils.persist_tensor("bot_placeholder", self.b_in,
                                       self.graph)

            train_utils.persist_tensor("similarity_all", self.sim_all,
                                       self.graph)
            train_utils.persist_tensor("pred_confidence", self.pred_confidence,
                                       self.graph)
            train_utils.persist_tensor("similarity", self.sim, self.graph)

            train_utils.persist_tensor("dial_embed", self.dial_embed,
                                       self.graph)
            train_utils.persist_tensor("bot_embed", self.bot_embed, self.graph)
            train_utils.persist_tensor("all_bot_embed", self.all_bot_embed,
                                       self.graph)

            train_utils.persist_tensor("attention_weights",
                                       self.attention_weights, self.graph)

            saver = tf.train.Saver()
            saver.save(self.session, checkpoint)

        with open(os.path.join(path, file_name + ".tf_config.pkl"), "wb") as f:
            pickle.dump(self._tf_config, f)
Ejemplo n.º 2
0
    def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]:
        """Persist this model into the passed directory.

        Return the metadata necessary to load the model again.
        """

        if self.session is None:
            return {"file": None}

        checkpoint = os.path.join(model_dir, file_name + ".ckpt")

        try:
            os.makedirs(os.path.dirname(checkpoint))
        except OSError as e:
            # be happy if someone already created the path
            import errno

            if e.errno != errno.EEXIST:
                raise
        with self.graph.as_default():
            train_utils.persist_tensor("batch_placeholder", self.batch_in,
                                       self.graph)

            train_utils.persist_tensor("similarity_all", self.sim_all,
                                       self.graph)
            train_utils.persist_tensor("pred_confidence", self.pred_confidence,
                                       self.graph)
            train_utils.persist_tensor("similarity", self.sim, self.graph)

            train_utils.persist_tensor("message_embed", self.message_embed,
                                       self.graph)
            train_utils.persist_tensor("label_embed", self.label_embed,
                                       self.graph)
            train_utils.persist_tensor("all_labels_embed",
                                       self.all_labels_embed, self.graph)

            saver = tf.train.Saver()
            saver.save(self.session, checkpoint)

        with open(os.path.join(model_dir, file_name + ".inv_label_dict.pkl"),
                  "wb") as f:
            pickle.dump(self.inverted_label_dict, f)

        with open(os.path.join(model_dir, file_name + ".tf_config.pkl"),
                  "wb") as f:
            pickle.dump(self._tf_config, f)

        with open(
                os.path.join(model_dir, file_name + ".batch_tuple_sizes.pkl"),
                "wb") as f:
            pickle.dump(self.batch_tuple_sizes, f)

        return {"file": file_name}