def _get_nncf_graph_from_sequential(model: tf.keras.Model) -> NNCFGraph: nncf_graph = NNCFGraph() producer_layer = None model_config = model.get_config() for layer in model_config['layers']: layer_name = layer['config']['name'] layer_type = _get_layer_type(layer) layer_dtype = _get_layer_dtype(layer) data_format = layer['config'].get('data_format') attrs = dict(type=layer_type, dtype=layer_dtype, data_format=data_format, in_ports=[0], out_ports=[0], is_shared=False) if layer_type in GENERAL_CONV_LAYERS: module_attributes = _get_module_attributes( model.get_layer(layer_name), attrs) attrs.update({NNCFGraph.MODULE_ATTRIBUTES: module_attributes}) nncf_graph.add_node(layer_name, **attrs) if producer_layer is not None: input_shape = _prepare_shape( model.get_layer(layer_name).input_shape) attr = { NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR: input_shape[0], NNCFGraph.IN_PORT_NAME_EDGE_ATTR: 0 } nncf_graph.add_edge(producer_layer, layer_name, **attr) producer_layer = layer_name return nncf_graph
def get_total_quantizations(model: tf.keras.Model) -> int: fq_layers = [ layer for layer in model.get_config()['layers'] if layer['class_name'] == 'FakeQuantize' ] total_quantizations = sum( len(layer['inbound_nodes']) for layer in fq_layers) return total_quantizations
def _prepare_raw_nodes(model: tf.keras.Model) -> Dict: model_config = model.get_config() raw_nodes = Dict() for layer in model_config['layers']: layer_name = layer['name'] layer_type = _get_layer_type(layer) layer_dtype = _get_layer_dtype(layer) data_format = layer['config'].get('data_format') model_layer = model.get_layer(layer_name) if layer['inbound_nodes']: is_shared = len(layer['inbound_nodes']) > 1 for i, inbound_node in enumerate(layer['inbound_nodes']): input_shape = _prepare_shape( model_layer.inbound_nodes[i].input_shapes) instance = raw_nodes[layer_name][i] instance['type'] = layer_type instance['dtype'] = layer_dtype instance['data_format'] = data_format instance['is_shared'] = is_shared instance['input_shape'] = input_shape instance['in_ports'] = list(range(len(inbound_node))) if not instance['out_ports']: instance['out_ports'] = set() if layer_type in GENERAL_CONV_LAYERS: module_attributes = _get_module_attributes( model_layer, instance) instance.update( {NNCFGraph.MODULE_ATTRIBUTES: module_attributes}) for parent_name, parent_instance_index, parent_out_ports, _ in inbound_node: parent_instance = raw_nodes[parent_name][ parent_instance_index] if parent_instance['out_ports']: parent_instance['out_ports'].add(parent_out_ports) else: parent_instance['out_ports'] = {parent_out_ports} else: instance = raw_nodes[layer_name][0] instance['type'] = layer_type instance['dtype'] = layer_dtype instance['data_format'] = data_format instance['is_shared'] = False instance['in_ports'] = [] instance['input_shape'] = _prepare_shape(model_layer.input_shape) if layer_type in GENERAL_CONV_LAYERS: module_attributes = _get_module_attributes( model_layer, instance) instance.update( {NNCFGraph.MODULE_ATTRIBUTES: module_attributes}) outputs = model_config['output_layers'] raw_nodes = _process_outputs(outputs, raw_nodes) for instance_dict in raw_nodes.values(): for instance in instance_dict.values(): instance['out_ports'] = sorted(list(instance['out_ports'])) return raw_nodes
def hparam_search(hparam_options: Dict, model: tf.keras.Model, dataset: 'DatasetDF', log_root=None, verbose=False, debug=False) -> List[Dict]: def onexit(log_dir): print('Ctrl-C KeyboardInterrupt') if os.path.isdir(log_dir): shutil.rmtree(log_dir) # remove logs for incomplete trainings print(f'rm -rf {log_dir}') model_config = model.get_config() combninations = hparam_combninations(hparam_options) logdir = hparam_logdir(combninations[0], hparam_options, log_root) stats_history = [] # print(f"--- Model Config: ", model_config) print(f"--- Testing {len(combninations)} combinations in {logdir}") print(f"--- hparam_options: ", hparam_options) for index, hparams in enumerate(combninations): run_name = hparam_run_name(hparams, hparam_options) logdir = hparam_logdir(hparams, hparam_options, log_root) print("") print( f"--- Starting trial {index+1}/{len(combninations)}: {logdir.split('/')[-2]} | {run_name}" ) print(hparams) if os.path.exists(logdir): print('Exists: skipping') continue if debug: continue atexit.register(onexit, logdir) # DOCS: https://www.tensorflow.org/guide/keras/save_and_serialize if model_config['name'] == 'sequential': model_clone = tf.keras.Sequential.from_config(model_config) else: model_clone = tf.keras.Model.from_config(model_config) stats = hparam.model_compile_fit(hparams, model_clone, dataset, log_dir=logdir, verbose=verbose) stats_history += stats print(stats) atexit.unregister(onexit) print("") print("--- Stats History") print(stats_history) print("--- Finished") return stats_history
def clone_model(model: tf.keras.Model) -> tf.keras.Model: """ Clone a sequential, functional or subclassed tf.keras.Model. """ try: # sequential or functional model return tf.keras.models.clone_model(model) except ValueError: # subclassed model try: config = model.get_config() except NotImplementedError: config = {} return model.__class__.from_config(config)
def _get_nncf_graph_from_functional(model: tf.keras.Model) -> NNCFGraph: model_config = model.get_config() raw_nodes = _prepare_raw_nodes(model) return _get_nncf_graph_from_raw_nodes(model_config, raw_nodes)