Esempio n. 1
0
    def test_is_attr_legal_verbose(self):  # type: () -> None

        def _set(attr, type, var, value):  # type: (AttributeProto, AttributeProto.AttributeType, Text, Any) -> None
            setattr(attr, var, value)
            setattr(attr, 'type', type)

        def _extend(attr, type, var, value):  # type: (AttributeProto, AttributeProto.AttributeType, List[Any], Any) -> None
            var.extend(value)
            setattr(attr, 'type', type)

        SET_ATTR = [
            (lambda attr: _set(attr, AttributeProto.FLOAT, "f", 1.0)),
            (lambda attr: _set(attr, AttributeProto.INT, "i", 1)),
            (lambda attr: _set(attr, AttributeProto.STRING, "s", b"str")),
            (lambda attr: _extend(attr, AttributeProto.FLOATS, attr.floats, [1.0, 2.0])),
            (lambda attr: _extend(attr, AttributeProto.INTS, attr.ints, [1, 2])),
            (lambda attr: _extend(attr, AttributeProto.STRINGS, attr.strings, [b"a", b"b"])),
        ]
        # Randomly set one field, and the result should be legal.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(SET_ATTR)(attr)
            checker.check_attribute(attr)
        # Randomly set two fields, and then ensure helper function catches it.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(SET_ATTR, 2):
                func(attr)
            self.assertRaises(checker.ValidationError,
                              checker.check_attribute,
                              attr)
def _rename_node_input(onnx_node, old_name, new_name=None):
    """
    Renames an input from a node.

    :param onnx_node: onnx_node
    :param old_name: old name
    :param new_name: new name or None if *old_name* is a dictionary
    :return: new node
    """
    inputs = [_replace(name, old_name, new_name) for name in onnx_node.input]
    outputs = list(onnx_node.output)
    if hasattr(onnx_node, 'attribute'):
        new_atts = []
        for att in onnx_node.attribute:
            if att.name == 'body':
                new_body = _rename_graph_input(att.g, old_name, new_name)
                attr = AttributeProto()
                attr.name = att.name
                attr.g.CopyFrom(new_body)
                attr.type = AttributeProto.GRAPH
                new_atts.append(attr)
            else:
                new_atts.append(att)
        atts = new_atts
    else:
        atts = onnx_node.attribute
    node = _make_node(onnx_node.op_type,
                      inputs,
                      outputs,
                      name=onnx_node.name,
                      domain=onnx_node.domain,
                      attributes=atts)
    return node
Esempio n. 3
0
 def test_is_attr_legal(self):  # type: () -> None
     # no name, no field
     attr = AttributeProto()
     self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
     # name, but no field
     attr = AttributeProto()
     attr.name = "test"
     self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
     # name, with two fields
     attr = AttributeProto()
     attr.name = "test"
     attr.f = 1.0
     attr.i = 2
     self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
Esempio n. 4
0
    def test_is_attr_legal_verbose(self):  # type: () -> None

        def _set(attr, type, var, value):  # type: (AttributeProto, AttributeProto.AttributeType, Text, Any) -> None
            setattr(attr, var, value)
            setattr(attr, 'type', type)

        def _extend(attr, type, var, value):  # type: (AttributeProto, AttributeProto.AttributeType, List[Any], Any) -> None
            var.extend(value)
            setattr(attr, 'type', type)

        SET_ATTR = [
            (lambda attr: _set(attr, AttributeProto.FLOAT, "f", 1.0)),
            (lambda attr: _set(attr, AttributeProto.INT, "i", 1)),
            (lambda attr: _set(attr, AttributeProto.STRING, "s", b"str")),
            (lambda attr: _extend(attr, AttributeProto.FLOATS, attr.floats, [1.0, 2.0])),
            (lambda attr: _extend(attr, AttributeProto.INTS, attr.ints, [1, 2])),
            (lambda attr: _extend(attr, AttributeProto.STRINGS, attr.strings, [b"a", b"b"])),
        ]
        # Randomly set one field, and the result should be legal.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(SET_ATTR)(attr)
            checker.check_attribute(attr)
        # Randomly set two fields, and then ensure helper function catches it.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(SET_ATTR, 2):
                func(attr)
            self.assertRaises(checker.ValidationError,
                              checker.check_attribute,
                              attr)
Esempio n. 5
0
    def test_is_attr_legal_verbose(self):

        SET_ATTR = [
            (lambda attr: setattr(attr, "f", 1.0) or \
             setattr(attr, 'type', AttributeProto.FLOAT)),
            (lambda attr: setattr(attr, "i", 1) or \
             setattr(attr, 'type', AttributeProto.INT)),
            (lambda attr: setattr(attr, "s", b"str") or \
             setattr(attr, 'type', AttributeProto.STRING)),
            (lambda attr: attr.floats.extend([1.0, 2.0]) or \
             setattr(attr, 'type', AttributeProto.FLOATS)),
            (lambda attr: attr.ints.extend([1, 2]) or \
             setattr(attr, 'type', AttributeProto.INTS)),
            (lambda attr: attr.strings.extend([b"a", b"b"]) or \
             setattr(attr, 'type', AttributeProto.STRINGS)),
        ]
        # Randomly set one field, and the result should be legal.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(SET_ATTR)(attr)
            checker.check_attribute(attr)
        # Randomly set two fields, and then ensure helper function catches it.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(SET_ATTR, 2):
                func(attr)
            self.assertRaises(checker.ValidationError,
                              checker.check_attribute,
                              attr)
Esempio n. 6
0
 def test_is_attr_legal(self):  # type: () -> None
     # no name, no field
     attr = AttributeProto()
     self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
     # name, but no field
     attr = AttributeProto()
     attr.name = "test"
     self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
     # name, with two fields
     attr = AttributeProto()
     attr.name = "test"
     attr.f = 1.0
     attr.i = 2
     self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
Esempio n. 7
0
def make_attribute(key: str, value) -> IAttributeProto:
    attr = AttributeProto()
    attr.name = key

    is_iterable = isinstance(value, Iterable)
    bytes_or_false = _to_bytes_or_false(value)
    # First, singular cases

    # float
    if isinstance(value, float):
        attr.f = value

    # integer
    elif isinstance(value, numbers.Integral):
        attr.i = value

    # string
    elif bytes_or_false:
        attr.s = bytes_or_false

    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)

    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)

    # third, iterable cases
    elif is_iterable:
        byte_array = [_to_bytes_or_false(v) for v in value]
        if all(isinstance(v, float) for v in value):
            attr.floats.extend(value)

        elif all(isinstance(v, numbers.Integral) for v in value):
            # Turn np.int32/64 into Python built-in int.
            attr.ints.extend(int(v) for v in value)

        elif all(byte_array):
            attr.strings.extend(byte_array)

        elif all(isinstance(v, TensorProto) for v in value):
            attr.tensors.extend(value)

        elif all(isinstance(v, GraphProto) for v in value):
            attr.graphs.extend(value)

        else:
            raise ValueError(
                "You passed in an iterable attribute but I cannot figure out "
                "its applicable type.")
    else:
        raise ValueError(
            'Value "{}" is not valid attribute data type.'.format(value))

    return attr
Esempio n. 8
0
def _Attribute_default_value(self):  # type: ignore
    attr = AttributeProto()
    attr.ParseFromString(self._default_value)
    return attr
Esempio n. 9
0
File: helper.py Progetto: zoq/onnx
def make_attribute(
    key,  # type: Text
    value,  # type: Any
    doc_string=None  # type: Optional[Text]
):  # type: (...) -> AttributeProto
    """Makes an AttributeProto based on the value type."""
    attr = AttributeProto()
    attr.name = key
    if doc_string:
        attr.doc_string = doc_string

    is_iterable = isinstance(value, collections.Iterable)
    bytes_or_false = _to_bytes_or_false(value)
    # First, singular cases
    # float
    if isinstance(value, float):
        attr.f = value
        attr.type = AttributeProto.FLOAT
    # integer
    elif isinstance(value, numbers.Integral):
        attr.i = cast(int, value)
        attr.type = AttributeProto.INT
    # string
    elif bytes_or_false is not False:
        assert isinstance(bytes_or_false, bytes)
        attr.s = bytes_or_false
        attr.type = AttributeProto.STRING
    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)
        attr.type = AttributeProto.TENSOR
    elif isinstance(value, SparseTensorProto):
        attr.sparse_tensor.CopyFrom(value)
        attr.type = AttributeProto.SPARSE_TENSOR
    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)
        attr.type = AttributeProto.GRAPH
    # third, iterable cases
    elif is_iterable:
        byte_array = [_to_bytes_or_false(v) for v in value]
        if all(isinstance(v, float) for v in value):
            attr.floats.extend(value)
            attr.type = AttributeProto.FLOATS
        elif all(isinstance(v, numbers.Integral) for v in value):
            # Turn np.int32/64 into Python built-in int.
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
        elif all(
                map(lambda bytes_or_false: bytes_or_false is not False,
                    byte_array)):
            attr.strings.extend(cast(List[bytes], byte_array))
            attr.type = AttributeProto.STRINGS
        elif all(isinstance(v, TensorProto) for v in value):
            attr.tensors.extend(value)
            attr.type = AttributeProto.TENSORS
        elif all(isinstance(v, SparseTensorProto) for v in value):
            attr.sparse_tensors.extend(value)
            attr.type = AttributeProto.SPARSE_TENSORS
        elif all(isinstance(v, GraphProto) for v in value):
            attr.graphs.extend(value)
            attr.type = AttributeProto.GRAPHS
        else:
            raise ValueError(
                "You passed in an iterable attribute but I cannot figure out "
                "its applicable type.")
    else:
        raise TypeError(
            'value "{}" is not valid attribute data type.'.format(value))
    return attr
Esempio n. 10
0
def printable_attribute(attr: AttributeProto, subgraphs: bool = False) -> Union[Text, Tuple[Text, List[GraphProto]]]:
    content = []
    content.append(attr.name)
    content.append("=")

    def str_float(f: float) -> Text:
        # NB: Different Python versions print different numbers of trailing
        # decimals, specifying this explicitly keeps it consistent for all
        # versions
        return '{:.15g}'.format(f)

    def str_int(i: int) -> Text:
        return str(i)

    _T = TypeVar('_T')  # noqa

    def str_list(str_elem: Callable[[_T], Text], xs: Sequence[_T]) -> Text:
        return '[' + ', '.join(map(str_elem, xs)) + ']'

    # for now, this logic should continue to work as long as we are running on a proto3
    # implementation. If/when we switch to proto3, we will need to use attr.type

    # To support printing subgraphs, if we find a graph attribute, print out
    # its name here and pass the graph itself up to the caller for later
    # printing.
    graphs = []
    if attr.HasField("f"):
        content.append(str_float(attr.f))
    elif attr.HasField("i"):
        content.append(str_int(attr.i))
    elif attr.HasField("s"):
        # TODO: Bit nervous about Python 2 / Python 3 determinism implications
        content.append(repr(_sanitize_str(attr.s)))
    elif attr.HasField("t"):
        if len(attr.t.dims) > 0:
            content.append("<Tensor>")
        else:
            # special case to print scalars
            field = STORAGE_TENSOR_TYPE_TO_FIELD[attr.t.data_type]
            content.append('<Scalar Tensor {}>'.format(str(getattr(attr.t, field))))
    elif attr.HasField("g"):
        content.append("<graph {}>".format(attr.g.name))
        graphs.append(attr.g)
    elif attr.HasField("tp"):
        content.append("<Type Proto {}>".format(attr.tp))
    elif attr.floats:
        content.append(str_list(str_float, attr.floats))
    elif attr.ints:
        content.append(str_list(str_int, attr.ints))
    elif attr.strings:
        # TODO: Bit nervous about Python 2 / Python 3 determinism implications
        content.append(str(list(map(_sanitize_str, attr.strings))))
    elif attr.tensors:
        content.append("[<Tensor>, ...]")
    elif attr.type_protos:
        content.append('[')
        for i, tp in enumerate(attr.type_protos):
            comma = ',' if i != len(attr.type_protos) - 1 else ''
            content.append('<Type Proto {}>{}'.format(tp, comma))
        content.append(']')
    elif attr.graphs:
        content.append('[')
        for i, g in enumerate(attr.graphs):
            comma = ',' if i != len(attr.graphs) - 1 else ''
            content.append('<graph {}>{}'.format(g.name, comma))
        content.append(']')
        graphs.extend(attr.graphs)
    else:
        content.append("<Unknown>")
    if subgraphs:
        return ' '.join(content), graphs
    else:
        return ' '.join(content)
def make_attribute(
    key,  # type: Text
    value,  # type: Any
    dtype=None,  # type: [np.float32, np.float64]
    domain='',  # type: Text
    doc_string=None  # type: Optional[Text]
):  # type: (...) -> AttributeProto
    """Makes an AttributeProto based on the value type."""
    attr = AttributeProto()
    attr.name = key
    if doc_string:
        attr.doc_string = doc_string

    is_iterable = isinstance(value, collections.abc.Iterable)
    bytes_or_false = _to_bytes_or_false(value)

    use_float64 = dtype == np.float64 and domain not in ('', 'ai.onnx.ml')

    if isinstance(value, np.float32):
        attr.f = value
        attr.type = AttributeProto.FLOAT
    elif isinstance(value, (float, np.float64)):
        if use_float64:
            attr.type = AttributeProto.TENSOR
            attr.t.CopyFrom(
                make_tensor(key, TensorProto.DOUBLE, (1, ), [value]))
        else:
            attr.f = value
            attr.type = AttributeProto.FLOAT
    elif isinstance(value, np.int32):
        attr.i = value
        attr.type = AttributeProto.INT
    elif isinstance(value, np.int64):
        attr.i = value
        attr.type = AttributeProto.INT
    elif isinstance(value, numbers.Integral):
        attr.i = value
        attr.type = AttributeProto.INT
    # string
    elif bytes_or_false is not False:
        assert isinstance(bytes_or_false, bytes)
        attr.s = bytes_or_false
        attr.type = AttributeProto.STRING
    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)
        attr.type = AttributeProto.TENSOR
    elif (SparseTensorProto is not None
          and isinstance(value, SparseTensorProto)):
        attr.sparse_tensor.CopyFrom(value)
        attr.type = AttributeProto.SPARSE_TENSOR
    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)
        attr.type = AttributeProto.GRAPH
    # third, iterable cases
    elif is_iterable:
        byte_array = [_to_bytes_or_false(v) for v in value]
        if all(isinstance(v, np.float32) for v in value):
            attr.floats.extend(value)
            attr.type = AttributeProto.FLOATS
        elif all(isinstance(v, np.float64) for v in value):
            if use_float64:
                attr.type = AttributeProto.TENSOR
                attr.t.CopyFrom(
                    make_tensor(key, TensorProto.DOUBLE, (len(value), ),
                                value))
            else:
                attr.floats.extend(value)
                attr.type = AttributeProto.FLOATS
        elif all(isinstance(v, float) for v in value):
            if use_float64:
                attr.type = AttributeProto.TENSOR
                attr.t.CopyFrom(
                    make_tensor(key, TensorProto.DOUBLE, (len(value), ),
                                value))
            else:
                attr.floats.extend(value)
                attr.type = AttributeProto.FLOATS
        elif all(isinstance(v, np.int32) for v in value):
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
        elif all(isinstance(v, np.int64) for v in value):
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
        elif all(isinstance(v, numbers.Integral) for v in value):
            # Turn np.int32/64 into Python built-in int.
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
        elif all(
                map(lambda bytes_or_false: bytes_or_false is not False,
                    byte_array)):
            attr.strings.extend(cast(List[bytes], byte_array))
            attr.type = AttributeProto.STRINGS
        elif all(isinstance(v, TensorProto) for v in value):
            attr.tensors.extend(value)
            attr.type = AttributeProto.TENSORS
        elif (SparseTensorProto is not None
              and all(isinstance(v, SparseTensorProto) for v in value)):
            attr.sparse_tensors.extend(value)
            attr.type = AttributeProto.SPARSE_TENSORS
        elif all(isinstance(v, GraphProto) for v in value):
            attr.graphs.extend(value)
            attr.type = AttributeProto.GRAPHS
        else:
            raise ValueError(
                "You passed in an iterable attribute but I cannot figure out "
                "its applicable type, key='{}', type={}, dtype={}, "
                "types={}.".format(
                    key, type(value), dtype,
                    [type(_) for _, __ in zip(value, range(0, 5))]))
    else:
        raise ValueError(
            "Value '{}' is not valid attribute data type for attribute "
            "'{}'.".format(value, key))
    return attr
Esempio n. 12
0
def make_attribute(key: str,
                   value: Any,
                   doc_string: Optional[str] = None) -> AttributeProto:
    """Makes an AttributeProto based on the value type."""
    attr = AttributeProto()
    attr.name = key
    if doc_string:
        attr.doc_string = doc_string

    is_iterable = isinstance(value, collections.abc.Iterable)
    bytes_or_false = _to_bytes_or_false(value)
    # First, singular cases
    # float
    if isinstance(value, float):
        attr.f = value
        attr.type = AttributeProto.FLOAT
    # integer
    elif isinstance(value, numbers.Integral):
        attr.i = cast(int, value)
        attr.type = AttributeProto.INT
    # string
    elif bytes_or_false is not False:
        assert isinstance(bytes_or_false, bytes)
        attr.s = bytes_or_false
        attr.type = AttributeProto.STRING
    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)
        attr.type = AttributeProto.TENSOR
    elif isinstance(value, SparseTensorProto):
        attr.sparse_tensor.CopyFrom(value)
        attr.type = AttributeProto.SPARSE_TENSOR
    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)
        attr.type = AttributeProto.GRAPH
    elif isinstance(value, TypeProto):
        attr.tp.CopyFrom(value)
        attr.type = AttributeProto.TYPE_PROTO
    # third, iterable cases
    elif is_iterable:
        byte_array = [_to_bytes_or_false(v) for v in value]
        if all(isinstance(v, numbers.Integral) for v in value):
            # Turn np.int32/64 into Python built-in int.
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
        elif all(isinstance(v, numbers.Real) for v in value):
            # Since ints and floats are members of Real, this allows a mix of ints and floats
            # (and converts the ints to floats).
            attr.floats.extend(float(v) for v in value)
            attr.type = AttributeProto.FLOATS
        elif all(
                map(lambda bytes_or_false: bytes_or_false is not False,
                    byte_array)):
            attr.strings.extend(cast(List[bytes], byte_array))
            attr.type = AttributeProto.STRINGS
        elif all(isinstance(v, TensorProto) for v in value):
            attr.tensors.extend(value)
            attr.type = AttributeProto.TENSORS
        elif all(isinstance(v, SparseTensorProto) for v in value):
            attr.sparse_tensors.extend(value)
            attr.type = AttributeProto.SPARSE_TENSORS
        elif all(isinstance(v, GraphProto) for v in value):
            attr.graphs.extend(value)
            attr.type = AttributeProto.GRAPHS
        elif all(isinstance(tp, TypeProto) for tp in value):
            attr.type_protos.extend(value)
            attr.type = AttributeProto.TYPE_PROTOS
        else:
            raise ValueError(
                "You passed in an iterable attribute but I cannot figure out "
                "its applicable type.")
    else:
        raise TypeError(f'value "{value}" is not valid attribute data type.')
    return attr
def _make_att_graph(name, new_body):
    attr = AttributeProto()
    attr.name = name
    attr.g.CopyFrom(new_body)
    attr.type = AttributeProto.GRAPH
    return attr
def make_attribute(
    key,  # type: Text
    value,  # type: Any
    doc_string=None  # type: Optional[Text]
):  # type: (...) -> AttributeProto
    """Makes an AttributeProto based on the value type."""
    def getshape(obj):
        if hasattr(obj, 'shape'):
            return obj.shape
        else:
            return (len(obj), )

    attr = AttributeProto()
    attr.name = key
    if doc_string:
        attr.doc_string = doc_string

    is_iterable = isinstance(value, collections.Iterable)
    bytes_or_false = _to_bytes_or_false(value)

    if isinstance(value, np.float32):
        attr.f = value
        attr.type = AttributeProto.FLOAT
    elif isinstance(value, np.float64):
        attr.f = value
        attr.type = AttributeProto.FLOAT
    elif isinstance(value, float):
        attr.f = value
        attr.type = AttributeProto.FLOAT
        # raise RuntimeError("float is not allowed anymore
        # due to ambiguities, "
        # "use numpy types, key='{}'.".format(key))
    elif isinstance(value, np.int32):
        attr.i = value
        attr.type = AttributeProto.INT
    elif isinstance(value, np.int64):
        attr.i = value
        attr.type = AttributeProto.INT64
    elif isinstance(value, numbers.Integral):
        attr.i = value
        attr.type = AttributeProto.INT
    elif bytes_or_false:
        assert isinstance(bytes_or_false, bytes)
        attr.s = bytes_or_false
        attr.type = AttributeProto.STRING
    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)
        attr.type = AttributeProto.TENSOR
    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)
        attr.type = AttributeProto.GRAPH
    # third, iterable cases
    elif is_iterable:
        byte_array = [_to_bytes_or_false(v) for v in value]
        if all(isinstance(v, np.float32) for v in value):
            attr.floats.extend(value)
            attr.type = AttributeProto.FLOATS
            # return make_attribute(
            #     key, doc_string=doc_string,
            #     value=make_tensor(
            #         key, TensorProto.FLOAT,
            #         getshape(value), value))
        elif all(isinstance(v, np.float64) for v in value):
            attr.floats.extend(value)
            attr.type = AttributeProto.FLOATS
            # return make_attribute(
            #    key, doc_string=doc_string,
            #     value=make_tensor(
            #         key, TensorProto.DOUBLE,
            #         getshape(value), value))
        elif all(isinstance(v, float) for v in value):
            attr.floats.extend(value)
            attr.type = AttributeProto.FLOATS
        elif all(isinstance(v, np.int32) for v in value):
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
            # return make_attribute(
            #     key, doc_string=doc_string,
            #     value=make_tensor(
            #         key, TensorProto.INT32,
            #         getshape(value), value))
        elif all(isinstance(v, np.int64) for v in value):
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
            # return make_attribute(
            #     key, doc_string=doc_string,
            #     value=make_tensor(
            #         key, TensorProto.INT64,
            #         getshape(value), value))
        elif all(isinstance(v, numbers.Integral) for v in value):
            # Turn np.int32/64 into Python built-in int.
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
        elif all(byte_array):
            attr.strings.extend(cast(List[bytes], byte_array))
            attr.type = AttributeProto.STRINGS
        elif all(isinstance(v, TensorProto) for v in value):
            attr.tensors.extend(value)
            attr.type = AttributeProto.TENSORS
        elif all(isinstance(v, GraphProto) for v in value):
            attr.graphs.extend(value)
            attr.type = AttributeProto.GRAPHS
        else:
            raise ValueError(
                "You passed in an iterable attribute but I cannot figure out "
                "its applicable type, key='{}', type={}, types={}.".format(
                    key, type(value),
                    [type(_) for _, __ in zip(value, range(0, 5))]))
    else:
        raise ValueError(
            'Value "{}" is not valid attribute data type.'.format(value))
    return attr
Esempio n. 15
0
def _make_att_graph(name, new_body):
    attr = AttributeProto()
    attr.name = name
    attr.g.CopyFrom(new_body)  # pylint: disable=E1101
    attr.type = AttributeProto.GRAPH  # pylint: disable=E1101
    return attr
Esempio n. 16
0
def make_attribute(
        key,  # type: Text
        value,  # type: Any
        doc_string=None  # type: Optional[Text]
):  # type: (...) -> AttributeProto
    """Makes an AttributeProto based on the value type."""
    attr = AttributeProto()
    attr.name = key
    if doc_string:
        attr.doc_string = doc_string

    is_iterable = isinstance(value, collections.Iterable)
    bytes_or_false = _to_bytes_or_false(value)
    # First, singular cases
    # float
    if isinstance(value, float):
        attr.f = value
        attr.type = AttributeProto.FLOAT
    # integer
    elif isinstance(value, numbers.Integral):
        attr.i = cast(int, value)
        attr.type = AttributeProto.INT
    # string
    elif bytes_or_false:
        assert isinstance(bytes_or_false, bytes)
        attr.s = bytes_or_false
        attr.type = AttributeProto.STRING
    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)
        attr.type = AttributeProto.TENSOR
    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)
        attr.type = AttributeProto.GRAPH
    # third, iterable cases
    elif is_iterable:
        byte_array = [_to_bytes_or_false(v) for v in value]
        if all(isinstance(v, float) for v in value):
            attr.floats.extend(value)
            attr.type = AttributeProto.FLOATS
        elif all(isinstance(v, numbers.Integral) for v in value):
            # Turn np.int32/64 into Python built-in int.
            attr.ints.extend(int(v) for v in value)
            attr.type = AttributeProto.INTS
        elif all(byte_array):
            attr.strings.extend(cast(List[bytes], byte_array))
            attr.type = AttributeProto.STRINGS
        elif all(isinstance(v, TensorProto) for v in value):
            attr.tensors.extend(value)
            attr.type = AttributeProto.TENSORS
        elif all(isinstance(v, GraphProto) for v in value):
            attr.graphs.extend(value)
            attr.type = AttributeProto.GRAPHS
        else:
            raise ValueError(
                "You passed in an iterable attribute but I cannot figure out "
                "its applicable type.")
    else:
        raise ValueError(
            'Value "{}" is not valid attribute data type.'.format(value))
    return attr