def __getitem__(self, key: str): """ Parse and return the attribute corresponding to the given key """ _xattr = self._xattr_map[key] assert _xattr.name == key if _xattr.type == 'UNDEFINED': return None elif _xattr.type == 'BOOL': return _xattr.b elif _xattr.type == 'INT': return _xattr.i elif _xattr.type == 'INTS': return IntVector(_xattr.ints) elif _xattr.type == 'INTS2D': return IntVector2D(_xattr.ints2d) elif _xattr.type == 'FLOAT': return _xattr.f elif _xattr.type == 'FLOATS': return FloatVector(_xattr.floats) elif _xattr.type == 'STRING': return _xattr.s elif _xattr.type == 'STRINGS': return StrVector(_xattr.strings) elif _xattr.type == 'MAP_STR_STR': return MapStrStr(_xattr.map_str_str) elif _xattr.type == 'MAP_STR_VSTR': return MapStrVectorStr(_xattr.map_str_vstr) else: raise NotImplementedError( "Unsupported attribute: {} of type: {}".format( _xattr, _xattr.type))
def get_layer_names(self): # type: () -> List[str] """ Return all layer names in topological order """ return StrVector(self._xgraph.get_layer_names())
def get_output_names(self): # type: () -> List[str] return StrVector(self._xgraph.get_output_names())
def layer(self): return StrVector(self._xlayer.layer)
def targets(self): return StrVector(self._xlayer.targets)
def bottoms(self): return StrVector(self._xlayer.bottoms)
def tops(self): return StrVector(self._xlayer.tops)
def type(self): return StrVector(self._xlayer.xtype)
class OpaqueFunc(object): # TypeCode conversion functions # First: C++ -> Python # Second: Python -> C++ type_codes_ = { TypeCode.vInt: ( lambda arg_: IntVector(arg_.ints), lambda arg_: lpx.OpaqueValue(lpx.IntVector(arg_))), TypeCode.Str: ( lambda arg_: arg_.s, lambda arg_: lpx.OpaqueValue(arg_)), TypeCode.Byte: ( lambda arg_: arg_.bytes, lambda arg_: lpx.OpaqueValue(arg_)), TypeCode.vStr: ( lambda arg_: StrVector(arg_.strings), lambda arg_: lpx.OpaqueValue(lpx.StrVector(arg_))), TypeCode.StrContainer: ( lambda arg_: StrContainer.from_lib(arg_.str_c), lambda arg_: lpx.OpaqueValue(arg_._str_c)), TypeCode.BytesContainer: ( lambda arg_: BytesContainer.from_lib(arg_.bytes_c), lambda arg_: lpx.OpaqueValue(arg_._bytes_c)), TypeCode.XGraph: ( lambda arg_: XGraph._from_xgraph(arg_.xg), lambda arg_: lpx.OpaqueValue(arg_._xgraph)), TypeCode.XBuffer: ( lambda arg_: XBuffer.from_lib(arg_.xb), lambda arg_: lpx.OpaqueValue(arg_._xb)), TypeCode.vXBuffer: ( lambda arg_: [XBuffer.from_lib(e) for e in arg_.xbuffers], lambda arg_: lpx.OpaqueValue( lpx.XBufferHolderVector([xb._xb for xb in arg_]))), TypeCode.OpaqueFunc: ( lambda arg_: OpaqueFunc.from_lib(arg_.of), lambda arg_: lpx.OpaqueValue(arg_._of)) } def __init__(self, func: Callable = None, type_codes: List[TypeCode] = None) -> None: self._of = lpx.OpaqueFunc() if type_codes is None: type_codes = [] if func is not None: self.set_func(func, type_codes) @classmethod def from_lib(cls, _of: lpx.OpaqueFunc) -> 'OpaqueFunc': of = OpaqueFunc.__new__(cls) of._of = _of return of def set_func(self, func: Callable, type_codes: List[TypeCode]): # if type_codes is not None: for tc in type_codes: if tc not in OpaqueFunc.type_codes_: raise NotImplementedError("Function with argument of" " unsupported type: {} provided" .format(tc.name)) def opaque_func_wrapper(args): new_args = [] if type_codes is not None: args_type_codes = type_codes else: args_type_codes = [TypeCode(args[i].get_type_code_int()) for i in range(len(args))] for tc, arg_ in zip(args_type_codes, args): if tc not in OpaqueFunc.type_codes_: raise ValueError(f"Unsupported type code: {tc}") arg_ = OpaqueFunc.type_codes_[tc][0](arg_) new_args.append(arg_) func(*new_args) arg_type_codes_ = lpx.IntVector([tc.value for tc in type_codes]) self._of.set_func(opaque_func_wrapper, arg_type_codes_) def __call__(self, *args: Any) -> None: """ Call internal lib OpaqueFunc with provided args """ args_type_codes = self.get_arg_type_codes() if len(args) != len(args_type_codes): raise ValueError("Invalid number of arguments detected." " OpaqueFunc is expecting {} arguments" " but got: {}" .format(len(args_type_codes), len(args))) oa_v = [] for tc, arg_ in zip(args_type_codes, args): if tc not in OpaqueFunc.type_codes_: raise ValueError(f"Unsupported type code: {tc}") oa_v.append(OpaqueFunc.type_codes_[tc][1](arg_)) oa = lpx.OpaqueArgs(oa_v) self._of(oa) def get_arg_type_codes(self): return [TypeCode(i) for i in self._of.get_arg_type_codes()] def get_nb_type_codes(self): return len(self.get_arg_type_codes()) def __del__(self): pass
def test_str_vector(self): iv = lpx.StrVector(['a', 'b', 'c']) ivx = StrVector(iv) assert ivx == iv assert ivx == ['a', 'b', 'c'] # Append ivx.append('d') assert len(iv) == 4 assert len(ivx) == 4 # Contains assert 'b' in iv assert 'b' in ivx assert 'e' not in iv assert 'e' not in ivx # Delete del ivx[3] assert ivx == ['a', 'b', 'c'] assert iv == lpx.StrVector(['a', 'b', 'c']) # Equal assert ivx == ['a', 'b', 'c'] assert ivx == lpx.StrVector(['a', 'b', 'c']) assert iv == lpx.StrVector(['a', 'b', 'c']) # Extend ivx.extend(['d', 'e']) assert len(iv) == 5 assert len(ivx) == 5 # Get item assert ivx[3] == 'd' assert ivx[-1] == 'e' with self.assertRaises(IndexError): ivx[6] # Iter c = ['a', 'b', 'c', 'd', 'e'] for i, e in enumerate(ivx): assert e == c[i] for i, e in enumerate(iv): assert e == c[i] # Length assert len(ivx) == len(iv) assert len(ivx) == 5 # Not equal assert iv != lpx.StrVector(['a', 'b', 'c']) assert ivx != ['a', 'b', 'c', 'd'] # Repr assert repr(iv) == "StrVector[a, b, c, d, e]" assert repr(ivx) == "StrVector[a, b, c, d, e]" # Str assert str(iv) == "StrVector[a, b, c, d, e]" assert str(ivx) == "StrVector[a, b, c, d, e]" # Set ivx[0] = 'z' assert ivx == ['z', 'b', 'c', 'd', 'e'] assert iv == lpx.StrVector(['z', 'b', 'c', 'd', 'e']) with self.assertRaises(IndexError): ivx[6] = 'z' # def test_xbuffer_vector(self): # iv = lpx.XBufferVector([np.array(1, 2, 3)]) # ivx = FloatVector(iv) # assert ivx == iv # assert ivx == [1, 1.5, 3] # # Append # ivx.append(4) # assert len(iv) == 4 # assert len(ivx) == 4 # # Contains # assert 1.5 in iv # assert 1.5 in ivx # assert 1.51 not in iv # assert 1.51 not in ivx # with self.assertRaises(TypeError): # assert 'a' not in ivx # # Delete # del ivx[3] # assert ivx == [1, 1.5, 3] # assert iv == lpx.FloatVector([1, 1.5, 3]) # # Equal # assert ivx == [1, 1.5, 3] # assert ivx == lpx.FloatVector([1, 1.5, 3]) # assert iv == lpx.FloatVector([1, 1.5, 3]) # # Extend # ivx.extend([4, 5]) # assert len(iv) == 5 # assert len(ivx) == 5 # # Get item # assert ivx[3] == 4 # assert ivx[-1] == 5 # with self.assertRaises(IndexError): # ivx[6] # # Iter # c = [1, 1.5, 3, 4, 5] # for i, e in enumerate(ivx): # assert e == c[i] # for i, e in enumerate(iv): # assert e == c[i] # # Length # assert len(ivx) == len(iv) # assert len(ivx) == 5 # # Not equal # assert iv != lpx.FloatVector([1, 1.5, 3]) # assert ivx != [1, 1.5, 3, 4] # # Repr # assert repr(iv) == "FloatVector[1, 1.5, 3, 4, 5]" # assert repr(ivx) == "FloatVector[1, 1.5, 3, 4, 5]" # # Str # assert str(iv) == "FloatVector[1, 1.5, 3, 4, 5]" # assert str(ivx) == "FloatVector[1, 1.5, 3, 4, 5]" # # Set # ivx[0] = -1 # assert ivx == [-1, 1.5, 3, 4, 5] # assert iv == lpx.FloatVector([-1, 1.5, 3, 4, 5]) # with self.assertRaises(IndexError): # ivx[6] = -1