Esempio n. 1
0
 def before_run(self, run_context):
     self._request_summary = (self._next_step is None
                              or self._timer.should_trigger_for_step(
                                  self._next_step))
     requests = {}  #{"global_step": self._global_step_tensor}
     opts = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
     return SessionRunArgs(requests, options=opts)
Esempio n. 2
0
 def before_run(self, _):
     # cprnt(train="BEFORE Comet Histograms")
     return SessionRunArgs(
         fetches={
             "global_step": tf.get_collection(tf.GraphKeys.GLOBAL_STEP),
             "trainables": self.trainables,
         })
Esempio n. 3
0
 def before_run(self, _):
     # cprnt(**{self.mode: "BEFORE Console Logger"})
     fetches = {
         "global_step": tf.get_collection(tf.GraphKeys.GLOBAL_STEP),
         **(self.tensors if self.mode == ModeKeys.TRAIN else {}),
     }
     return SessionRunArgs(fetches=fetches)
Esempio n. 4
0
 def before_run(self, _):
     return SessionRunArgs(  # noqa
         fetches={
             "global_step": tf.get_collection(tf.GraphKeys.GLOBAL_STEP),
             "attention": tf.get_collection("ATTENTION"),
             "targets": self.targets,
             "labels": self.labels,
             "predictions": self.predictions,
         }
     )
Esempio n. 5
0
 def before_run(self, _):
     # cprnt(train="BEFORE Summary Saving ({})".format(self._global_step))
     this_step = self._global_step + 1
     fetches = {
         "global_step":
         tf.get_collection(tf.GraphKeys.GLOBAL_STEP),
         **({
             "summary": self.summary_op
         } if this_step % self.summary_freq == 0 or this_step == 1 else {}),
     }
     return SessionRunArgs(fetches=fetches)
Esempio n. 6
0
 def before_run(self, _):
     # cprnt(**{self.mode: "BEFORE Metadata"})
     this_step = self.counter + 1
     log_metadata = (self.mode == ModeKeys.EVAL and self.freq == "once"
                     if isinstance(self.freq, str) else this_step == 1
                     or this_step % self.freq == 0)
     return SessionRunArgs(
         fetches={
             "global_step": tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
         },
         options=(
             tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE  # pylint: disable=E1101
                           ) if log_metadata else None),
     )
Esempio n. 7
0
 def before_run(self, run_context):
     return SessionRunArgs(self._global_step_tensor)
Esempio n. 8
0
 def before_run(self, _):
     # cprnt(**{self.mode: "BEFORE Comet Progress"})
     return SessionRunArgs(
         fetches={
             "global_step": tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
         })