コード例 #1
0
ファイル: summary.py プロジェクト: tech-pi/dxlearn
 def summary(self, feeds=None):
     from ..scalar import current_step
     if self.as_tensor() is None:
         return
     value = get_default_session().run(self.as_tensor(),
                                       self.get_feed_dict(feeds))
     self.nodes['summary_writer'].add_summary(value, current_step())
コード例 #2
0
 def summary(self, feeds=None):
     from ..scalar import current_step
     mrf = self._get_multiple_run_feeds(feeds)
     feed_dict = self.get_feed_dict(feeds)
     feed_dict.update(mrf)
     value = get_default_session().run(self.nodes[SummaryKeys.MERGED],
                                       feed_dict=feed_dict)
     self.nodes['summary_writer'].add_summary(value, current_step())
コード例 #3
0
ファイル: saver.py プロジェクト: Hong-Xiang/dxl
 def save(self, feeds):
     from ..scalar import global_step
     if self._saver is None:
         self._saver = tf.train.Saver()
     from dxpy.learn.session import get_default_session
     sess = get_default_session()
     step = sess.run(global_step())
     print("[SAVE] model to: {}.".format(self._model_path()))
     self._saver.save(sess, self._model_path(), global_step=step)
コード例 #4
0
 def _get_multiple_run_feeds(self, feeds=None):
     result = dict()
     for i in range(self.nb_max_runs):
         current_result = get_default_session().run(
             self.inputs, self.get_feed_dict(feeds))
         for k in self.inputs:
             if k in self.nb_runs and self.nb_runs[k] > i:
                 result[self.multi_runs[k][i]] = current_result[k]
     return result
コード例 #5
0
ファイル: saver.py プロジェクト: Hong-Xiang/dxl
 def load(self, feeds):
     from ..scalar import global_global_step
     import sys
     if self._saver is None:
         self._saver = tf.train.Saver()
     from dxpy.learn.session import get_default_session
     sess = get_default_session()
     path_load, flag = self.__resolve_path_load(feeds)
     if flag is False:
         if isinstance(path_load, int):
             msg = "[ERROR][LOAD] Save for given step {} not found. Skip restore."
             print(msg.format(path_load), file=sys.stderr)
             return
         else:
             msg = "[ERROR][LOAD] Checkpoint file {} not found. Skip restore."
             print(msg.format(path_load), file=sys.stderr)
             return
     print("[LOAD] model from: {}.".format(path_load))
     self._saver.restore(sess, path_load)
コード例 #6
0
ファイル: summary.py プロジェクト: tech-pi/dxlearn
 def __create_writer(self, feeds):
     self.register_node(
         'summary_writer',
         tf.summary.FileWriter(self.param('path', feeds),
                               get_default_session().graph))
コード例 #7
0
 def set_value(self, feeds):
     from dxpy.learn.session import get_default_session
     get_default_session().run(self.assign_op,
                               feed_dict={self.nodes['new_value']: feeds})
コード例 #8
0
def get_value():
    from dxpy.learn.session import get_default_session
    return get_default_session().run(global_step())
コード例 #9
0
ファイル: trainer.py プロジェクト: Hong-Xiang/dxl
 def _train(self, feeds):
     sess = get_default_session()
     sess.run(self.as_tensor(), feed_dict=feeds)
コード例 #10
0
ファイル: base.py プロジェクト: Hong-Xiang/dxl
 def get_value(self):
     from dxpy.learn.session import get_default_session
     return get_default_session().run(self.as_tensor())
コード例 #11
0
 def session(self):
     from dxpy.learn.session import get_default_session
     return get_default_session()
コード例 #12
0
def get_value():
    from dxpy.learn.session import get_default_session
    return get_default_session().run(keep_prob())