Example #1
0
    def test_parse_tensor(self):
        # Zero-rank tensor
        attr = attr_value.AttrValue()
        attr.tensor.version_number = 1
        attr.tensor.dtype = types.DataType.DT_INT32
        t = parse.parse_attr(attr)
        self.assertTrue(isinstance(t, mil_types.int32))
        self.assertEqual(0, t.val)

        # Non-zero rank
        attr = attr_value.AttrValue()
        attr.tensor.version_number = 1
        attr.tensor.dtype = types.DataType.DT_INT32
        shaped_attr = self._attr_with_shape([(1, "outer"), (2, "middle"),
                                             (3, "inner")])
        attr.tensor.tensor_shape.dim.extend(shaped_attr.shape.dim)
        attr.tensor.int_val.extend([55, 56, 57])

        t = parse.parse_attr(attr)
        self.assertEqual([55, 56, 57], t.val.tolist())
        self.assertEqual("tensor", mil_types.get_type_info(t).name)

        # Note that the result of t.get_primitive() is a function that returns a type
        # rather than an instance of that type as it is when the tensor has rank zero.
        self.assertTrue(isinstance(t.get_primitive()(), mil_types.int32))
        self.assertEqual((1, 2, 3), t.get_shape())
Example #2
0
    def visit(self, graph, node, nodename_prefix=""):
        if node.name in self.visited_memo:
            return self

        # For printing datatype, breaks type
        if node.attr.get("symbolic_datatype", None) is not None:
            dtype = str(types.get_type_info(node.attr["symbolic_datatype"]))
        elif node.datatype is not None:
            dtype = str(types.get_type_info(node.datatype))
        else:
            dtype = "Unknown"

        label = ""
        if self.alternate_labeller is not None:
            label = self.alternate_labeller(node)
        else:
            if len(node.outputs) == 0:
                label = "\\n{" + node.name + "}"
            if "Placeholder" in node.op:
                label = "\\n{" + node.name + "}"
            if node.op == "while":
                label = ("\\n{body: " + node.attr["body_function"] + " cond:" +
                         node.attr["cond_function"] + "}")
            if node.op == "function":
                label = "\\n{body: " + node.attr["function_name"] + "}"
            if node.op == "function_entry":
                label = "\\n{" + node.name + "}"
            label = node.op + ":" + dtype + label

        if node.name in self.highlights:
            self.result.append('"' + nodename_prefix + node.name + '"' +
                               '[label="' + label +
                               '",fillcolor=%s,style=filled,fontcolor=%s]' % (
                                   self.highlights[node.name],
                                   "violetred" if node.attr.
                                   get(self.annotation, False) else "black",
                               ))
        else:
            self.result.append('"' + nodename_prefix + node.name + '"' +
                               '[label="' + label + '",fontcolor=%s]' %
                               ("violetred" if node.attr.
                                get(self.annotation, False) else "black"))

        for i in node.inputs:
            input_name = i
            edge = ('"' + nodename_prefix + input_name + '"' + " -> " + '"' +
                    nodename_prefix + node.name + '"')
            innode = graph[input_name]
            self.result.append(edge)

        for i in node.control_inputs:
            input_name = i
            edge = ('"' + nodename_prefix + input_name + '"' + " -> " + '"' +
                    nodename_prefix + node.name + '"')
            innode = graph[input_name]
            edge = edge + " [style=dotted]"
            self.result.append(edge)

        self.visited_memo[node.name] = 1

        for i in node.inputs:
            input_name = i
            if input_name[0] == "^":
                input_name = input_name[1:]
            assert input_name in graph
            self.visit(graph, graph[input_name], nodename_prefix)
        return self