Beispiel #1
0
def diff(network_config_1,
         network_config_2,
         render_to=None,
         dpi=800, lazy=False, merge_blocks=[]):
    if not isinstance(network_config_1, list):
        graph_list1 = dsl_parser.parse(network_config_1, lazy=lazy)
    else:
        graph_list1 = network_config_1
    if not isinstance(network_config_2, list):
        graph_list2 = dsl_parser.parse(network_config_2, lazy=lazy)
    else:
        graph_list2 = network_config_2
    tree_full1 = param_count.ParamCount(graph_list1).annotate_tree()
    tree_full2 = param_count.ParamCount(graph_list2).annotate_tree()

    operations = []
    cost = 0

    graph_list1_block = []
    graph_list2_block = []
    for block in merge_blocks:
        graph_list1_block += list(filter(lambda x: x["meta"]["block"]==block, graph_list1))
        graph_list2_block += list(filter(lambda x: x["meta"]["block"]==block, graph_list2))
    _cost, _operations = _diff_graph_list(graph_list1_block,
                                          graph_list2_block,
                                          tree_full1, tree_full2)
    cost += _cost
    operations += _operations

    for block in registry.BLOCKS:
        if block in merge_blocks:
            continue
        graph_list1_block = list(filter(lambda x: x["meta"]["block"]==block, graph_list1))
        graph_list2_block = list(filter(lambda x: x["meta"]["block"]==block, graph_list2))
        _cost, _operations = _diff_graph_list(graph_list1_block,
                                              graph_list2_block,
                                              tree_full1, tree_full2)
        cost += _cost
        operations += _operations

    if render_to is not None:
        print("Done computing diff. Rendering image")
        annotation = annotate_ops(operations)
        img1 = core.draw(tree_full1, None, annotation[1], dpi=dpi)
        img2 = core.draw(tree_full2, None, annotation[2], dpi=dpi)
        fig, axs = plt.subplots(1, 2, dpi=dpi)
        axs[0].imshow(img1)
        axs[0].axis("off")
        axs[1].imshow(img2)
        axs[1].axis("off")
        fig.tight_layout()
        if render_to != "":
            fig.savefig(render_to, dpi="figure")
            print("Diff images written to: %s" % render_to)
        else:
            print("Diff images rendered to screen")
            plt.show()

    return cost, operations
Beispiel #2
0
 def test_annotate_complexity(self):
     complexity = complexity_measure.ComplexityMeasure(self.config_path)
     annotated = complexity.annotate_tree()
     core.draw(annotated,
               graph_path=os.path.join(const.CACHE_BASE_PATH,
                                       'complexity.png'),
               label_field="complexity",
               excluded_types=["hyperparam"])
Beispiel #3
0
 def test_annotate_trainbale_params(self):
     params_counter = param_count.ParamCount(self.config_path)
     annotated = params_counter.annotate_tree()
     core.draw(annotated,
               graph_path=os.path.join(const.CACHE_BASE_PATH,
                                       'trainable_param_count.png'),
               label_field="count",
               excluded_types=["hyperparam"])
     core.draw(annotated,
               graph_path=os.path.join(const.CACHE_BASE_PATH, 'shape.png'),
               label_field="shape",
               excluded_types=["hyperparam"])
Beispiel #4
0
 def test_ted_on_conv(self):
     conv_tree1 = core.json_to_tree(self.conv, "conv")
     conv_tree2 = core.json_to_tree(self.modified_conv, "conv")
     costs, operations = compare.ted(conv_tree1, conv_tree2)
     self.assertEqual(costs, 7)
     annotation = compare.annotate_ops(operations)
     core.draw(
         conv_tree1,
         os.path.join(const.CACHE_BASE_PATH, "test_compare_conv1.png"),
         annotation[1])
     core.draw(
         conv_tree2,
         os.path.join(const.CACHE_BASE_PATH, "test_compare_conv2.png"),
         annotation[2])
Beispiel #5
0
 def test_ted(self):
     toy_tree = core.json_to_tree(self.toy_json)
     toy_tree2 = core.json_to_tree(self.toy_json2)
     costs, operations = compare.ted(toy_tree, toy_tree2)
     annotation = compare.annotate_ops(operations)
     self.assertEqual(costs, 8)
     core.draw(
         toy_tree,
         os.path.join(const.CACHE_BASE_PATH, "test_compare_toytree1.png"),
         annotation[1])
     core.draw(
         toy_tree2,
         os.path.join(const.CACHE_BASE_PATH, "test_compare_toytree2.png"),
         annotation[2])
Beispiel #6
0
def _dist_graph_list(tree1, tree2, exclude_types=[], render_to=None, dpi=800):
    cost, operations = ted(tree1, tree2)
    if render_to is not None:
        print("Done computing distance. Rendering image")
        annotation = annotate_ops(operations)
        img1 = core.draw(tree1, None, annotation[1], dpi=dpi)
        img2 = core.draw(tree2, None, annotation[2], dpi=dpi)
        fig, axs = plt.subplots(1, 2, dpi=dpi)
        axs[0].imshow(img1)
        axs[0].axis("off")
        axs[1].imshow(img2)
        axs[1].axis("off")
        fig.tight_layout()
        if render_to != "":
            fig.savefig(render_to, dpi="figure")
            print("Distance images written to: %s" % render_to)
        else:
            print("Distance images rendered to screen")
            plt.show()
    return cost, operations
Beispiel #7
0
def matched_ingredients(network_config_1,
                        network_config_2,
                        render_to=None,
                        dpi=800):
    # n1 = dsl_parser.list_to_graph(graph_list1)
    # n2 = dsl_parser.list_to_graph(graph_list2)
    # tree1 = core.alex_graph_to_tree(n1, exclude_types=["hyperparam"], naive=False)
    # tree2 = core.alex_graph_to_tree(n2, exclude_types=["hyperparam"], naive=False)
    tree1 = param_count.ParamCount(network_config_1,
                                   naive=False).annotate_tree()
    tree2 = param_count.ParamCount(network_config_2,
                                   naive=False).annotate_tree()
    _, operations = diff(network_config_1, network_config_2, merge_blocks=["model_block", "loss_block"])
    if render_to is not None:
        print("Done computing diff. Rendering image")
        annotation = annotate_ops(operations)
        img1 = core.draw(tree1, None, annotation[1], dpi=dpi)
        img2 = core.draw(tree2, None, annotation[2], dpi=dpi)
        fig, axs = plt.subplots(1, 2, dpi=dpi)
        axs[0].imshow(img1)
        axs[0].axis("off")
        axs[1].imshow(img2)
        axs[1].axis("off")
        fig.tight_layout()
        if render_to != "" or render_to is None:
            fig.savefig(render_to, dpi="figure")
            print("Diff images written to: %s" % render_to)
        else:
            print("Diff images rendered to screen")
            plt.show()
    matched = {}
    for operation in operations:
        if operation[1][0] == "MATCH":
            component_type = label_to_value(operation[1][1])
            if component_type in registry.PARAMS:
                matched = {**matched,
                           **{operation[0][0]:
                              operation[0][1]}}
    return matched