def add_summary_by_name(summary_name: str, summary_value: tf.Tensor, max_outputs_tb: int = 1): """ Add the summary defining the type of it by name and subtracting the prefix from name Parameters ---------- summary_name name of the summary summary_value : value of summary max_outputs_tb number of maximum outputs in tensorboard e.g. for images """ name_splitted = summary_name.split('/') if len(name_splitted) > 1: family = name_splitted[0] else: family = None if 'scalar_' in summary_name: tf.summary.scalar(summary_name.replace('scalar_', ''), summary_value, family=family) elif 'image_' in summary_name: tf.summary.image(summary_name.replace('image_', ''), summary_value, max_outputs=max_outputs_tb, family=family) elif 'histogram_' in summary_name: main_name = summary_name.replace('histogram_', '') if isinstance(summary_value, dict): for each_summary_name, each_summary_value in summary_value.items(): histogram_name = '_'.join([main_name, each_summary_name]) add_histogram_summary(histogram_name, each_summary_value) else: add_histogram_summary(main_name, summary_value) elif 'text_' in summary_name: tf.summary.text(summary_name.replace('text_', ''), summary_value) elif 'audio_' in summary_name: # TODO([email protected]) Find a way to set the sample_rate tf.summary.audio(summary_name.replace('audio_', ''), summary_value, max_outputs=max_outputs_tb, family=family, sample_rate=16000) else: msg = ('Warning: summary with name {} will not be added ' 'to tensorboard!'.format(summary_name)) warnings.warn(msg, RuntimeWarning, stacklevel=2)
def __call__(self, values: Tensor, raveled: bool = True) -> Tensor: if isinstance(values, dict): return {k: self(v, raveled) for k, v in values.items()} return self._call_with_tensor_values(values=values, raveled=raveled)