Пример #1
0
 def test_is_attr_legal(self):
     # no name, no field
     attr = AttributeProto()
     self.assertFalse(helper.is_attribute_legal(attr))
     # name, but no field
     attr = AttributeProto()
     attr.name = "test"
     self.assertFalse(helper.is_attribute_legal(attr))
     # name, with two fields
     attr = AttributeProto()
     attr.name = "test"
     attr.f = 1.0
     attr.i = 2
     self.assertFalse(helper.is_attribute_legal(attr))
Пример #2
0
 def test_is_attr_legal(self):
     # no name, no field
     attr = AttributeProto()
     self.assertRaises(checker.ValidationError, checker.check_attribute,
                       attr)
     # name, but no field
     attr = AttributeProto()
     attr.name = "test"
     self.assertRaises(checker.ValidationError, checker.check_attribute,
                       attr)
     # name, with two fields
     attr = AttributeProto()
     attr.name = "test"
     attr.f = 1.0
     attr.i = 2
     self.assertRaises(checker.ValidationError, checker.check_attribute,
                       attr)
Пример #3
0
 def map_attr(attr):
     if attr.name in depluralizer:
         # TODO: replace this with a version test
         if not can_be_singular(attr.ints):
             raise "Caffe2 doesn't support plural kernel_shape/strides/pads prior to 6cb4d1ecb0dfb553f797f6a8a61dd6966909cb0b; if you know your Caffe2 is recent enough, comment out this test"
         # In fact, this code is MANDATORY, because prior to
         # https://github.com/caffe2/caffe2/commit/6cb4d1ecb0dfb553f797f6a8a61dd6966909cb0b
         # the pluralized versions were not supported.
         # You'll get an error like
         # "[enforce fail at conv_transpose_unpool_op_base.h:54] kernel_h_ > 0"
         # if your Caffe2 is too old and you actually use the plural
         # version
         singular_attr = AttributeProto()
         singular_attr.name = depluralizer[attr.name]
         singular_attr.i = attr.ints[0]
         return cls._onnx_arg_to_caffe2_arg(op_def.type, singular_attr)
     else:
         return cls._onnx_arg_to_caffe2_arg(op_def.type, attr)
Пример #4
0
    def test_is_attr_legal_verbose(self):

        ATTR_FUNCTIONS = [
            (lambda attr: setattr(attr, "f", 1.0)),
            (lambda attr: setattr(attr, "i", 1)),
            (lambda attr: setattr(attr, "s", b"str")),
            (lambda attr: attr.floats.extend([1.0, 2.0])),
            (lambda attr: attr.ints.extend([1, 2])),
            (lambda attr: attr.strings.extend([b"a", b"b"])),
            (lambda attr: attr.tensors.extend([TensorProto(), TensorProto()])),
            (lambda attr: attr.graphs.extend([GraphProto(), GraphProto()])),
        ]
        # Randomly set one field, and the result should be legal.
        for i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(ATTR_FUNCTIONS)(attr)
            self.assertTrue(helper.is_attribute_legal(attr))
        # Randomly set two fields, and then ensure helper function catches it.
        for i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(ATTR_FUNCTIONS, 2):
                func(attr)
            self.assertFalse(helper.is_attribute_legal(attr))
Пример #5
0
    def test_is_attr_legal_verbose(self):

        SET_ATTR = [
            (lambda attr: setattr(attr, "f", 1.0) or \
             setattr(attr, 'type', AttributeProto.FLOAT)),
            (lambda attr: setattr(attr, "i", 1) or \
             setattr(attr, 'type', AttributeProto.INT)),
            (lambda attr: setattr(attr, "s", b"str") or \
             setattr(attr, 'type', AttributeProto.STRING)),
            (lambda attr: attr.floats.extend([1.0, 2.0]) or \
             setattr(attr, 'type', AttributeProto.FLOATS)),
            (lambda attr: attr.ints.extend([1, 2]) or \
             setattr(attr, 'type', AttributeProto.INTS)),
            (lambda attr: attr.strings.extend([b"a", b"b"]) or \
             setattr(attr, 'type', AttributeProto.STRINGS)),
        ]
        # Randomly set one field, and the result should be legal.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(SET_ATTR)(attr)
            checker.check_attribute(attr)
        # Randomly set two fields, and then ensure helper function catches it.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(SET_ATTR, 2):
                func(attr)
            self.assertRaises(checker.ValidationError, checker.check_attribute,
                              attr)
Пример #6
0
def make_attribute(key, value):
    """Makes an AttributeProto based on the value type."""
    attr = AttributeProto()
    attr.name = key

    is_iterable = isinstance(value, collections.Iterable)
    bytes_or_false = _to_bytes_or_false(value)
    # First, singular cases
    # float
    if isinstance(value, float):
        attr.f = value
    # integer
    elif isinstance(value, numbers.Integral):
        attr.i = value
    # string
    elif bytes_or_false:
        attr.s = bytes_or_false
    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)
    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)
    # third, iterable cases
    elif is_iterable:
        byte_array = [_to_bytes_or_false(v) for v in value]
        if all(isinstance(v, float) for v in value):
            attr.floats.extend(value)
        elif all(isinstance(v, numbers.Integral) for v in value):
            # Turn np.int32/64 into Python built-in int.
            attr.ints.extend(int(v) for v in value)
        elif all(byte_array):
            attr.strings.extend(byte_array)
        elif all(isinstance(v, TensorProto) for v in value):
            attr.tensors.extend(value)
        elif all(isinstance(v, GraphProto) for v in value):
            attr.graphs.extend(value)
        else:
            raise ValueError(
                "You passed in an iterable attribute but I cannot figure out "
                "its applicable type.")
    else:
        raise ValueError(
            'Value "{}" is not valid attribute data type.'.format(value))
    return attr