def __init__(self, **kwargs): """ Constructor. Should take either model config, or NetworkX graph. :param config: model config to init from :param graph: NetworkX instance to init from """ self._models = [] self._prefix_is_applied = False self._cache = Dict() if 'config' not in kwargs and 'graph' not in kwargs: raise TypeError( 'Unable to load model. Invalid keyword arguments. ' 'Expected model config (config=) or NetworkX graph (graph=) is expected ' 'with optional target_device parameter. Got {}'.format( kwargs.keys())) if 'config' in kwargs: target_device = kwargs[ 'target_device'] if 'target_device' in kwargs else 'ANY' self._from_config(kwargs['config'], target_device) elif 'graph' in kwargs: self._from_graph(kwargs['graph']) else: raise TypeError( 'Unable to load models. Invalid keyword argument. ' 'Either model config (config=) or NetworkX graph (graph=) is expected.' ) for model in self._models: ge.add_fullname_for_nodes(model['model'])
def _from_graph(self, graph): if graph.graph['ir_version'] == 10: raise AssertionError( 'POT does not support version 10 of IR.' 'Please convert the model with the newer version of OpenVINO ' 'or use the POT from OpenVINO 2021.4.2 to work with version 10 of IR.') ge.add_fullname_for_nodes(graph) self._models.append({'model': graph}) self._is_cascade = False
def _from_config(self, model_config, target_device='ANY'): model_config = model_config if isinstance(model_config, Dict) else Dict(model_config) if model_config.cascade: for model_dict in model_config.cascade: model_config_ = model_config.deepcopy() model_config_.update(model_dict) self._models.append({'model': load_graph(model_config_, target_device)}) if len(model_config.cascade) > 1: self._models[-1]['name'] = model_dict.name self._models[-1]['model'].name = model_dict.name else: self._models.append({'model': load_graph(model_config, target_device)}) self.name = model_config.model_name self._is_cascade = len(self._models) > 1 if self._is_cascade: self._add_models_prefix() for model in self._models: ge.add_fullname_for_nodes(model['model'])
def _from_graph(self, graph): ge.add_fullname_for_nodes(graph) self._models.append({'model': graph}) self._is_cascade = False