Esempio n. 1
0
    def transform(self):
        """ Applies transformations to the Keras model.

        :return: The transformed Keras model
        """
        layer_weights_map = {}
        for layer in self._model.layers:
            original_layer = layer.layer if isinstance(layer,
                                                       NNCFWrapper) else layer
            layer_weights_map[original_layer.name] = self._get_layer_weights(
                layer)

        for transform in self._transformations:
            self._apply_transformation(transform)

        if is_functional_model(self._model):
            transformed_model = tf.keras.Model.from_config(
                self._model_config, self._custom_objects)
        else:
            transformed_model = tf.keras.Sequential.from_config(
                self._model_config, self._custom_objects)

        for layer in transformed_model.layers:
            original_layer = layer.layer if isinstance(layer,
                                                       NNCFWrapper) else layer
            weights = layer_weights_map.get(original_layer.name)
            if weights:
                self._set_layer_weights(layer, weights)

        return transformed_model
Esempio n. 2
0
 def _find_layer_config(self, layer_name):
     for idx, layer in enumerate(self._model_config['layers']):
         layer_name_ = layer['name'] if is_functional_model(self._model) \
             else layer['config']['name']
         if layer_name_ == layer_name:
             return idx, layer
     return None, None
Esempio n. 3
0
    def _replace_config(self, layer_name, replace_layer_config):
        replace_layer_name = replace_layer_config['config']['name']
        if is_functional_model(self._model):
            if 'name' not in replace_layer_config:
                replace_layer_config['name'] = replace_layer_name
            self._replace_functional(layer_name, replace_layer_config)
        else:
            self._replace_sequential(layer_name, replace_layer_config)

        self._update_layer_mapping(layer_name, replace_layer_name)
Esempio n. 4
0
    def _insert_layers_after(self, layer_name, instance_index, out_port, layers):
        functional_model = is_functional_model(self._model)

        layer_configs = []
        for layer in layers:
            config = tf.keras.utils.serialize_keras_object(layer)
            if functional_model:
                config['name'] = config['config']['name']
                config['inbound_nodes'] = [[[layer_name, instance_index, out_port, {}]]]
            layer_configs.append(config)

        for config in layer_configs:
            if functional_model:
                self._insert_layer_after_functional(layer_name, instance_index, config)
            else:
                self._insert_layer_after_sequential(layer_name, config)
Esempio n. 5
0
def convert_keras_model_to_nncf_graph(model: tf.keras.Model) -> NNCFGraph:
    """
    Convert Keras model graph to the NNCFGraph

    :param model: Keras model
    :return: NNCFGraph
    """
    func_model = is_functional_model(model)
    seq_model = is_sequential_model(model)

    if not func_model and not seq_model:
        RuntimeError('convert_keras_model_to_nxmodel function supports '
                     'only sequential or functional models')

    if func_model:
        nncf_graph = _get_nncf_graph_from_functional(model)
    else:
        nncf_graph = _get_nncf_graph_from_sequential(model)

    return nncf_graph
Esempio n. 6
0
def convert_keras_model_to_nxmodel(model):
    """
    Convert Keras model graph to the NetworkX directed graph

    :param model: Keras model
    :return: NetworkX directed graph
    """
    func_model = is_functional_model(model)
    seq_model = is_sequential_model(model)

    if not func_model and not seq_model:
        RuntimeError('convert_keras_model_to_nxmodel function supports '
                     'only sequential or functional models')

    if func_model:
        nxmodel = _get_nxmodel_from_functional(model)
    else:
        nxmodel = _get_nxmodel_from_sequential(model)

    #nx.drawing.nx_pydot.write_dot(nxmodel, str("nxmodel_graph.dot"))

    return nxmodel