def _print(pass_through_tensor, values): """Wrapper for tf.Print which supports lists and namedtuples for printing.""" flat_values = [] for value in values: # Checks if it is a namedtuple. if hasattr(value, '_fields'): for field in value._fields: flat_values.extend([field, _to_str(getattr(value, field))]) continue if isinstance(value, (list, tuple)): for v in value: flat_values.append(_to_str(v)) continue flat_values.append(_to_str(value)) return tf.Print(pass_through_tensor, flat_values)
def print_dataset(features): """tf.Print dataset fields for debugging purposes.""" return {k: tf.Print(v, [v], k + ': ') for k, v in features.items()}
def my_fn(x): return {k: tf.Print(v, [v], k + ": ") for k, v in x.items()}