コード例 #1
0
ファイル: plot.py プロジェクト: vishalbelsare/baikal
 def build_output_edges(self, model, outer_port, container):
     root_name = make_name(model.name, outer_port, sep=":")
     keys = self.get_innermost_outputs_keys(model, outer_port)
     for (outer_port,
          inner_output), output in safezip2(keys, model._internal_outputs):
         src = self.node_names[outer_port, inner_output.node]
         dst = make_name(root_name, output.name)
         label = output.name
         container.add_node(dummy_dot_node(dst))
         container.add_edge(dot_edge(src, dst, label, "black"))
コード例 #2
0
ファイル: plot.py プロジェクト: vitalyvels/baikal
    def build_dot_from(model_, output_, container=None):
        if container is None:
            container = dot_graph

        parent_step = output_.step

        if parent_step in nodes_built:
            return

        if isinstance(parent_step, Model) and expand_nested:
            # Build nested model
            nested_model = parent_step
            cluster = pydot.Cluster(name=nested_model.name,
                                    label=nested_model.name,
                                    style="dashed")

            for output_ in nested_model._internal_outputs:
                build_dot_from(nested_model, output_, cluster)
            container.add_subgraph(cluster)

            # Connect with outer model
            for input, internal_input in safezip2(
                    nested_model.inputs, nested_model._internal_inputs):
                build_edge(input.step, internal_input.step, input.name,
                           container)

        else:
            # Build step
            if parent_step in [
                    input.step for input in model_._internal_inputs
            ]:
                container.add_node(
                    pydot.Node(name=parent_step.name,
                               shape="invhouse",
                               color="green"))
            else:
                container.add_node(
                    pydot.Node(name=parent_step.name, shape="rect"))

            # Build incoming edges
            for input in parent_step.inputs:
                if isinstance(input.step, Model) and expand_nested:
                    nested_model = input.step
                    index = nested_model.outputs.index(input)
                    internal_output = nested_model._internal_outputs[index]
                    build_edge(internal_output.step, parent_step, input.name,
                               container)
                else:
                    build_edge(input.step, parent_step, input.name, container)

        nodes_built.add(parent_step)

        # Continue building
        for input in parent_step.inputs:
            build_dot_from(model_, input)
コード例 #3
0
 def _update_cache(cache, output_data, node):
     try:
         cache.update(safezip2(node.outputs, output_data))
     except ValueError as e:
         message = (
             "The number of output data elements ({}) does not match "
             "the number of {} outputs ({}).".format(
                 len(output_data), node.step.name, len(node.outputs)
             )
         )
         raise RuntimeError(message) from e
コード例 #4
0
def test_split(x, indices_or_sections, teardown):
    x1 = Input()
    ys = Split(indices_or_sections, axis=0)(x1)
    model = Model(x1, ys)

    y_expected = np.split(x, indices_or_sections, axis=0)
    y_pred = model.predict(x)
    y_pred = listify(y_pred)

    for actual, expected in safezip2(y_pred, y_expected):
        assert_array_equal(actual, expected)
コード例 #5
0
    def _compute_step(step, Xs, cache):
        # TODO: Raise warning if computed output is already in cache.
        # This happens when recomputing a step that had a subset of its outputs already passed in the inputs.
        # TODO: Some regressors have extra options in their predict method, and they return a tuple of arrays.
        # https://scikit-learn.org/stable/glossary.html#term-predict
        output_data = step.compute(*Xs)
        output_data = listify(output_data)

        try:
            cache.update(safezip2(step.outputs, output_data))
        except ValueError as e:
            message = (
                "The number of output data elements ({}) does not match "
                "the number of {} outputs ({}).".format(
                    len(output_data), step.name, len(step.outputs)))
            raise RuntimeError(message) from e
コード例 #6
0
    def _normalize_list(
        data: ArrayLikes, data_placeholders: List[DataPlaceholder]
    ) -> Dict[DataPlaceholder, ArrayLike]:
        data = listify(data)

        try:
            data_norm = dict(safezip2(data_placeholders, data))

        except ValueError as e:
            # TODO: Improve this message
            message = (
                "When passing inputs/outputs as a list or a single array, "
                "the number of arrays must match the number of inputs/outputs "
                "specified at instantiation. "
                "Got {}, expected: {}.".format(len(data), len(data_placeholders))
            )
            raise ValueError(message) from e

        return data_norm
コード例 #7
0
def test_safezip2(x, y, raises):
    with raises:
        z = list(safezip2(x, y))
        assert z == list(zip(x, y))