def test_is_attr_legal_verbose(self): ATTR_FUNCTIONS = [ (lambda attr: setattr(attr, "f", 1.0)), (lambda attr: setattr(attr, "i", 1)), (lambda attr: setattr(attr, "s", b"str")), (lambda attr: attr.floats.extend([1.0, 2.0])), (lambda attr: attr.ints.extend([1, 2])), (lambda attr: attr.strings.extend([b"a", b"b"])), (lambda attr: attr.tensors.extend([TensorProto(), TensorProto()])), (lambda attr: attr.graphs.extend([GraphProto(), GraphProto()])), ] # Randomly set one field, and the result should be legal. for i in range(100): attr = AttributeProto() attr.name = "test" random.choice(ATTR_FUNCTIONS)(attr) self.assertTrue(helper.is_attribute_legal(attr)) # Randomly set two fields, and then ensure helper function catches it. for i in range(100): attr = AttributeProto() attr.name = "test" for func in random.sample(ATTR_FUNCTIONS, 2): func(attr) self.assertFalse(helper.is_attribute_legal(attr))
def test_is_attr_legal_verbose(self): SET_ATTR = [ (lambda attr: setattr(attr, "f", 1.0) or \ setattr(attr, 'type', AttributeProto.FLOAT)), (lambda attr: setattr(attr, "i", 1) or \ setattr(attr, 'type', AttributeProto.INT)), (lambda attr: setattr(attr, "s", b"str") or \ setattr(attr, 'type', AttributeProto.STRING)), (lambda attr: attr.floats.extend([1.0, 2.0]) or \ setattr(attr, 'type', AttributeProto.FLOATS)), (lambda attr: attr.ints.extend([1, 2]) or \ setattr(attr, 'type', AttributeProto.INTS)), (lambda attr: attr.strings.extend([b"a", b"b"]) or \ setattr(attr, 'type', AttributeProto.STRINGS)), ] # Randomly set one field, and the result should be legal. for _i in range(100): attr = AttributeProto() attr.name = "test" random.choice(SET_ATTR)(attr) checker.check_attribute(attr) # Randomly set two fields, and then ensure helper function catches it. for _i in range(100): attr = AttributeProto() attr.name = "test" for func in random.sample(SET_ATTR, 2): func(attr) self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
def test_is_attr_legal(self): # no name, no field attr = AttributeProto() self.assertFalse(helper.is_attribute_legal(attr)) # name, but no field attr = AttributeProto() attr.name = "test" self.assertFalse(helper.is_attribute_legal(attr)) # name, with two fields attr = AttributeProto() attr.name = "test" attr.f = 1.0 attr.i = 2 self.assertFalse(helper.is_attribute_legal(attr))
def test_is_attr_legal(self): # no name, no field attr = AttributeProto() self.assertRaises(checker.ValidationError, checker.check_attribute, attr) # name, but no field attr = AttributeProto() attr.name = "test" self.assertRaises(checker.ValidationError, checker.check_attribute, attr) # name, with two fields attr = AttributeProto() attr.name = "test" attr.f = 1.0 attr.i = 2 self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
def make_attribute(key, value, doc_string=None): """Makes an AttributeProto based on the value type.""" attr = AttributeProto() attr.name = key if doc_string: attr.doc_string = doc_string is_iterable = isinstance(value, collections.Iterable) bytes_or_false = _to_bytes_or_false(value) # First, singular cases # float if isinstance(value, float): attr.f = value attr.type = AttributeProto.FLOAT # integer elif isinstance(value, numbers.Integral): attr.i = value attr.type = AttributeProto.INT # string elif bytes_or_false: attr.s = bytes_or_false attr.type = AttributeProto.STRING elif isinstance(value, TensorProto): attr.t.CopyFrom(value) attr.type = AttributeProto.TENSOR elif isinstance(value, GraphProto): attr.g.CopyFrom(value) attr.type = AttributeProto.GRAPH # third, iterable cases elif is_iterable: byte_array = [_to_bytes_or_false(v) for v in value] if all(isinstance(v, float) for v in value): attr.floats.extend(value) attr.type = AttributeProto.FLOATS elif all(isinstance(v, numbers.Integral) for v in value): # Turn np.int32/64 into Python built-in int. attr.ints.extend(int(v) for v in value) attr.type = AttributeProto.INTS elif all(byte_array): attr.strings.extend(byte_array) attr.type = AttributeProto.STRINGS elif all(isinstance(v, TensorProto) for v in value): attr.tensors.extend(value) attr.type = AttributeProto.TENSORS elif all(isinstance(v, GraphProto) for v in value): attr.graphs.extend(value) attr.type = AttributeProto.GRAPHS else: raise ValueError( "You passed in an iterable attribute but I cannot figure out " "its applicable type.") else: raise ValueError( 'Value "{}" is not valid attribute data type.'.format(value)) return attr
def map_attr(attr): if attr.name in depluralizer: # TODO: replace this with a version test if not can_be_singular(attr.ints): raise "Caffe2 doesn't support plural kernel_shape/strides/pads prior to 6cb4d1ecb0dfb553f797f6a8a61dd6966909cb0b; if you know your Caffe2 is recent enough, comment out this test" # In fact, this code is MANDATORY, because prior to # https://github.com/caffe2/caffe2/commit/6cb4d1ecb0dfb553f797f6a8a61dd6966909cb0b # the pluralized versions were not supported. # You'll get an error like # "[enforce fail at conv_transpose_unpool_op_base.h:54] kernel_h_ > 0" # if your Caffe2 is too old and you actually use the plural # version singular_attr = AttributeProto() singular_attr.name = depluralizer[attr.name] singular_attr.i = attr.ints[0] return cls._onnx_arg_to_caffe2_arg(op_def.type, singular_attr) else: return cls._onnx_arg_to_caffe2_arg(op_def.type, attr)
def make_attribute(key, value): """Makes an AttributeProto based on the value type.""" attr = AttributeProto() attr.name = key is_iterable = isinstance(value, collections.Iterable) bytes_or_false = _to_bytes_or_false(value) # First, singular cases # float if isinstance(value, float): attr.f = value # integer elif isinstance(value, numbers.Integral): attr.i = value # string elif bytes_or_false: attr.s = bytes_or_false elif isinstance(value, TensorProto): attr.t.CopyFrom(value) elif isinstance(value, GraphProto): attr.g.CopyFrom(value) # third, iterable cases elif is_iterable: byte_array = [_to_bytes_or_false(v) for v in value] if all(isinstance(v, float) for v in value): attr.floats.extend(value) elif all(isinstance(v, numbers.Integral) for v in value): attr.ints.extend(value) elif all(byte_array): attr.strings.extend(byte_array) elif all(isinstance(v, TensorProto) for v in value): attr.tensors.extend(value) elif all(isinstance(v, GraphProto) for v in value): attr.graphs.extend(value) else: raise ValueError( "You passed in an iterable attribute but I cannot figure out " "its applicable type.") else: raise ValueError( 'Value "{}" is not valid attribute data type.'.format(value)) return attr