def persist(self, path: Text) -> None: """Persists the policy to a storage.""" if self.session is None: logger.debug("Method `persist(...)` was called " "without a trained model present. " "Nothing to persist then!") return self.featurizer.persist(path) meta = { "priority": self.priority, "loss_type": self.loss_type, "ranking_length": self.ranking_length, } meta_file = os.path.join(path, "embedding_policy.json") rasa.utils.io.dump_obj_as_json_to_file(meta_file, meta) file_name = "tensorflow_embedding.ckpt" checkpoint = os.path.join(path, file_name) rasa.utils.io.create_directory_for_file(checkpoint) with self.graph.as_default(): train_utils.persist_tensor("user_placeholder", self.a_in, self.graph) train_utils.persist_tensor("bot_placeholder", self.b_in, self.graph) train_utils.persist_tensor("similarity_all", self.sim_all, self.graph) train_utils.persist_tensor("pred_confidence", self.pred_confidence, self.graph) train_utils.persist_tensor("similarity", self.sim, self.graph) train_utils.persist_tensor("dial_embed", self.dial_embed, self.graph) train_utils.persist_tensor("bot_embed", self.bot_embed, self.graph) train_utils.persist_tensor("all_bot_embed", self.all_bot_embed, self.graph) train_utils.persist_tensor("attention_weights", self.attention_weights, self.graph) saver = tf.train.Saver() saver.save(self.session, checkpoint) with open(os.path.join(path, file_name + ".tf_config.pkl"), "wb") as f: pickle.dump(self._tf_config, f)
def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]: """Persist this model into the passed directory. Return the metadata necessary to load the model again. """ if self.session is None: return {"file": None} checkpoint = os.path.join(model_dir, file_name + ".ckpt") try: os.makedirs(os.path.dirname(checkpoint)) except OSError as e: # be happy if someone already created the path import errno if e.errno != errno.EEXIST: raise with self.graph.as_default(): train_utils.persist_tensor("batch_placeholder", self.batch_in, self.graph) train_utils.persist_tensor("similarity_all", self.sim_all, self.graph) train_utils.persist_tensor("pred_confidence", self.pred_confidence, self.graph) train_utils.persist_tensor("similarity", self.sim, self.graph) train_utils.persist_tensor("message_embed", self.message_embed, self.graph) train_utils.persist_tensor("label_embed", self.label_embed, self.graph) train_utils.persist_tensor("all_labels_embed", self.all_labels_embed, self.graph) saver = tf.train.Saver() saver.save(self.session, checkpoint) with open(os.path.join(model_dir, file_name + ".inv_label_dict.pkl"), "wb") as f: pickle.dump(self.inverted_label_dict, f) with open(os.path.join(model_dir, file_name + ".tf_config.pkl"), "wb") as f: pickle.dump(self._tf_config, f) with open( os.path.join(model_dir, file_name + ".batch_tuple_sizes.pkl"), "wb") as f: pickle.dump(self.batch_tuple_sizes, f) return {"file": file_name}