def testGinConfigSaverHookIncludeStepFalse(self): output_dir, _ = self.run_log_config_hook_maybe_with_summary( global_step_value=7, include_step_in_filename=False) expected_file_name = 'operative_config.gin' with tf.io.gfile.GFile(os.path.join(output_dir, expected_file_name)) as f: operative_config_str = f.read() self.assertEqual(operative_config_str, config.operative_config_str())
def after_create_session(self, session=None, coord=None): """Writes out Gin's operative config, and maybe adds a summary of it.""" config_str = config.operative_config_str() if not tf.gfile.IsDirectory(self._output_dir): tf.gfile.MakeDirs(self._output_dir) global_step_val = 0 if session is not None: global_step = tf.train.get_global_step() if global_step is not None: global_step_val = session.run(global_step) filename = '%s-%s.gin' % (self._base_name, global_step_val) config_path = os.path.join(self._output_dir, filename) with tf.gfile.GFile(config_path, 'w') as f: f.write(config_str) if self._summarize_config: md_config_str = self._markdownify_operative_config_str(config_str) summary_metadata = summary_pb2.SummaryMetadata() summary_metadata.plugin_data.plugin_name = 'text' summary_metadata.plugin_data.content = b'{}' text_tensor = tf.make_tensor_proto(md_config_str) summary = summary_pb2.Summary() summary.value.add(tag='gin/' + self._base_name, tensor=text_tensor, metadata=summary_metadata) if not self._summary_writer: # Creating the FileWriter also creates the events file, so it should be # done here (where it is most likely to only occur on chief workers), as # opposed to in the constructor. self._summary_writer = tf.summary.FileWriterCache.get( self._output_dir) self._summary_writer.add_summary(summary, global_step_val) self._summary_writer.flush()
def testGinConfigSaverHookWithoutSummary(self): global_step_value = 7 output_dir, summary_writer = self.run_log_config_hook_maybe_with_summary( global_step_value=global_step_value, summarize_config=False) expected_file_name = 'operative_config-%d.gin' % global_step_value with tf.io.gfile.GFile(os.path.join(output_dir, expected_file_name)) as f: operative_config_str = f.read() self.assertEqual(operative_config_str, config.operative_config_str()) self.assertEmpty(summary_writer.summaries)
def testKwOnlyArgs(self): config_str = """ fn_with_kw_only_args.arg1 = 'arg1' fn_with_kw_only_args.kwarg1 = 'kwarg1' """ arg, kwarg = fn_with_kw_only_args(None) self.assertEqual(arg, None) self.assertEqual(kwarg, None) self.assertIn('fn_with_kw_only_args.kwarg1 = None', config.operative_config_str()) config.parse_config(config_str) arg, kwarg = fn_with_kw_only_args('arg1') self.assertEqual(arg, 'arg1') self.assertEqual(kwarg, 'kwarg1') self.assertIn("fn_with_kw_only_args.kwarg1 = 'kwarg1'", config.operative_config_str())
def testGinConfigSaverHookWithoutGlobalStep(self): output_dir, summary_writer = self.run_log_config_hook_maybe_with_summary( global_step_value=None) expected_file_name = 'operative_config-0.gin' with tf.io.gfile.GFile(os.path.join(output_dir, expected_file_name)) as f: operative_config_str = f.read() self.assertEqual(operative_config_str, config.operative_config_str()) summary = summary_writer.summaries[0][0] self.assertEqual(summary.value[0].tag, 'gin/operative_config')
def testGinConfigSaverHookWithSummary(self): global_step_value = 7 output_dir, summary_writer = self.run_log_config_hook_maybe_with_summary( global_step_value=global_step_value, base_name='custom_name') expected_file_name = 'custom_name-%d.gin' % global_step_value with tf.gfile.Open(os.path.join(output_dir, expected_file_name)) as f: operative_config_str = f.read() self.assertEqual(operative_config_str, config.operative_config_str()) summary_writer.assert_summaries(test_case=self, expected_logdir=output_dir) summary = summary_writer.summaries[global_step_value][0] self.assertEqual(summary.value[0].tag, 'gin/custom_name') summary_lines = summary.value[0].tensor.string_val[0].splitlines() markdown = GinConfigSaverHookTest.EXPECTED_MARKDOWN markdown_lines = markdown.strip().splitlines() self.assertEqual(len(summary_lines), len(markdown_lines)) for l1, l2 in zip(summary_lines, markdown_lines): self.assertEqual(l1.strip(), l2.strip())