Exemplo n.º 1
0
  def _write_custom_summaries(self, step, logs=None):
    """Writes metrics out as custom scalar summaries.

    Arguments:
        step: the global step to use for TensorBoard.
        logs: dict. Keys are scalar summary names, values are
            NumPy scalars.

    """
    logs = logs or {}
    if context.executing_eagerly():
      # use v2 summary ops
      with self.writer.as_default(), summary_ops_v2.always_record_summaries():
        for name, value in logs.items():
          if isinstance(value, np.ndarray):
            value = value.item()
          summary_ops_v2.scalar(name, value, step=step)
    else:
      # use FileWriter from v1 summary
      for name, value in logs.items():
        if isinstance(value, np.ndarray):
          value = value.item()
        summary = tf_summary.Summary()
        summary_value = summary.value.add()
        summary_value.simple_value = value
        summary_value.tag = name
        self.writer.add_summary(summary, step)
    self.writer.flush()
Exemplo n.º 2
0
 def testWriterInitAndClose(self):
   logdir = self.get_temp_dir()
   get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
   with summary_ops.always_record_summaries():
     writer = summary_ops.create_file_writer(
         logdir, max_queue=100, flush_millis=1000000)
     self.assertEqual(1, get_total())  # file_version Event
     # Calling init() again while writer is open has no effect
     writer.init()
     self.assertEqual(1, get_total())
     try:
       # Not using .as_default() to avoid implicit flush when exiting
       writer.set_as_default()
       summary_ops.scalar('one', 1.0, step=1)
       self.assertEqual(1, get_total())
       # Calling .close() should do an implicit flush
       writer.close()
       self.assertEqual(2, get_total())
       # Calling init() on a closed writer should start a new file
       time.sleep(1.1)  # Ensure filename has a different timestamp
       writer.init()
       files = sorted(gfile.Glob(os.path.join(logdir, '*tfevents*')))
       self.assertEqual(2, len(files))
       get_total = lambda: len(summary_test_util.events_from_file(files[1]))
       self.assertEqual(1, get_total())  # file_version Event
       summary_ops.scalar('two', 2.0, step=2)
       writer.close()
       self.assertEqual(2, get_total())
     finally:
       # Clean up by resetting default writer
       summary_ops.create_file_writer(None).set_as_default()
Exemplo n.º 3
0
 def testWriterInitAndClose(self):
   logdir = self.get_temp_dir()
   with summary_ops.always_record_summaries():
     writer = summary_ops.create_file_writer(
         logdir, max_queue=100, flush_millis=1000000)
     with writer.as_default():
       summary_ops.scalar('one', 1.0, step=1)
   with self.cached_session() as sess:
     sess.run(summary_ops.summary_writer_initializer_op())
     get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
     self.assertEqual(1, get_total())  # file_version Event
     # Running init() again while writer is open has no effect
     sess.run(writer.init())
     self.assertEqual(1, get_total())
     sess.run(summary_ops.all_summary_ops())
     self.assertEqual(1, get_total())
     # Running close() should do an implicit flush
     sess.run(writer.close())
     self.assertEqual(2, get_total())
     # Running init() on a closed writer should start a new file
     time.sleep(1.1)  # Ensure filename has a different timestamp
     sess.run(writer.init())
     sess.run(summary_ops.all_summary_ops())
     sess.run(writer.close())
     files = sorted(gfile.Glob(os.path.join(logdir, '*tfevents*')))
     self.assertEqual(2, len(files))
     self.assertEqual(2, len(summary_test_util.events_from_file(files[1])))
Exemplo n.º 4
0
 def testEagerMemory(self):
   training_util.get_or_create_global_step()
   logdir = self.get_temp_dir()
   with summary_ops.create_file_writer(
       logdir, max_queue=0,
       name='t0').as_default(), summary_ops.always_record_summaries():
     summary_ops.generic('tensor', 1, '')
     summary_ops.scalar('scalar', 2.0)
     summary_ops.histogram('histogram', [1.0])
     summary_ops.image('image', [[[[1.0]]]])
     summary_ops.audio('audio', [[1.0]], 1.0, 1)
Exemplo n.º 5
0
 def testSummaryName(self):
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(logdir, max_queue=0)
   with writer.as_default(), summary_ops.always_record_summaries():
     summary_ops.scalar('scalar', 2.0, step=1)
   with self.cached_session() as sess:
     sess.run(summary_ops.summary_writer_initializer_op())
     sess.run(summary_ops.all_summary_ops())
   events = summary_test_util.events_from_logdir(logdir)
   self.assertEqual(2, len(events))
   self.assertEqual('scalar', events[1].summary.value[0].tag)
Exemplo n.º 6
0
  def testSummaryGlobalStep(self):
    step = training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    with summary_ops.create_file_writer(
        logdir, max_queue=0,
        name='t2').as_default(), summary_ops.always_record_summaries():

      summary_ops.scalar('scalar', 2.0, step=step)

      events = summary_test_util.events_from_logdir(logdir)
      self.assertEqual(len(events), 2)
      self.assertEqual(events[1].summary.value[0].tag, 'scalar')
Exemplo n.º 7
0
 def testMaxQueue(self):
   logs = tempfile.mkdtemp()
   with summary_ops.create_file_writer(
       logs, max_queue=1, flush_millis=999999,
       name='lol').as_default(), summary_ops.always_record_summaries():
     get_total = lambda: len(summary_test_util.events_from_logdir(logs))
     # Note: First tf.Event is always file_version.
     self.assertEqual(1, get_total())
     summary_ops.scalar('scalar', 2.0, step=1)
     self.assertEqual(1, get_total())
     # Should flush after second summary since max_queue = 1
     summary_ops.scalar('scalar', 2.0, step=2)
     self.assertEqual(3, get_total())
Exemplo n.º 8
0
 def testSummaryGlobalStep(self):
   training_util.get_or_create_global_step()
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(logdir, max_queue=0)
   with writer.as_default(), summary_ops.always_record_summaries():
     summary_ops.scalar('scalar', 2.0)
   with self.cached_session() as sess:
     sess.run(variables.global_variables_initializer())
     sess.run(summary_ops.summary_writer_initializer_op())
     step, _ = sess.run(
         [training_util.get_global_step(), summary_ops.all_summary_ops()])
   events = summary_test_util.events_from_logdir(logdir)
   self.assertEqual(2, len(events))
   self.assertEqual(step, events[1].step)
Exemplo n.º 9
0
 def testSummaryOps(self):
   training_util.get_or_create_global_step()
   logdir = tempfile.mkdtemp()
   with summary_ops.create_file_writer(
       logdir, max_queue=0,
       name='t0').as_default(), summary_ops.always_record_summaries():
     summary_ops.generic('tensor', 1, '')
     summary_ops.scalar('scalar', 2.0)
     summary_ops.histogram('histogram', [1.0])
     summary_ops.image('image', [[[[1.0]]]])
     summary_ops.audio('audio', [[1.0]], 1.0, 1)
     # The working condition of the ops is tested in the C++ test so we just
     # test here that we're calling them correctly.
     self.assertTrue(gfile.Exists(logdir))
Exemplo n.º 10
0
 def testWriterFlush(self):
   logdir = self.get_temp_dir()
   with summary_ops.always_record_summaries():
     writer = summary_ops.create_file_writer(
         logdir, max_queue=100, flush_millis=1000000)
     with writer.as_default():
       summary_ops.scalar('one', 1.0, step=1)
   with self.cached_session() as sess:
     sess.run(summary_ops.summary_writer_initializer_op())
     get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
     self.assertEqual(1, get_total())  # file_version Event
     sess.run(summary_ops.all_summary_ops())
     self.assertEqual(1, get_total())
     sess.run(writer.flush())
     self.assertEqual(2, get_total())
Exemplo n.º 11
0
 def testSummaryOps(self):
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(logdir, max_queue=0)
   with writer.as_default(), summary_ops.always_record_summaries():
     summary_ops.generic('tensor', 1, step=1)
     summary_ops.scalar('scalar', 2.0, step=1)
     summary_ops.histogram('histogram', [1.0], step=1)
     summary_ops.image('image', [[[[1.0]]]], step=1)
     summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)
   with self.cached_session() as sess:
     sess.run(summary_ops.summary_writer_initializer_op())
     sess.run(summary_ops.all_summary_ops())
   # The working condition of the ops is tested in the C++ test so we just
   # test here that we're calling them correctly.
   self.assertTrue(gfile.Exists(logdir))
Exemplo n.º 12
0
 def testWriterFlush(self):
   logdir = self.get_temp_dir()
   get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
   with summary_ops.always_record_summaries():
     writer = summary_ops.create_file_writer(
         logdir, max_queue=100, flush_millis=1000000)
     self.assertEqual(1, get_total())  # file_version Event
     with writer.as_default():
       summary_ops.scalar('one', 1.0, step=1)
       self.assertEqual(1, get_total())
       writer.flush()
       self.assertEqual(2, get_total())
       summary_ops.scalar('two', 2.0, step=2)
     # Exiting the "as_default()" should do an implicit flush of the "two" tag
     self.assertEqual(3, get_total())
Exemplo n.º 13
0
 def testDbURIOpen(self):
   tmpdb_path = os.path.join(self.get_temp_dir(), 'tmpDbURITest.sqlite')
   tmpdb_uri = six.moves.urllib_parse.urljoin("file:", tmpdb_path)
   tmpdb_writer = summary_ops.create_db_writer(
       tmpdb_uri,
       "experimentA",
       "run1",
       "user1")
   with summary_ops.always_record_summaries():
     with tmpdb_writer.as_default():
       summary_ops.scalar('t1', 2.0)
   tmpdb = sqlite3.connect(tmpdb_path)
   num = get_one(tmpdb, 'SELECT count(*) FROM Tags WHERE tag_name = "t1"')
   self.assertEqual(num, 1)
   tmpdb.close()
Exemplo n.º 14
0
 def testMaxQueue(self):
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(
       logdir, max_queue=1, flush_millis=999999)
   with writer.as_default(), summary_ops.always_record_summaries():
     summary_ops.scalar('scalar', 2.0, step=1)
   with self.cached_session() as sess:
     sess.run(summary_ops.summary_writer_initializer_op())
     get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
     # Note: First tf.Event is always file_version.
     self.assertEqual(1, get_total())
     sess.run(summary_ops.all_summary_ops())
     self.assertEqual(1, get_total())
     # Should flush after second summary since max_queue = 1
     sess.run(summary_ops.all_summary_ops())
     self.assertEqual(3, get_total())
Exemplo n.º 15
0
  def testScalarSummaryNameScope(self):
    """Test record_summaries_every_n_global_steps and all_summaries()."""
    with ops.Graph().as_default(), self.cached_session() as sess:
      global_step = training_util.get_or_create_global_step()
      global_step.initializer.run()
      with ops.device('/cpu:0'):
        step_increment = state_ops.assign_add(global_step, 1)
      sess.run(step_increment)  # Increment global step from 0 to 1

      logdir = tempfile.mkdtemp()
      with summary_ops.create_file_writer(logdir, max_queue=0,
                                          name='t2').as_default():
        with summary_ops.record_summaries_every_n_global_steps(2):
          summary_ops.initialize()
          with ops.name_scope('scope'):
            summary_op = summary_ops.scalar('my_scalar', 2.0)

          # Neither of these should produce a summary because
          # global_step is 1 and "1 % 2 != 0"
          sess.run(summary_ops.all_summary_ops())
          sess.run(summary_op)
          events = summary_test_util.events_from_logdir(logdir)
          self.assertEqual(len(events), 1)

          # Increment global step from 1 to 2 and check that the summary
          # is now written
          sess.run(step_increment)
          sess.run(summary_ops.all_summary_ops())
          events = summary_test_util.events_from_logdir(logdir)
          self.assertEqual(len(events), 2)
          self.assertEqual(events[1].summary.value[0].tag, 'scope/my_scalar')
Exemplo n.º 16
0
 def define_ops():
   result = []
   # TF 2.0 summary ops
   result.append(summary_ops.write('write', 1, step=0))
   result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb'))
   # TF 1.x tf.contrib.summary ops
   result.append(summary_ops.generic('tensor', 1, step=1))
   result.append(summary_ops.scalar('scalar', 2.0, step=1))
   result.append(summary_ops.histogram('histogram', [1.0], step=1))
   result.append(summary_ops.image('image', [[[[1.0]]]], step=1))
   result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1))
   return result
Exemplo n.º 17
0
 def testFlushFunction(self):
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(
       logdir, max_queue=999999, flush_millis=999999)
   with writer.as_default(), summary_ops.always_record_summaries():
     summary_ops.scalar('scalar', 2.0, step=1)
     flush_op = summary_ops.flush()
   with self.cached_session() as sess:
     sess.run(summary_ops.summary_writer_initializer_op())
     get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
     # Note: First tf.Event is always file_version.
     self.assertEqual(1, get_total())
     sess.run(summary_ops.all_summary_ops())
     self.assertEqual(1, get_total())
     sess.run(flush_op)
     self.assertEqual(2, get_total())
     # Test "writer" parameter
     sess.run(summary_ops.all_summary_ops())
     sess.run(summary_ops.flush(writer=writer))
     self.assertEqual(3, get_total())
     sess.run(summary_ops.all_summary_ops())
     sess.run(summary_ops.flush(writer=writer._resource))  # pylint:disable=protected-access
     self.assertEqual(4, get_total())
Exemplo n.º 18
0
 def testFlushFunction(self):
   logs = tempfile.mkdtemp()
   writer = summary_ops.create_file_writer(
       logs, max_queue=999999, flush_millis=999999, name='lol')
   with writer.as_default(), summary_ops.always_record_summaries():
     get_total = lambda: len(summary_test_util.events_from_logdir(logs))
     # Note: First tf.Event is always file_version.
     self.assertEqual(1, get_total())
     summary_ops.scalar('scalar', 2.0, step=1)
     summary_ops.scalar('scalar', 2.0, step=2)
     self.assertEqual(1, get_total())
     summary_ops.flush()
     self.assertEqual(3, get_total())
     # Test "writer" parameter
     summary_ops.scalar('scalar', 2.0, step=3)
     summary_ops.flush(writer=writer)
     self.assertEqual(4, get_total())
     summary_ops.scalar('scalar', 2.0, step=4)
     summary_ops.flush(writer=writer._resource)  # pylint:disable=protected-access
     self.assertEqual(5, get_total())
Exemplo n.º 19
0
  def testSharedName(self):
    logdir = self.get_temp_dir()
    with summary_ops.always_record_summaries():
      # Create with default shared name (should match logdir)
      writer1 = summary_ops.create_file_writer(logdir)
      with writer1.as_default():
        summary_ops.scalar('one', 1.0, step=1)
      # Create with explicit logdir shared name (should be same resource/file)
      shared_name = 'logdir:' + logdir
      writer2 = summary_ops.create_file_writer(logdir, name=shared_name)
      with writer2.as_default():
        summary_ops.scalar('two', 2.0, step=2)
      # Create with different shared name (should be separate resource/file)
      writer3 = summary_ops.create_file_writer(logdir, name='other')
      with writer3.as_default():
        summary_ops.scalar('three', 3.0, step=3)

    with self.cached_session() as sess:
      # Run init ops across writers sequentially to avoid race condition.
      # TODO(nickfelt): fix race condition in resource manager lookup or create
      sess.run(writer1.init())
      sess.run(writer2.init())
      time.sleep(1.1)  # Ensure filename has a different timestamp
      sess.run(writer3.init())
      sess.run(summary_ops.all_summary_ops())
      sess.run([writer1.flush(), writer2.flush(), writer3.flush()])

    event_files = iter(sorted(gfile.Glob(os.path.join(logdir, '*tfevents*'))))

    # First file has tags "one" and "two"
    events = summary_test_util.events_from_file(next(event_files))
    self.assertEqual('brain.Event:2', events[0].file_version)
    tags = [e.summary.value[0].tag for e in events[1:]]
    self.assertItemsEqual(['one', 'two'], tags)

    # Second file has tag "three"
    events = summary_test_util.events_from_file(next(event_files))
    self.assertEqual('brain.Event:2', events[0].file_version)
    tags = [e.summary.value[0].tag for e in events[1:]]
    self.assertItemsEqual(['three'], tags)

    # No more files
    self.assertRaises(StopIteration, lambda: next(event_files))
Exemplo n.º 20
0
  def testSharedName(self):
    logdir = self.get_temp_dir()
    with summary_ops.always_record_summaries():
      # Create with default shared name (should match logdir)
      writer1 = summary_ops.create_file_writer(logdir)
      with writer1.as_default():
        summary_ops.scalar('one', 1.0, step=1)
        summary_ops.flush()
      # Create with explicit logdir shared name (should be same resource/file)
      shared_name = 'logdir:' + logdir
      writer2 = summary_ops.create_file_writer(logdir, name=shared_name)
      with writer2.as_default():
        summary_ops.scalar('two', 2.0, step=2)
        summary_ops.flush()
      # Create with different shared name (should be separate resource/file)
      time.sleep(1.1)  # Ensure filename has a different timestamp
      writer3 = summary_ops.create_file_writer(logdir, name='other')
      with writer3.as_default():
        summary_ops.scalar('three', 3.0, step=3)
        summary_ops.flush()

    event_files = iter(sorted(gfile.Glob(os.path.join(logdir, '*tfevents*'))))

    # First file has tags "one" and "two"
    events = iter(summary_test_util.events_from_file(next(event_files)))
    self.assertEqual('brain.Event:2', next(events).file_version)
    self.assertEqual('one', next(events).summary.value[0].tag)
    self.assertEqual('two', next(events).summary.value[0].tag)
    self.assertRaises(StopIteration, lambda: next(events))

    # Second file has tag "three"
    events = iter(summary_test_util.events_from_file(next(event_files)))
    self.assertEqual('brain.Event:2', next(events).file_version)
    self.assertEqual('three', next(events).summary.value[0].tag)
    self.assertRaises(StopIteration, lambda: next(events))

    # No more files
    self.assertRaises(StopIteration, lambda: next(event_files))
Exemplo n.º 21
0
    def testSharing_withExplicitSummaryFileWriters(self):
        logdir = self.get_temp_dir()
        with session.Session() as sess:
            # Initial file writer via FileWriter(session=?)
            writer1 = writer.FileWriter(session=sess, logdir=logdir)
            writer1.add_summary(self._createTaggedSummary("one"), 1)
            writer1.flush()

            # Next one via create_file_writer(), should use same file
            writer2 = summary_ops_v2.create_file_writer(logdir=logdir)
            with summary_ops_v2.always_record_summaries(), writer2.as_default(
            ):
                summary2 = summary_ops_v2.scalar("two", 2.0, step=2)
            sess.run(writer2.init())
            sess.run(summary2)
            sess.run(writer2.flush())

            # Next has different shared name, should be in separate file
            time.sleep(1.1)  # Ensure filename has a different timestamp
            writer3 = summary_ops_v2.create_file_writer(logdir=logdir,
                                                        name="other")
            with summary_ops_v2.always_record_summaries(), writer3.as_default(
            ):
                summary3 = summary_ops_v2.scalar("three", 3.0, step=3)
            sess.run(writer3.init())
            sess.run(summary3)
            sess.run(writer3.flush())

            # Next uses a second session, should be in separate file
            time.sleep(1.1)  # Ensure filename has a different timestamp
            with session.Session() as other_sess:
                writer4 = summary_ops_v2.create_file_writer(logdir=logdir)
                with summary_ops_v2.always_record_summaries(
                ), writer4.as_default():
                    summary4 = summary_ops_v2.scalar("four", 4.0, step=4)
                other_sess.run(writer4.init())
                other_sess.run(summary4)
                other_sess.run(writer4.flush())

                # Next via FileWriter(session=?) uses same second session, should be in
                # same separate file. (This checks sharing in the other direction)
                writer5 = writer.FileWriter(session=other_sess, logdir=logdir)
                writer5.add_summary(self._createTaggedSummary("five"), 5)
                writer5.flush()

            # One more via create_file_writer(), should use same file
            writer6 = summary_ops_v2.create_file_writer(logdir=logdir)
            with summary_ops_v2.always_record_summaries(), writer6.as_default(
            ):
                summary6 = summary_ops_v2.scalar("six", 6.0, step=6)
            sess.run(writer6.init())
            sess.run(summary6)
            sess.run(writer6.flush())

        event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*"))))

        # First file should have tags "one", "two", and "six"
        events = summary_iterator.summary_iterator(next(event_paths))
        self.assertEqual("brain.Event:2", next(events).file_version)
        self.assertEqual("one", next(events).summary.value[0].tag)
        self.assertEqual("two", next(events).summary.value[0].tag)
        self.assertEqual("six", next(events).summary.value[0].tag)
        self.assertRaises(StopIteration, lambda: next(events))

        # Second file should have just "three"
        events = summary_iterator.summary_iterator(next(event_paths))
        self.assertEqual("brain.Event:2", next(events).file_version)
        self.assertEqual("three", next(events).summary.value[0].tag)
        self.assertRaises(StopIteration, lambda: next(events))

        # Third file should have "four" and "five"
        events = summary_iterator.summary_iterator(next(event_paths))
        self.assertEqual("brain.Event:2", next(events).file_version)
        self.assertEqual("four", next(events).summary.value[0].tag)
        self.assertEqual("five", next(events).summary.value[0].tag)
        self.assertRaises(StopIteration, lambda: next(events))

        # No more files
        self.assertRaises(StopIteration, lambda: next(event_paths))
Exemplo n.º 22
0
 def run_fn():
     """Function executed for each replica."""
     with summary_writer.as_default():
         replica_id = ds_context.get_replica_context(
         ).replica_id_in_sync_group
         return summary_ops.scalar("a", replica_id)
Exemplo n.º 23
0
 def write():
   summary_ops.scalar('scalar', 2.0)
Exemplo n.º 24
0
 def write():
     summary_ops.scalar('scalar', 2.0)
Exemplo n.º 25
0
 def write_summary_f():
   summary_ops.scalar(name=self.name, tensor=t)
   return t
Exemplo n.º 26
0
 def call(self, inputs):
     summary_ops_v2.scalar('mean', math_ops.reduce_mean(inputs))
     return inputs
 def run():
   with writer.as_default():
     summary_ops.scalar("result", step * 2, step=step)
     step.assign_add(1)
Exemplo n.º 28
0
 def run_fn():
   """Function executed for each replica."""
   with summary_writer.as_default():
     replica_id = ds_context.get_replica_context().replica_id_in_sync_group
     return summary_ops.scalar("a", replica_id)
Exemplo n.º 29
0
 def body(unused_pred):
   summary_ops.scalar('scalar', 2.0)
   return constant_op.constant(False)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/hparams.yaml')
    parser.add_argument('-load_model', type=str, default=None)
    parser.add_argument('-model_name', type=str, default='P_S_Transformer_debug',
                        help='model name')
    # parser.add_argument('-batches_per_allreduce', type=int, default=1,
    #                     help='number of batches processed locally before '
    #                          'executing allreduce across workers; it multiplies '
    #                          'total batch size.')
    parser.add_argument('-num_wokers', type=int, default=0,
                        help='how many subprocesses to use for data loading. '
                             '0 means that the data will be loaded in the main process')
    parser.add_argument('-log', type=str, default='train.log')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile,Loader=yaml.FullLoader))

    log_name = opt.model_name or config.model.name
    log_folder = os.path.join(os.getcwd(),'logdir/logging',log_name)
    if not os.path.isdir(log_folder):
        os.mkdir(log_folder)
    logger = init_logger(log_folder+'/'+opt.log)

    # TODO: build dataloader
    train_datafeeder = DataFeeder(config,'debug')

    # TODO: build model or load pre-trained model
    global global_step
    global_step = 0
    learning_rate = CustomSchedule(config.model.d_model)
    # learning_rate = 0.00002
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=config.optimizer.beta1, beta_2=config.optimizer.beta2,
                                         epsilon=config.optimizer.epsilon)
    logger.info('config.optimizer.beta1:' + str(config.optimizer.beta1))
    logger.info('config.optimizer.beta2:' + str(config.optimizer.beta2))
    logger.info('config.optimizer.epsilon:' + str(config.optimizer.epsilon))
    # print(str(config))
    model = Speech_transformer(config=config,logger=logger)

    #Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n epochs.
    checkpoint_path = log_folder
    ckpt = tf.train.Checkpoint(transformer=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        logger.info('Latest checkpoint restored!!')
    else:
        logger.info('Start new run')


    # define metrics and summary writer
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    # summary_writer = tf.keras.callbacks.TensorBoard(log_dir=log_folder)
    summary_writer = summary_ops_v2.create_file_writer_v2(log_folder+'/train')


    # @tf.function
    def train_step(batch_data):
        inp = batch_data['the_inputs'] # batch*time*feature
        tar = batch_data['the_labels'] # batch*time
        # inp_len = batch_data['input_length']
        # tar_len = batch_data['label_length']
        gtruth = batch_data['ground_truth']
        tar_inp = tar
        tar_real = gtruth
        # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp[:,:,0], tar_inp)
        combined_mask = create_combined_mask(tar=tar_inp)
        with tf.GradientTape() as tape:
            predictions, _ = model(inp, tar_inp, True, None,
                                   combined_mask, None)
            # logger.info('config.train.label_smoothing_epsilon:' + str(config.train.label_smoothing_epsilon))
            loss = LableSmoothingLoss(tar_real, predictions,config.model.vocab_size,config.train.label_smoothing_epsilon)
        gradients = tape.gradient(loss, model.trainable_variables)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(loss)
        train_accuracy(tar_real, predictions)

    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    acc_window = ValueWindow(100)
    logger.info('config.train.epoches:' + str(config.train.epoches))
    first_time = True
    for epoch in range(config.train.epoches):
        logger.info('start epoch '+ str(epoch))
        logger.info('total wavs: '+ str(len(train_datafeeder)))
        logger.info('batch size: ' + str(train_datafeeder.batch_size))
        logger.info('batch per epoch: ' + str(len(train_datafeeder)//train_datafeeder.batch_size))
        train_data = train_datafeeder.get_batch()
        start_time = time.time()
        train_loss.reset_states()
        train_accuracy.reset_states()

        for step in range(len(train_datafeeder)//train_datafeeder.batch_size):
            batch_data = next(train_data)
            step_time = time.time()
            train_step(batch_data)
            if first_time:
                model.summary()
                first_time=False
            time_window.append(time.time()-step_time)
            loss_window.append(train_loss.result())
            acc_window.append(train_accuracy.result())
            message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f]' % (
                    global_step, time_window.average, train_loss.result(), loss_window.average, train_accuracy.result(),acc_window.average)
            logger.info(message)

            if global_step % 10 == 0:
                with summary_ops_v2.always_record_summaries():
                    with summary_writer.as_default():
                        summary_ops_v2.scalar('train_loss', train_loss.result(), step=global_step)
                        summary_ops_v2.scalar('train_acc', train_accuracy.result(), step=global_step)

            global_step += 1

        ckpt_save_path = ckpt_manager.save()
        logger.info('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
        logger.info('Time taken for 1 epoch: {} secs\n'.format(time.time() - start_time))
Exemplo n.º 31
0
 def body(unused_pred):
     summary_ops.scalar('scalar', 2.0)
     return constant_op.constant(False)
Exemplo n.º 32
0
 def f():
     summary_ops.scalar('scalar', 2.0)
     return constant_op.constant(True)
Exemplo n.º 33
0
 def write_summary_f():
     summary_ops.scalar(name=self.name, tensor=t)
     return t
Exemplo n.º 34
0
 def f():
   summary_ops.scalar('scalar', 2.0)
   return constant_op.constant(True)
Exemplo n.º 35
0
 def result(self):
   t = self.numer / self.denom
   summary_ops.scalar(name=self.name, tensor=t)
   return t
Exemplo n.º 36
0
  def testSharing_withExplicitSummaryFileWriters(self):
    logdir = self.get_temp_dir()
    with session.Session() as sess:
      # Initial file writer via FileWriter(session=?)
      writer1 = writer.FileWriter(session=sess, logdir=logdir)
      writer1.add_summary(self._createTaggedSummary("one"), 1)
      writer1.flush()

      # Next one via create_file_writer(), should use same file
      writer2 = summary_ops_v2.create_file_writer(logdir=logdir)
      with summary_ops_v2.always_record_summaries(), writer2.as_default():
        summary2 = summary_ops_v2.scalar("two", 2.0, step=2)
      sess.run(writer2.init())
      sess.run(summary2)
      sess.run(writer2.flush())

      # Next has different shared name, should be in separate file
      time.sleep(1.1)  # Ensure filename has a different timestamp
      writer3 = summary_ops_v2.create_file_writer(logdir=logdir, name="other")
      with summary_ops_v2.always_record_summaries(), writer3.as_default():
        summary3 = summary_ops_v2.scalar("three", 3.0, step=3)
      sess.run(writer3.init())
      sess.run(summary3)
      sess.run(writer3.flush())

      # Next uses a second session, should be in separate file
      time.sleep(1.1)  # Ensure filename has a different timestamp
      with session.Session() as other_sess:
        writer4 = summary_ops_v2.create_file_writer(logdir=logdir)
        with summary_ops_v2.always_record_summaries(), writer4.as_default():
          summary4 = summary_ops_v2.scalar("four", 4.0, step=4)
        other_sess.run(writer4.init())
        other_sess.run(summary4)
        other_sess.run(writer4.flush())

        # Next via FileWriter(session=?) uses same second session, should be in
        # same separate file. (This checks sharing in the other direction)
        writer5 = writer.FileWriter(session=other_sess, logdir=logdir)
        writer5.add_summary(self._createTaggedSummary("five"), 5)
        writer5.flush()

      # One more via create_file_writer(), should use same file
      writer6 = summary_ops_v2.create_file_writer(logdir=logdir)
      with summary_ops_v2.always_record_summaries(), writer6.as_default():
        summary6 = summary_ops_v2.scalar("six", 6.0, step=6)
      sess.run(writer6.init())
      sess.run(summary6)
      sess.run(writer6.flush())

    event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*"))))

    # First file should have tags "one", "two", and "six"
    events = summary_iterator.summary_iterator(next(event_paths))
    self.assertEqual("brain.Event:2", next(events).file_version)
    self.assertEqual("one", next(events).summary.value[0].tag)
    self.assertEqual("two", next(events).summary.value[0].tag)
    self.assertEqual("six", next(events).summary.value[0].tag)
    self.assertRaises(StopIteration, lambda: next(events))

    # Second file should have just "three"
    events = summary_iterator.summary_iterator(next(event_paths))
    self.assertEqual("brain.Event:2", next(events).file_version)
    self.assertEqual("three", next(events).summary.value[0].tag)
    self.assertRaises(StopIteration, lambda: next(events))

    # Third file should have "four" and "five"
    events = summary_iterator.summary_iterator(next(event_paths))
    self.assertEqual("brain.Event:2", next(events).file_version)
    self.assertEqual("four", next(events).summary.value[0].tag)
    self.assertEqual("five", next(events).summary.value[0].tag)
    self.assertRaises(StopIteration, lambda: next(events))

    # No more files
    self.assertRaises(StopIteration, lambda: next(event_paths))
 def host_computation(x):
     summary.scalar("x", x, step=0)
     return x * 2.0