Exemple #1
0
def annotate(components, lazy=True):
    schemas = dict()
    components = list_to_dict(components)
    for name in components:
        component = components[name]
        component_type = component["meta"]["type"]
        validator = validation.get(component)
        try:
            if component_type not in schemas:
                schemas[component_type] = validator.load_schema(component_type)
            if schemas[component_type] is not None:
                validator.validate_schema(component, schemas[component_type])

            component["meta"]["input_shape"] = validator.get_input_shape(component, components)
            component["meta"]["shape"] = validator.get_shape(component, components)
            if not lazy:
                validator.eval_input_channels(component, components)
                validator.validate_connection(component, components)

            inputs = component["meta"]["inputs"]
        except Exception as err:
            msg = "Component %s annotation failed. \n" % (name)
            msg += str(err)
            logger.error(msg)
            logger.error(str(err))
            raise Exception(msg)
        components[name] = component
    return list(components.values())
Exemple #2
0
 def validate_connection(self, component, components):
     transformations = self.get_transformation()
     dims = self.get_dimensions(component, components)
     if transformations is None or (dims[1] is None):
         return True
     else:
         valid = False
         msg = ""
         for trans in transformations:
             if valid:
                 return valid
             try:
                 if len(trans[0]) != len(dims[0]):
                     _msg = "Number of inputs mismatch %s (ingredient: %s, dim %s)" % (
                         component["meta"]["name"],
                         component["meta"]["type"], str(dims))
                     # logger.error(msg)
                     raise Exception(msg)
                 elif len(
                         list(
                             filter(lambda x: len(x[0]) != len(x[1]),
                                    zip(trans[0], dims[0])))) != 0:
                     _msg = "Tensor rank mismatch %s (ingredient: %s, dim: %s)" % (
                         component["meta"]["name"],
                         component["meta"]["type"], str(dims))
                     raise Exception(_msg)
                 else:
                     trans_ = [i for sub in trans[0]
                               for i in sub] + trans[1]
                     dims_ = [i for sub in dims[0] for i in sub] + dims[1]
                     n = len(trans_)
                     pattern_required = [
                         trans_[i] == trans_[i + j] for i in range(n - 1)
                         for j in range(1, n - i)
                     ]
                     pattern_data = [
                         dims_[i] == dims_[i + j] for i in range(n - 1)
                         for j in range(1, n - i)
                     ]
                     for r in pattern_required:
                         if pattern_required and (not pattern_data):
                             _msg = "Tensor dimension mismatch" % component[
                                 "meta"]["name"]
                             raise Exception(_msg)
                 valid = True
             except Exception as err:
                 msg += "%s\n " % str(err)
                 # traceback.print_exc()
                 # raise Exception(msg)
         if not valid:
             logger.error("Validation failed %s, %s" %
                          (component["meta"]["name"], msg))
             logger.error("Dimension %s" % dims)
             raise Exception()
Exemple #3
0
 def eval_input_channels(self, component, components):
     inputs = self.get_input_shape(component, components)
     try:
         if inputs is not None:
             if inputs[0] is not None:
                 util.replace_key(component["value"]["hyperparams"],
                                  "input_shape", inputs[0][-1])
     except Exception as err:
         logger.error(err)
         logger.error(
             "Ingredient %s (type: %s) invalid inputs" %
             (component["meta"]["name"], component["meta"]["type"]))
Exemple #4
0
def connect_to(components, current_component_type, current_component_hyperparams=None, range=None, types=None, scope="all"):

    def _get_name_scope(_component):
        return "/".join(_component["meta"]["name"].split("/")[:-1])

    def _get_type(_component):
        return _component["meta"]["type"]

    if scope != "all" and (scope != "current"):
        logger.error("Scope has to be 'current' or 'all'!")
        raise Exception
    if current_component_hyperparams is None:
        current_component_hyperparams = components[-1]["value"]["hyperparams"]
    current_component_name_scope = _get_name_scope(components[-1])
    if range is not None:
        lower, upper = _parse_range(range)
    else:
        lower, upper = -2, -1

    connected_component = []
    if lower == "" and upper != "":
        _components = components[:upper]
    elif upper == "" and lower != "":
        _components = components[lower:]
    elif upper != "" and lower != "":
        _components = components[lower:upper]
    else:
        _components = components

    for component in _components:
        type_match = True
        scope_match = True
        _component_type = _get_type(component)
        if types is not None and (_component_type not in types):
            type_match = False
        if scope == "current" and (current_component_name_scope != (_get_name_scope(component))):
            scope_match = False
        if type_match and scope_match:
            connected_component.append("%s/%s" % (component["meta"]["name"],
                                                  types[_component_type]))
            # Make regularization part of the hyperparam subtree
            # component["value"]["hyperparams"][types[_component_type]][current_component_type] = current_component_hyperparams
        delete_key(component["value"]["hyperparams"], "connect_to")
    delete_key(current_component_hyperparams, "connect_to")
    return connected_component
Exemple #5
0
def draw_graph(graph_list, level=2, graph_path='example.png', show="name"):
    graph = pydot.Dot(graph_type='digraph', rankdir='LR', dpi=800, size=5, fontsize=18)
    prev = None
    for component in graph_list:
        _dir = component[const.META][const.NAME].split("/")
        _dir = _dir[:level]
        name = "/".join(_dir)
        if prev is None:
            prev = name

        if name == prev:
            continue
        else:
            prev = name
        component_scope = component[const.META][const.SCOPE]
        component_type = component[const.META][const.TYPE]
        node = pydot.Node(name, shape="box")
        if show == "name":
            label = name.split("/")[-1]# .split("_")[0]
        else:
            if show not in component["meta"] or show=="type":
                logger.error("Lable '%s' is not available" % show)
                raise Exception
            label = str(component["meta"][show])
        node.set_label(label)
        graph.add_node(node)

        _type = component[const.META][const.TYPE]
        if not _type in registry.DATA_BLOCK:
            inputs = component[const.META][const.INPUTS]
            for _input in inputs:
                if _type not in registry.REGULARIZERS:
                    _input = "/".join(_input.split("/")[:level])
                _edge = pydot.Edge(_input, name)
                graph.add_edge(_edge)

    graph.write_png(graph_path)
    return graph
Exemple #6
0
def parse(yml_file, return_dict=False, lazy=True):
    global global_count0, global_count1, component_count, prev_component
    global_count0 = 0
    global_count1 = 0
    component_count = 0
    # print("parsing, yml_file", yml_file)
    prev_component = None
    try:
        # Read the yml file
        if isinstance(yml_file, str):
            user_defined = util.read_yaml(yml_file)
        elif isinstance(yml_file, dict):
            user_defined = deepcopy(yml_file)
        else:
            msg = "First argument has to be a path or a dict type"
            logger.error(msg)
            raise TypeError(msg)
        reader = alex_reader(user_defined)
        # Step 1: transform yml into a (recursive) dict
        # Each node in the dict has the following keys:
        # hyperparams: list or dict, default: {}
        # type: string
        # inputs: list, default: None
        # name: string, default: None
        # (optional) dtype: data type
        # (optional) shape: list
        _graph:dict = config_to_inter_graph(user_defined, reader)
        # Step 2: transform _graph into an intermediate ast, where each key is a function
        # name
        # inputs
        # repeat
        # e.g. ("inputs", inputs_list, ("name", name_str, ("repeat", n, [])))
        ast:list = inter_graph_to_ast(_graph)

        # Step 3: first pass
        # A list of dicts
        # The reason of having a list of dicts as the main underlying data structure is due to its simplicity. This data structure can be transformed into the final ast
        # Each dict has the following structure:
        # "meta": {"inputs": None or list,
        #          "name": string (scope/.../name),
        #          "scope": the scope of the node,
        #          "type": type of the deep learning component}
        # "value": {"hyperparams": dict of hyperparameters,
        #           "stats": {}
        #           "tensor": library specific tensor object (not easily serializable),
        #           "value": a serializable object that contains the same info as tensor,
        #           "var": }
        graph:list = eval_ast_to_graph(ast)
        # Step 4: second pass: same data structure; populate inputs
        graph:list = global_update(graph)
        # if config_to_type(yml_file) not in const.RECIPE_TYPES:
        graph = annotate(graph, lazy=lazy) # json schema
        # Step 5: load graph and states from checkpoint; check if the structure matches the one defined in the DSL
    except Exception as err:
        msg = "Error during parsing configuration %s" % yml_file
        msg += "\n %s" % str(err)
        logger.error(msg)
        raise Exception(str(err))
    if return_dict:
        graph = list_to_dict(graph)
    return graph
Exemple #7
0
    def _draw_hyperparam_tree(_graph,
                              parent_node: str,
                              graph):

        _graph = clone(_graph)
        if _is_simple(_graph):
            graph = _add_node(parent_node=parent_node,
                              label=str(_graph),
                              name="%s/%s" % (parent_node, str(_graph)),
                              graph=graph)
        else:
            for i, _node in enumerate(_graph):
                """ The assumption here is that _graph can be either a dict
                        or hyperparams and hyperparams have a very simple
                        structure, which means that _node is either a string
                        or a scalar. If _graph is a list and _node is a complex
                        data structure, it is not handled.
                """

                if _is_simple(_node):
                    label = str(_node) # node is simple
                else:
                    pprint(_node)
                    logger.error("Unrecognized node type")
                    raise Exception

                name = "%s/%s" % (parent_node, label)
                if isinstance(_graph, list):
                    name = "%s_%i" % (name, i)

                if isinstance(_graph, dict) and _is_component(_graph[_node]):
                    # FIXME:
                    if not _graph[_node]["meta"]["visible"]:
                        continue
                    label = clone(_graph[_node][const.TYPE])

                graph = _add_node(parent_node, label, name, graph)

                if isinstance(_graph, dict):
                    _node = clone(_graph[_node])
                    if _is_simple(_node):
                        graph = _add_node(parent_node=name,
                                          label=str(_node),
                                          name="%s/%s" % (name, str(_node)),
                                          graph=graph)
                    elif isinstance(_node, list):
                        for ii, __node in enumerate(_node):
                            graph = _add_node(parent_node=name,
                                              label=str(__node),
                                              name="%s/%s_%i" % (name, str(__node), ii),
                                              graph=graph)


                if isinstance(_node, dict):
                    # component or hyperparams
                    if _is_component(_node): # is a component (ingredient or recipe)
                        if _is_recipe(_node):
                            graph = _draw_hyperparam_tree(_node[const.SUBGRAPH],
                                                          parent_node=name,
                                                          graph=graph)
                        # go in the hyperparam tree
                        if recipe_defined(_node[const.TYPE]):
                            graph = _draw_hyperparam_tree(_node[const.HYPERPARAMS],
                                                          parent_node=name,
                                                          graph=graph)
                    else: # in the hyperparam tree
                        graph = _draw_hyperparam_tree(_node,
                                                      parent_node=name,
                                                      graph=graph)

        return graph