示例#1
0
def tf_summary_to_dict(tf_summary_str_or_pb, namespace=""):
    """Convert a Tensorboard Summary to a dictionary

    Accepts either a tensorflow.summary.Summary
    or one encoded as a string.
    """
    values = {}
    if isinstance(tf_summary_str_or_pb, Summary):
        summary_pb = tf_summary_str_or_pb
    elif isinstance(tf_summary_str_or_pb, Event):
        summary_pb = tf_summary_str_or_pb.summary
        values["global_step"] = tf_summary_str_or_pb.step
        values["_timestamp"] = tf_summary_str_or_pb.wall_time
    else:
        summary_pb = Summary()
        summary_pb.ParseFromString(tf_summary_str_or_pb)

    for value in summary_pb.value:
        kind = value.WhichOneof("value")
        if kind == "simple_value":
            values[namespaced_tag(value.tag, namespace)] = value.simple_value
        elif kind == "image":
            from PIL import Image
            image = wandb.Image(
                Image.open(six.BytesIO(value.image.encoded_image_string)))
            tag_idx = value.tag.rsplit('/', 1)
            if len(tag_idx) > 1 and tag_idx[1].isdigit():
                tag, idx = tag_idx
                values.setdefault(history_image_key(tag), []).append(image)
            else:
                values[history_image_key(value.tag)] = image
        # Coming soon...
        # elif kind == "audio":
        #    audio = wandb.Audio(six.BytesIO(value.audio.encoded_audio_string),
        #                        sample_rate=value.audio.sample_rate, content_type=value.audio.content_type)
        elif kind == "histo":
            first = value.histo.bucket_limit[0] + \
                value.histo.bucket_limit[0] - value.histo.bucket_limit[1]
            last = value.histo.bucket_limit[-2] + \
                value.histo.bucket_limit[-2] - value.histo.bucket_limit[-3]
            np_histogram = (list(value.histo.bucket),
                            [first] + value.histo.bucket_limit[:-1] + [last])
            values[namespaced_tag(
                value.tag)] = wandb.Histogram(np_histogram=np_histogram)

    return values
示例#2
0
def tf_summary_to_dict(tf_summary_str_or_pb, namespace=""):
    """Convert a Tensorboard Summary to a dictionary

    Accepts either a tensorflow.summary.Summary
    or one encoded as a string.
    """
    values = {}
    if hasattr(tf_summary_str_or_pb, "summary"):
        summary_pb = tf_summary_str_or_pb.summary
        values[namespaced_tag("global_step",
                              namespace)] = tf_summary_str_or_pb.step
        values["_timestamp"] = tf_summary_str_or_pb.wall_time
    elif isinstance(tf_summary_str_or_pb, (str, bytes, bytearray)):
        summary_pb = Summary()
        summary_pb.ParseFromString(tf_summary_str_or_pb)
    else:
        if not hasattr(tf_summary_str_or_pb, "value"):
            raise ValueError(
                "Can't log %s, only Event, Summary, or Summary proto buffer strings are accepted"
            )
        else:
            summary_pb = tf_summary_str_or_pb

    for value in summary_pb.value:
        kind = value.WhichOneof("value")
        if kind in IGNORE_KINDS:
            continue
        if kind == "simple_value":
            values[namespaced_tag(value.tag, namespace)] = value.simple_value
        elif kind == "tensor":
            values[namespaced_tag(value.tag,
                                  namespace)] = make_ndarray(value.tensor)
        elif kind == "image":
            from PIL import Image
            img_str = value.image.encoded_image_string
            # Supports gifs from TboardX
            if img_str.startswith(b"GIF"):
                image = wandb.Video(six.BytesIO(img_str), format="gif")
            else:
                image = wandb.Image(Image.open(six.BytesIO(img_str)))
            tag_idx = value.tag.rsplit('/', 1)
            if len(tag_idx) > 1 and tag_idx[1].isdigit():
                tag, idx = tag_idx
                values.setdefault(history_image_key(tag, namespace),
                                  []).append(image)
            else:
                values[history_image_key(value.tag, namespace)] = [image]
        # Coming soon...
        # elif kind == "audio":
        #    audio = wandb.Audio(six.BytesIO(value.audio.encoded_audio_string),
        #                        sample_rate=value.audio.sample_rate, content_type=value.audio.content_type)
        elif kind == "histo":
            tag = namespaced_tag(value.tag, namespace)
            if len(value.histo.bucket_limit) >= 3:
                first = value.histo.bucket_limit[0] + \
                    value.histo.bucket_limit[0] - value.histo.bucket_limit[1]
                last = value.histo.bucket_limit[-2] + \
                    value.histo.bucket_limit[-2] - value.histo.bucket_limit[-3]
                np_histogram = (list(value.histo.bucket), [first] +
                                value.histo.bucket_limit[:-1] + [last])
                try:
                    #TODO: we should just re-bin if there are too many buckets
                    values[tag] = wandb.Histogram(np_histogram=np_histogram)
                except ValueError:
                    wandb.termwarn(
                        "Not logging key \"{}\".  Histograms must have fewer than {} bins"
                        .format(tag, wandb.Histogram.MAX_LENGTH),
                        repeat=False)
            else:
                #TODO: is there a case where we can render this?
                wandb.termwarn(
                    "Not logging key \"{}\".  Found a histogram with only 2 bins."
                    .format(tag),
                    repeat=False)
        elif value.tag == "_hparams_/session_start_info":
            if wandb.util.get_module("tensorboard.plugins.hparams"):
                from tensorboard.plugins.hparams import plugin_data_pb2
                plugin_data = plugin_data_pb2.HParamsPluginData()
                plugin_data.ParseFromString(value.metadata.plugin_data.content)
                for key, param in six.iteritems(
                        plugin_data.session_start_info.hparams):
                    if not wandb.run.config.get(key):
                        wandb.run.config[
                            key] = param.number_value or param.string_value or param.bool_value
            else:
                wandb.termerror(
                    "Received hparams tf.summary, but could not import the hparams plugin from tensorboard"
                )
    return values
示例#3
0
def tf_summary_to_dict(tf_summary_str_or_pb, namespace=""):
    """Convert a Tensorboard Summary to a dictionary

    Accepts either a tensorflow.summary.Summary
    or one encoded as a string.
    """
    values = {}
    if hasattr(tf_summary_str_or_pb, "summary"):
        summary_pb = tf_summary_str_or_pb.summary
        values[namespaced_tag("global_step",
                              namespace)] = tf_summary_str_or_pb.step
        values["_timestamp"] = tf_summary_str_or_pb.wall_time
    elif isinstance(tf_summary_str_or_pb, (str, bytes, bytearray)):
        summary_pb = Summary()
        summary_pb.ParseFromString(tf_summary_str_or_pb)
    else:
        if not hasattr(tf_summary_str_or_pb, "value"):
            raise ValueError(
                "Can't log %s, only Event, Summary, or Summary proto buffer strings are accepted"
            )
        else:
            summary_pb = tf_summary_str_or_pb

    for value in summary_pb.value:
        kind = value.WhichOneof("value")
        if kind == "simple_value":
            values[namespaced_tag(value.tag, namespace)] = value.simple_value
        elif kind == "image":
            from PIL import Image
            image = wandb.Image(
                Image.open(six.BytesIO(value.image.encoded_image_string)))
            tag_idx = value.tag.rsplit('/', 1)
            if len(tag_idx) > 1 and tag_idx[1].isdigit():
                tag, idx = tag_idx
                values.setdefault(history_image_key(tag, namespace),
                                  []).append(image)
            else:
                values[history_image_key(value.tag, namespace)] = image
        # Coming soon...
        # elif kind == "audio":
        #    audio = wandb.Audio(six.BytesIO(value.audio.encoded_audio_string),
        #                        sample_rate=value.audio.sample_rate, content_type=value.audio.content_type)
        elif kind == "histo":
            first = value.histo.bucket_limit[0] + \
                value.histo.bucket_limit[0] - value.histo.bucket_limit[1]
            last = value.histo.bucket_limit[-2] + \
                value.histo.bucket_limit[-2] - value.histo.bucket_limit[-3]
            np_histogram = (list(value.histo.bucket),
                            [first] + value.histo.bucket_limit[:-1] + [last])
            values[namespaced_tag(
                value.tag,
                namespace)] = wandb.Histogram(np_histogram=np_histogram)
        elif value.tag == "_hparams_/session_start_info":
            if wandb.util.get_module("tensorboard.plugins.hparams"):
                from tensorboard.plugins.hparams import plugin_data_pb2
                plugin_data = plugin_data_pb2.HParamsPluginData()
                plugin_data.ParseFromString(value.metadata.plugin_data.content)
                for key, param in six.iteritems(
                        plugin_data.session_start_info.hparams):
                    if not wandb.run.config.get(key):
                        wandb.run.config[
                            key] = param.number_value or param.string_value or param.bool_value
            else:
                wandb.termerror(
                    "Received hparams tf.summary, but could not import the hparams plugin from tensorboard"
                )

    return values