コード例 #1
0
 def testGinConfigSaverHookIncludeStepFalse(self):
   output_dir, _ = self.run_log_config_hook_maybe_with_summary(
       global_step_value=7, include_step_in_filename=False)
   expected_file_name = 'operative_config.gin'
   with tf.io.gfile.GFile(os.path.join(output_dir, expected_file_name)) as f:
     operative_config_str = f.read()
   self.assertEqual(operative_config_str, config.operative_config_str())
コード例 #2
0
    def after_create_session(self, session=None, coord=None):
        """Writes out Gin's operative config, and maybe adds a summary of it."""
        config_str = config.operative_config_str()
        if not tf.gfile.IsDirectory(self._output_dir):
            tf.gfile.MakeDirs(self._output_dir)
        global_step_val = 0
        if session is not None:
            global_step = tf.train.get_global_step()
            if global_step is not None:
                global_step_val = session.run(global_step)
        filename = '%s-%s.gin' % (self._base_name, global_step_val)
        config_path = os.path.join(self._output_dir, filename)
        with tf.gfile.GFile(config_path, 'w') as f:
            f.write(config_str)

        if self._summarize_config:
            md_config_str = self._markdownify_operative_config_str(config_str)
            summary_metadata = summary_pb2.SummaryMetadata()
            summary_metadata.plugin_data.plugin_name = 'text'
            summary_metadata.plugin_data.content = b'{}'
            text_tensor = tf.make_tensor_proto(md_config_str)
            summary = summary_pb2.Summary()
            summary.value.add(tag='gin/' + self._base_name,
                              tensor=text_tensor,
                              metadata=summary_metadata)
            if not self._summary_writer:
                # Creating the FileWriter also creates the events file, so it should be
                # done here (where it is most likely to only occur on chief workers), as
                # opposed to in the constructor.
                self._summary_writer = tf.summary.FileWriterCache.get(
                    self._output_dir)
            self._summary_writer.add_summary(summary, global_step_val)
            self._summary_writer.flush()
コード例 #3
0
 def testGinConfigSaverHookWithoutSummary(self):
   global_step_value = 7
   output_dir, summary_writer = self.run_log_config_hook_maybe_with_summary(
       global_step_value=global_step_value, summarize_config=False)
   expected_file_name = 'operative_config-%d.gin' % global_step_value
   with tf.io.gfile.GFile(os.path.join(output_dir, expected_file_name)) as f:
     operative_config_str = f.read()
   self.assertEqual(operative_config_str, config.operative_config_str())
   self.assertEmpty(summary_writer.summaries)
コード例 #4
0
  def testKwOnlyArgs(self):
    config_str = """
      fn_with_kw_only_args.arg1 = 'arg1'
      fn_with_kw_only_args.kwarg1 = 'kwarg1'
    """

    arg, kwarg = fn_with_kw_only_args(None)
    self.assertEqual(arg, None)
    self.assertEqual(kwarg, None)
    self.assertIn('fn_with_kw_only_args.kwarg1 = None',
                  config.operative_config_str())

    config.parse_config(config_str)

    arg, kwarg = fn_with_kw_only_args('arg1')
    self.assertEqual(arg, 'arg1')
    self.assertEqual(kwarg, 'kwarg1')
    self.assertIn("fn_with_kw_only_args.kwarg1 = 'kwarg1'",
                  config.operative_config_str())
コード例 #5
0
  def testGinConfigSaverHookWithoutGlobalStep(self):
    output_dir, summary_writer = self.run_log_config_hook_maybe_with_summary(
        global_step_value=None)
    expected_file_name = 'operative_config-0.gin'
    with tf.io.gfile.GFile(os.path.join(output_dir, expected_file_name)) as f:
      operative_config_str = f.read()
    self.assertEqual(operative_config_str, config.operative_config_str())

    summary = summary_writer.summaries[0][0]
    self.assertEqual(summary.value[0].tag, 'gin/operative_config')
コード例 #6
0
    def testGinConfigSaverHookWithSummary(self):
        global_step_value = 7
        output_dir, summary_writer = self.run_log_config_hook_maybe_with_summary(
            global_step_value=global_step_value, base_name='custom_name')
        expected_file_name = 'custom_name-%d.gin' % global_step_value
        with tf.gfile.Open(os.path.join(output_dir, expected_file_name)) as f:
            operative_config_str = f.read()
        self.assertEqual(operative_config_str, config.operative_config_str())
        summary_writer.assert_summaries(test_case=self,
                                        expected_logdir=output_dir)

        summary = summary_writer.summaries[global_step_value][0]
        self.assertEqual(summary.value[0].tag, 'gin/custom_name')

        summary_lines = summary.value[0].tensor.string_val[0].splitlines()
        markdown = GinConfigSaverHookTest.EXPECTED_MARKDOWN
        markdown_lines = markdown.strip().splitlines()
        self.assertEqual(len(summary_lines), len(markdown_lines))
        for l1, l2 in zip(summary_lines, markdown_lines):
            self.assertEqual(l1.strip(), l2.strip())