def test_merge_graph_defs_mismatch_version(self):
        graph_def_a = GraphDef()
        text_format.Parse(
            """
              node {
                name: "A"
                op: "Input"
              }
              versions {
                producer: 21
              }
          """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
              node {
                name: "A"
                op: "Input"
              }
              versions {
                producer: 100
              }
          """,
            graph_def_b,
        )

        with self.assertRaisesRegex(
            ValueError, "Cannot combine GraphDefs of different versions"
        ):
            graph_util.merge_graph_defs([graph_def_a, graph_def_b])
    def test_merge_graph_defs_gradient(self):
        expected_proto = """
            library {
              gradient {
                function_name: "graph_1_foo"
                gradient_func: "graph_1_foo_grad"
              }
              gradient {
                function_name: "graph_2_foo"
                gradient_func: "graph_2_foo_grad"
              }
              gradient {
                function_name: "graph_2_bar"
                gradient_func: "graph_2_bar_grad"
              }
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                library {
                  gradient {
                    function_name: "foo"
                    gradient_func: "foo_grad"
                  }
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                library {
                  gradient {
                    function_name: "foo"
                    gradient_func: "foo_grad"
                  }
                  gradient {
                    function_name: "bar"
                    gradient_func: "bar_grad"
                  }
                }
            """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )
Beispiel #3
0
    def graph_impl(
        self,
        ctx,
        run,
        tag,
        is_conceptual,
        experiment=None,
        limit_attr_size=None,
        large_attrs_key=None,
    ):
        """Result of the form `(body, mime_type)`; may raise `NotFound`."""
        if is_conceptual:
            keras_model_config = json.loads(
                self._read_blob(
                    ctx,
                    experiment,
                    [metadata.PLUGIN_NAME_KERAS_MODEL],
                    run,
                    tag,
                ))
            graph = keras_util.keras_model_to_graph_def(keras_model_config)

        elif tag is None:
            graph_raw = self._read_blob(
                ctx,
                experiment,
                [metadata.PLUGIN_NAME],
                run,
                metadata.RUN_GRAPH_NAME,
            )
            graph = graph_pb2.GraphDef.FromString(graph_raw)

        else:
            # Op graph: could be either of two plugins. (Cf. `info_impl`.)
            plugins = [
                metadata.PLUGIN_NAME_RUN_METADATA,
                metadata.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH,
            ]
            raw_run_metadata = self._read_blob(ctx, experiment, plugins, run,
                                               tag)
            run_metadata = config_pb2.RunMetadata.FromString(raw_run_metadata)
            graph = graph_util.merge_graph_defs([
                func_graph.pre_optimization_graph
                for func_graph in run_metadata.function_graphs
            ])

        # This next line might raise a ValueError if the limit parameters
        # are invalid (size is negative, size present but key absent, etc.).
        process_graph.prepare_graph_for_ui(graph, limit_attr_size,
                                           large_attrs_key)
        return (str(graph), "text/x-protobuf")  # pbtxt
    def test_merge_graph_defs_single_graph_def_no_prefix(self):
        graph_def_a = GraphDef()
        text_format.Parse(
            """
              node {
                name: "A"
                op: "Input"
              }
              versions {
                producer: 21
              }
          """,
            graph_def_a,
        )

        self.assertProtoEquals(
            graph_def_a,
            graph_util.merge_graph_defs([graph_def_a]),
        )
Beispiel #5
0
    def graph_impl(
        self,
        ctx,
        run,
        tag,
        is_conceptual,
        experiment=None,
        limit_attr_size=None,
        large_attrs_key=None,
    ):
        """Result of the form `(body, mime_type)`, or `None` if no graph
        exists."""
        if self._data_provider:
            if tag is None:
                tag = metadata.RUN_GRAPH_NAME
            graph_blob_sequences = self._data_provider.read_blob_sequences(
                ctx,
                experiment_id=experiment,
                plugin_name=metadata.PLUGIN_NAME,
                run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
                downsample=1,
            )
            blob_datum_list = graph_blob_sequences.get(run, {}).get(tag, ())
            try:
                blob_ref = blob_datum_list[0].values[0]
            except IndexError:
                return None
            # Always use the blob_key approach for now, even if there is a direct url.
            graph_raw = self._data_provider.read_blob(
                ctx, blob_key=blob_ref.blob_key
            )
            # This method ultimately returns pbtxt, but we have to deserialize and
            # later reserialize this anyway, because a) this way we accept binary
            # protobufs too, and b) below we run `prepare_graph_for_ui` on the graph.
            graph = graph_pb2.GraphDef.FromString(graph_raw)

        elif is_conceptual:
            tensor_events = self._multiplexer.Tensors(run, tag)
            # Take the first event if there are multiple events written from different
            # steps.
            keras_model_config = json.loads(
                tensor_events[0].tensor_proto.string_val[0]
            )
            graph = keras_util.keras_model_to_graph_def(keras_model_config)

        elif tag:
            tensor_events = self._multiplexer.Tensors(run, tag)
            # Take the first event if there are multiple events written from different
            # steps.
            run_metadata = config_pb2.RunMetadata.FromString(
                tensor_events[0].tensor_proto.string_val[0]
            )
            graph = graph_util.merge_graph_defs(
                [
                    func_graph.pre_optimization_graph
                    for func_graph in run_metadata.function_graphs
                ]
            )
        else:
            graph = self._multiplexer.Graph(run)

        # This next line might raise a ValueError if the limit parameters
        # are invalid (size is negative, size present but key absent, etc.).
        process_graph.prepare_graph_for_ui(
            graph, limit_attr_size, large_attrs_key
        )
        return (str(graph), "text/x-protobuf")  # pbtxt
    def test_merge_graph_defs_partitioned_call_remap(self):
        expected_proto = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "graph_1/X"
                  op: "PartitionedCall"
                  attr {
                    key: "f"
                    value {
                      func {
                        name: "graph_1_foo"
                      }
                    }
                  }
                }
                library {
                  function {
                    signature {
                      name: "graph_1_foo"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                  }
                }
            """,
            expected_proto,
        )

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "PartitionedCall"
                  attr {
                    key: "f"
                    value {
                      func {
                        name: "foo"
                      }
                    }
                  }
                }
                library {
                  function {
                    signature {
                      name: "foo"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                  }
                }
            """,
            graph_def_a,
        )
        graph_def_b = GraphDef()

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )
    def test_merge_graph_defs_function(self):
        expected_proto = """
            library {
              function {
                signature {
                  name: "graph_1_foo"
                  input_arg {
                    name: "x"
                    type: DT_HALF
                  }
                  output_arg {
                    name: "identity"
                    type: DT_HALF
                  }
                }
                node_def {
                  name: "add"
                  op: "Add"
                  input: "x"
                  input: "y"
                }
              }
              function {
                signature {
                  name: "graph_2_foo"
                  input_arg {
                    name: "x"
                    type: DT_INT32
                  }
                  output_arg {
                    name: "identity"
                    type: DT_INT32
                  }
                }
                node_def {
                  name: "add"
                  op: "Add"
                  input: "x"
                  input: "y"
                }
              }
              function {
                signature {
                  name: "graph_2_foo_1"
                  input_arg {
                    name: "x"
                    type: DT_HALF
                  }
                  output_arg {
                    name: "identity"
                    type: DT_HALF
                  }
                }
                node_def {
                  name: "add"
                  op: "Add"
                  input: "x"
                  input: "y"
                }
              }
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                library {
                  function {
                    signature {
                      name: "foo"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                    node_def {
                      name: "add"
                      op: "Add"
                      input: "x"
                      input: "y"
                    }
                  }
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                library {
                  function {
                    signature {
                      name: "foo"
                      input_arg {
                        name: "x"
                        type: DT_INT32
                      }
                      output_arg {
                        name: "identity"
                        type: DT_INT32
                      }
                    }
                    node_def {
                      name: "add"
                      op: "Add"
                      input: "x"
                      input: "y"
                    }
                  }
                  function {
                    signature {
                      name: "foo_1"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                    node_def {
                      name: "add"
                      op: "Add"
                      input: "x"
                      input: "y"
                    }
                  }
                }
            """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )
    def test_merge_graph_defs(self):
        expected_proto = """
            node {
              name: "graph_1/X"
              op: "Input"
            }
            node {
              name: "graph_1/W"
              op: "Input"
            }
            node {
              name: "graph_1/Y"
              op: "MatMul"
              input: "graph_1/X"
              input: "graph_1/W"
            }
            node {
              name: "graph_2/A"
              op: "Input"
            }
            node {
              name: "graph_2/B"
              op: "Input"
            }
            node {
              name: "graph_2/C"
              op: "MatMul"
              input: "graph_2/A"
              input: "graph_2/B"
            }
            node {
              name: "graph_3/A"
              op: "Input"
            }
            node {
              name: "graph_3/B"
              op: "Input"
            }
            versions {
              producer: 21
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "Input"
                }
                node {
                  name: "W"
                  op: "Input"
                }
                node {
                  name: "Y"
                  op: "MatMul"
                  input: "X"
                  input: "W"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "A"
                  op: "Input"
                }
                node {
                  name: "B"
                  op: "Input"
                }
                node {
                  name: "C"
                  op: "MatMul"
                  input: "A"
                  input: "B"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_b,
        )

        graph_def_c = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "A"
                  op: "Input"
                }
                node {
                  name: "B"
                  op: "Input"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_c,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs(
                [graph_def_a, graph_def_b, graph_def_c]
            ),
        )
    def test_merge_graph_defs_name_collided_with_same_content(self):
        expected_proto = """
            node {
              name: "graph_1/X"
              op: "Input"
            }
            node {
              name: "graph_1/W"
              op: "Input"
            }
            node {
              name: "graph_1/Y"
              op: "MatMul"
              input: "graph_1/X"
              input: "graph_1/W"
            }
            node {
              name: "graph_2/X"
              op: "Input"
            }
            node {
              name: "graph_2/A"
              op: "Input"
            }
            node {
              name: "graph_2/Y"
              op: "MatMul"
              input: "graph_2/X"
              input: "graph_2/A"
            }
            versions {
              producer: 21
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "Input"
                }
                node {
                  name: "W"
                  op: "Input"
                }
                node {
                  name: "Y"
                  op: "MatMul"
                  input: "X"
                  input: "W"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "Input"
                }
                node {
                  name: "A"
                  op: "Input"
                }
                node {
                  name: "Y"
                  op: "MatMul"
                  input: "X"
                  input: "A"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )