def trace_and_get_graph_from_model(model, args, training): orig_state_dict_keys = _unique_state_dict(model).keys() # By default, training=False, which is good because running a model in # training mode could result in internal buffers getting updated, dropout # getting applied, etc. If you really know what you're doing, you # can turn training=True (or None, to preserve whatever the original # training mode was.) with set_training(model, training): if hasattr(torch.jit, "get_trace_graph"): trace, torch_out = torch.jit.get_trace_graph(model, args) graph = trace.graph() else: old_map = _get_trace_map() _init_trace_state() model.apply(_register_trace_fn) graph, torch_out = _get_trace_graph()(model, args) _remove_trace_fn() _init_trace_map(old_map) # torch.jit._trace_module_map = old_map if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") return graph, torch_out
def _trace_and_get_graph_from_model(model, args): # A basic sanity check: make sure the state_dict keys are the same # before and after running the model. Fail fast! orig_state_dict_keys = _unique_state_dict(model).keys() trace_graph, torch_out, inputs_states = \ torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True) warn_on_static_input_change(inputs_states) if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") return trace_graph, torch_out
def _export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False): # Special case for common case of passing a single Variable if isinstance(args, torch.autograd.Variable): args = (args, ) # A basic sanity check: make sure the state_dict keys are the same # before and after running the model. Fail fast! orig_state_dict_keys = _unique_state_dict(model).keys() # By default, training=False, which is good because running a model in # training mode could result in internal buffers getting updated, dropout # getting applied, etc. If you really know what you're doing, you # can turn training=True (or None, to preserve whatever the original # training mode was.) with set_training(model, training): trace, torch_out = torch.jit.get_trace_graph(model, args) if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") _optimize_trace(trace, aten) _set_input_and_output_names(trace.graph(), input_names, output_names) if verbose: print(trace) # TODO: Don't allocate a in-memory string for the protobuf from torch.onnx.symbolic import _onnx_opset_version if export_params: # NB: OrderedDict values is not actually a list, but trace.export is # not duck-typed and expects an actual list. proto = trace.export(list(_unique_state_dict(model).values()), _onnx_opset_version) else: proto = trace.export([], _onnx_opset_version) torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto)) return torch_out
def __init__(self, model): # sanity check. super(PytorchGraph, self).__init__(model) self.model = model self.state_dict = _unique_state_dict(self.model) self.shape_dict = dict() self.layer_weight_map = dict()
def _model_to_graph(model, args, f, verbose=False, training=False, input_names=None, output_names=None, aten=False): # Special case for common case of passing a single Variable if isinstance(args, torch.Tensor): args = (args, ) if isinstance(model, torch.jit.ScriptModule): torch_out = None try: method = model.__getattr__('forward') graph = method.propagate_shapes(args, False) params = method.params() except AttributeError: # TODO: just trace it raise RuntimeError('\'forward\' method must be a script method') else: graph, torch_out = _trace_and_get_graph_from_model( model, args, training) params = list(_unique_state_dict(model).values()) graph = _optimize_graph(graph, aten) _set_input_and_output_names(graph, input_names, output_names) if verbose: print(graph) return graph, params, torch_out
def _model_to_graph(model, args, f, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, example_outputs=None, propagate=False): # Special case for common case of passing a single Variable if isinstance(args, torch.Tensor): args = (args, ) if isinstance(model, torch.jit.ScriptModule): torch_out = None assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] try: method = model.__getattr__('forward') graph = method.propagate_and_assign_input_and_output_shapes( args, example_outputs, False, propagate) # Erase number types to bring the graph to a pre-NumberType state torch._C._jit_pass_erase_number_types(graph) params = method.params() except AttributeError: # TODO: just trace it raise RuntimeError('\'forward\' method must be a script method') else: graph, torch_out = _trace_and_get_graph_from_model(model, args, training) params = list(_unique_state_dict(model).values()) graph = _optimize_graph(graph, operator_export_type) _set_input_and_output_names(graph, input_names, output_names) if verbose: print(graph) return graph, params, torch_out
def _model_to_graph(model, args, f, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, example_outputs=None, propagate=False): # Special case for common case of passing a single Variable if isinstance(args, torch.Tensor): args = (args, ) if isinstance(model, torch.jit.ScriptModule): torch_out = None assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] try: method = model.__getattr__('forward') graph = method.propagate_and_assign_input_and_output_shapes( args, example_outputs, False, propagate) # Erase number types to bring the graph to a pre-NumberType state params = method.params() except AttributeError: # TODO: just trace it raise RuntimeError('\'forward\' method must be a script method') else: graph, torch_out = _trace_and_get_graph_from_model(model, args, training) params = list(_unique_state_dict(model).values()) graph = _optimize_graph(graph, operator_export_type) _set_input_and_output_names(graph, input_names, output_names) if verbose: print(graph) return graph, params, torch_out
def _model_to_graph(model, args, f, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, example_outputs=None, propagate=False, _retain_param_name=False): # Special case for common case of passing a single Tensor if isinstance(args, torch.Tensor): args = (args, ) if isinstance(model, torch.jit.ScriptModule): torch_out = None assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] try: method = model.__getattr__('forward') params = method.initial_ivalues() graph = _propagate_and_assign_input_and_output_shapes( method.graph, tuple(args) + tuple(params), example_outputs, False, propagate) # Erase number types to bring the graph to a pre-NumberType state except AttributeError: # TODO: just trace it raise RuntimeError('\'forward\' method must be a script method') else: graph, torch_out = _trace_and_get_graph_from_model(model, args, training) state_dict = _unique_state_dict(model) params = list(state_dict.values()) if _retain_param_name: graph_inputs = list(graph.inputs()) user_input_num = len(graph_inputs) - len(state_dict) param_names = list(state_dict.keys()) for i, inp in enumerate(graph_inputs): if i >= user_input_num: inp.setUniqueName(param_names[i - user_input_num]) graph = _optimize_graph(graph, operator_export_type) # NB: ONNX requires complete information about output types, which might be # erased by some optimizations, so we need to set it explicitly again. if torch_out is not None: output_tensors, _ = torch._C._jit_flatten(torch_out) for output, tensor in zip(graph.outputs(), output_tensors): output.inferTypeFrom(tensor) _set_input_and_output_names(graph, input_names, output_names) # make sure that the param dict and the graph match each other flatten_args, _ = torch._C._jit_flatten(args) assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs()) input_and_param_names = [val.uniqueName() for val in graph.inputs()] param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) if verbose: print(graph) return graph, params_dict, torch_out
def _trace_and_get_graph_from_model(model, args, training): # A basic sanity check: make sure the state_dict keys are the same # before and after running the model. Fail fast! orig_state_dict_keys = _unique_state_dict(model).keys() # By default, training=False, which is good because running a model in # training mode could result in internal buffers getting updated, dropout # getting applied, etc. If you really know what you're doing, you # can turn training=True (or None, to preserve whatever the original # training mode was.) with set_training(model, training): trace, torch_out = torch.jit.get_trace_graph(model, args) if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") return trace.graph(), torch_out
def unique_state_dict(model): """ Wrapper of torch.jit._unique_state_dict. Args: model (Module): Torch model. Returns: dict, params. """ from torch.jit import _unique_state_dict return _unique_state_dict(model)
def rename_graph_param_name(model, graph): state_dict = _unique_state_dict(model) graph_inputs = list(graph.inputs()) user_input_num = len(graph_inputs) - len(state_dict) param_names = list(state_dict.keys()) params = [] for i, inp in enumerate(graph_inputs): if i >= user_input_num: set_unique_name(inp, param_names[i - user_input_num]) params.append(unique_name(inp)) else: set_unique_name(inp, 'input_' + str(i)) return params
def _model_to_graph(model, args, f, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, example_outputs=None, propagate=False): # Special case for common case of passing a single Tensor if isinstance(args, torch.Tensor): args = (args, ) if isinstance(model, torch.jit.ScriptModule): torch_out = None assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] try: method = model.__getattr__('forward') graph = method.propagate_and_assign_input_and_output_shapes( args, example_outputs, False, propagate) # Erase number types to bring the graph to a pre-NumberType state params = method.params() except AttributeError: # TODO: just trace it raise RuntimeError('\'forward\' method must be a script method') else: graph, torch_out = _trace_and_get_graph_from_model( model, args, training) params = list(_unique_state_dict(model).values()) graph = _optimize_graph(graph, operator_export_type) # NB: ONNX requires complete information about output types, which might be # erased by some optimizations, so we need to set it explicitly again. if torch_out is not None: output_tensors, _ = torch._C._jit_flatten(torch_out) for output, tensor in zip(graph.outputs(), output_tensors): output.inferTypeFrom(tensor) _set_input_and_output_names(graph, input_names, output_names) if verbose: print(graph) return graph, params, torch_out
def _get_jit_params(module, param_exclude, param_include): state_dict = _unique_state_dict(module) if param_exclude is not None: param_exclude = re.compile(param_exclude) if param_include is not None: param_include = re.compile(param_include) new_state_dict = OrderedDict() for k, v in state_dict.items(): if param_exclude is not None and param_exclude.match(k) is not None: continue if param_include is not None and param_include.match(k) is None: continue if "num_batches_tracked" not in k: if "weight" in k or "bias" in k or "running_mean" in k or "running_var" in k: new_state_dict[k] = v params = list(new_state_dict.values())[::-1] return params, list(new_state_dict.keys())[::-1]
def pytorch_to_keras(model, args, input_shape, change_ordering=False, training=False, verbose=False): """ By given pytorch model convert layers with specified convertors. Args: model: pytorch model args: pytorch model arguments input_shape: keras input shape (using for InputLayer creation) change_ordering: change NCHW to NHWC training: switch model to training mode verbose: verbose output Returns: model: created keras model. """ # PyTorch JIT tracing args = (args, ) if isinstance(args, torch.autograd.Variable) else args orig_state_dict_keys = _unique_state_dict(model).keys() with set_training(model, training): trace, torch_out = torch.jit.get_trace_graph(model, args) if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") # _optimize_trace(trace, False) trace.set_graph(_optimize_graph(trace.graph(), False)) if verbose: print(trace.graph()) if verbose: print(list(trace.graph().outputs())) # Get all graph nodes nodes = list(trace.graph().nodes()) # Collect graph outputs graph_outputs = [n.uniqueName() for n in trace.graph().outputs()] print('Graph outputs:', graph_outputs) # Collect model state dict state_dict = _unique_state_dict(model) if verbose: print('State dict:', list(state_dict)) import re import keras from keras import backend as K K.set_image_data_format('channels_first') layers = dict() layers['input'] = keras.layers.InputLayer(input_shape=input_shape, name='input').output outputs = [] for node in nodes: node_inputs = list(node.inputs()) node_input_names = [] for node_input in node_inputs: if node_input.node().scopeName(): node_input_names.append(get_node_id(node_input.node())) if len(node_input_names) == 0: node_input_names.append('input') node_type = node.kind() node_scope_name = node.scopeName() node_id = get_node_id(node) node_weights_name = '.'.join( re.findall(r'\[([\w\d.]+)\]', node_scope_name)) node_attrs = {k: node[k] for k in node.attributeNames()} node_outputs = list(node.outputs()) node_outputs_names = [] for node_output in node_outputs: if node_output.node().scopeName(): node_outputs_names.append(node_output.node().scopeName()) if verbose: print(' ____ ') print('graph node:', node_scope_name) print('type:', node_type) print('inputs:', node_input_names) print('outputs:', node_outputs_names) print('name in state_dict:', node_weights_name) print('attrs:', node_attrs) print('node_id:', node_id) print('is_terminal:', node_id in graph_outputs) AVAILABLE_CONVERTERS[node_type](params=node_attrs, w_name=node_weights_name, scope_name=node_id, inputs=node_input_names, layers=layers, weights=state_dict) if node_id in graph_outputs: outputs.append(layers[node_id]) model = keras.models.Model(inputs=layers['input'], outputs=outputs) model.summary() if change_ordering: # Change from 'NCW' to 'NWC' ordering customary in tf import numpy as np config = model.get_config() output_shape = None for layer_type, lc in ((layer['class_name'], layer['config']) for layer in config['layers']): if 'batch_input_shape' in lc: if len(lc['batch_input_shape']) == 3: N, C, W = lc['batch_input_shape'] lc['batch_input_shape'] = (N, W, C) elif len(lc['batch_input_shape']) == 4: N, C, H, W = lc['batch_input_shape'] lc['batch_input_shape'] = (N, H, W, C) else: raise NotImplementedError( "len(batch_input_shape) should be either 3 or 4") output_shape = lc['batch_input_shape'] if layer_type == 'Con1D': (N, W, _), K = output_shape, lc['kernel_size'][0] C = lc['filters'] W -= K - 1 output_shape = (N, W, C) if 'target_shape' in lc: lc['target_shape'] = tuple( np.reshape( np.array([ list(lc['target_shape'][1:][:]), lc['target_shape'][0] ]), -1)) if 'data_format' in lc: lc['data_format'] = 'channels_last' if 'axis' in lc: lc['axis'] = len(output_shape) - 1 K.set_image_data_format('channels_last') # # For theano: # from keras.utils.layer_utils import convert_all_kernels_in_model # convert_all_kernels_in_model(model) # Set the weights into the model with new ordering # `Dense` layers after `Flatten` have their weights transposed. src_weights = [] last_was_flatten = False last_shape = None for layer in model.layers: W = layer.get_weights() if last_was_flatten and W: assert len(last_shape) == 3, str(last_shape) A, b = W _, C, H = last_shape A.shape = (C, H, -1) A = np.ascontiguousarray(np.swapaxes(A, 0, 1)) A.shape = (H * C, -1) W = [A, b] last_was_flatten = False if isinstance(layer, keras.layers.core.Flatten): last_was_flatten = True elif not last_was_flatten: last_shape = layer.output_shape src_weights.append(W) if K.backend() == 'tensorflow': # Tensorflow needs a new graph for the converted model # to retain the same scopes for the operators. import tensorflow as tf tf.reset_default_graph() K.set_session(tf.Session()) model_tf_ordering = keras.models.Model.from_config(config) for dst, src in zip(model_tf_ordering.layers, src_weights): dst.set_weights(src) else: model_tf_ordering = keras.models.Model.from_config(config) for dst, src in zip(model_tf_ordering.layers, src_weights): dst.set_weights(src) model = model_tf_ordering return model
def _export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False, export_type=ExportTypes.PROTOBUF_FILE): # Special case for common case of passing a single Variable if isinstance(args, torch.autograd.Variable): args = (args, ) # A basic sanity check: make sure the state_dict keys are the same # before and after running the model. Fail fast! orig_state_dict_keys = _unique_state_dict(model).keys() # By default, training=False, which is good because running a model in # training mode could result in internal buffers getting updated, dropout # getting applied, etc. If you really know what you're doing, you # can turn training=True (or None, to preserve whatever the original # training mode was.) with set_training(model, training): trace, torch_out = torch.jit.get_trace_graph(model, args) if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") trace.set_graph(_optimize_graph(trace.graph(), aten)) _set_input_and_output_names(trace.graph(), input_names, output_names) if verbose: print(trace) # TODO: Don't allocate a in-memory string for the protobuf from torch.onnx.symbolic import _onnx_opset_version defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if export_params: # NB: OrderedDict values is not actually a list, but trace.export is # not duck-typed and expects an actual list. proto, export_map = trace.export( list(_unique_state_dict(model).values()), _onnx_opset_version, defer_weight_export) else: proto, export_map = trace.export([], _onnx_opset_version, False) if export_type == ExportTypes.PROTOBUF_FILE: assert (len(export_map) == 0) torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto)) elif export_type in [ ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE ]: import zipfile compression = zipfile.ZIP_DEFLATED \ if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \ else zipfile.ZIP_STORED with zipfile.ZipFile(f, 'w', compression=compression) as z: z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto) for k, v in export_map.items(): z.writestr(k, v) elif export_type == ExportTypes.DIRECTORY: import os if os.path.exists(f): assert (os.path.isdir(f)) else: os.makedirs(f) model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME) torch.serialization._with_file_like(model_proto_file, "wb", lambda f: f.write(proto)) for k, v in export_map.items(): weight_proto_file = os.path.join(f, k) torch.serialization._with_file_like(weight_proto_file, "wb", lambda f: f.write(v)) else: raise RuntimeError('Unknown export type') return torch_out
def pytorch_to_keras( model, args, input_shapes, change_ordering=False, training=False, verbose=False, short_names=False, ): """ By given pytorch model convert layers with specified convertors. Args: model: pytorch model args: pytorch model arguments input_shapes: keras input shapes (using for each InputLayer) change_ordering: change CHW to HWC training: switch model to training mode verbose: verbose output short_names: use shorn names for keras layers Returns: model: created keras model. """ # PyTorch JIT tracing if isinstance(args, torch.autograd.Variable): args = (args, ) # Workaround for previous versions if isinstance(input_shapes, tuple): input_shapes = [input_shapes] orig_state_dict_keys = _unique_state_dict(model).keys() with set_training(model, training): trace, torch_out = torch.jit.get_trace_graph(model, tuple(args)) if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") # _optimize_trace(trace, False) trace.set_graph(_optimize_graph(trace.graph(), False)) if verbose: print(trace.graph()) if verbose: print(list(trace.graph().outputs())) # Get all graph nodes nodes = list(trace.graph().nodes()) # Collect graph outputs graph_outputs = [n.uniqueName() for n in trace.graph().outputs()] print('Graph outputs:', graph_outputs) # Collect model state dict state_dict = _unique_state_dict(model) if verbose: print('State dict:', list(state_dict)) import re import keras from keras import backend as K K.set_image_data_format('channels_first') layers = dict() keras_inputs = [] for i in range(len(args)): layers['input{0}'.format(i)] = keras.layers.InputLayer( input_shape=input_shapes[i], name='input{0}'.format(i)).output keras_inputs.append(layers['input{0}'.format(i)]) outputs = [] input_index = 0 model_inputs = dict() for node in nodes: node_inputs = list(node.inputs()) node_input_names = [] for node_input in node_inputs: if node_input.node().scopeName(): node_input_names.append(get_node_id(node_input.node())) if len(node_input_names) == 0: if len(node_inputs) > 0: if node_inputs[0] in model_inputs: node_input_names.append(model_inputs[node_inputs[0]]) else: input_name = 'input{0}'.format(input_index) node_input_names.append(input_name) input_index += 1 model_inputs[node_inputs[0]] = input_name node_type = node.kind() # print(dir(node)) node_scope_name = node.scopeName() node_id = get_node_id(node) node_weights_name = '.'.join( re.findall(r'\[([\w\d.]+)\]', node_scope_name)) node_attrs = {k: node[k] for k in node.attributeNames()} node_outputs = list(node.outputs()) node_outputs_names = [] for node_output in node_outputs: if node_output.node().scopeName(): node_outputs_names.append(node_output.node().scopeName()) if verbose: print(' ____ ') print('graph node:', node_scope_name) print('type:', node_type) print('inputs:', node_input_names) print('outputs:', node_outputs_names) print('name in state_dict:', node_weights_name) print('attrs:', node_attrs) print('is_terminal:', node_id in graph_outputs) AVAILABLE_CONVERTERS[node_type](node_attrs, node_weights_name, node_id, node_input_names, layers, state_dict, short_names) if node_id in graph_outputs: outputs.append(layers[node_id]) model = keras.models.Model(inputs=keras_inputs, outputs=outputs) if change_ordering: import numpy as np conf = model.get_config() for layer in conf['layers']: if layer['config'] and 'batch_input_shape' in layer['config']: layer['config']['batch_input_shape'] = \ tuple(np.reshape(np.array( [ [None] + list(layer['config']['batch_input_shape'][2:][:]) + [layer['config']['batch_input_shape'][1]] ]), -1 )) if layer['config'] and 'target_shape' in layer['config']: if len(list(layer['config']['target_shape'][1:][:])) > 0: layer['config']['target_shape'] = \ tuple(np.reshape(np.array( [ list(layer['config']['target_shape'][1:][:]), layer['config']['target_shape'][0] ]), -1 ),) if layer['config'] and 'data_format' in layer['config']: layer['config']['data_format'] = 'channels_last' if layer['config'] and 'axis' in layer['config']: layer['config']['axis'] = 3 K.set_image_data_format('channels_last') model_tf_ordering = keras.models.Model.from_config(conf) # from keras.utils.layer_utils import convert_all_kernels_in_model # convert_all_kernels_in_model(model) for dst_layer, src_layer in zip(model_tf_ordering.layers, model.layers): dst_layer.set_weights(src_layer.get_weights()) model = model_tf_ordering return model
def _model_to_graph(model, args, verbose=False, training=False, input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, example_outputs=None, propagate=False, _retain_param_name=False, do_constant_folding=False, _disable_torch_constant_prop=False): from torch.onnx.symbolic_helper import _export_onnx_opset_version # Special case for common case of passing a single Tensor if isinstance(args, torch.Tensor): args = (args, ) if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] torch_out = None if isinstance(model, torch.jit.ScriptModule): assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" try: method_graph, params = model.forward._lowered_graph() in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params)) graph = _propagate_and_assign_input_shapes(method_graph, tuple(in_vars), False, propagate) except AttributeError: raise RuntimeError('\'forward\' method must be a script method') elif isinstance(model, torch.jit.Function): assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function" method = model params = () in_vars, in_desc = torch.jit._flatten(tuple(args)) graph = _propagate_and_assign_input_shapes(model.graph, tuple(in_vars), False, propagate) else: graph, torch_out = _trace_and_get_graph_from_model( model, args, training) state_dict = _unique_state_dict(model) params = list(state_dict.values()) if _retain_param_name: graph_inputs = list(graph.inputs()) user_input_num = len(graph_inputs) - len(state_dict) param_names = list(state_dict.keys()) for i, inp in enumerate(graph_inputs): if i >= user_input_num: inp.setDebugName(param_names[i - user_input_num]) graph = _optimize_graph( graph, operator_export_type, _disable_torch_constant_prop=_disable_torch_constant_prop) if isinstance(model, torch.jit.ScriptModule) or isinstance( model, torch.jit.Function): out_vars, _ = torch.jit._flatten(tuple(example_outputs)) graph = _assign_output_shapes(graph, out_vars) # NB: ONNX requires complete information about output types, which might be # erased by some optimizations, so we need to set it explicitly again. if torch_out is not None: output_tensors, _ = torch._C._jit_flatten(torch_out) for output, tensor in zip(graph.outputs(), output_tensors): output.inferTypeFrom(tensor) _set_input_and_output_names(graph, input_names, output_names) # make sure that the param dict and the graph match each other flatten_args, _ = torch._C._jit_flatten(args) assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs()) input_and_param_names = [val.debugName() for val in graph.inputs()] param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) if do_constant_folding and _export_onnx_opset_version in [9, 10]: params_dict = torch._C._jit_pass_onnx_constant_fold( graph, params_dict, _export_onnx_opset_version) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) # For ONNX opset < 9, constants only have three data types: float16, float, double. # In this pass transform constants of other data types to float/double + cast operator. if _export_onnx_opset_version < 9: torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph) if verbose: print(graph) return graph, params_dict, torch_out
def pytorch_to_keras( model, args, input_names, input_shapes, change_ordering=False, training=False, verbose=False, names=False, ): """ By given pytorch model convert layers with specified convertors. Args: model: pytorch model args: pytorch model arguments input_names: keras input names (using for each InputLayer) input_shapes: keras input shapes (using for each InputLayer) change_ordering: change CHW to HWC training: switch model to training mode verbose: verbose output names: use short names, use random-suffix or keep original names for keras layers Returns: model: created keras model. """ # PyTorch JIT tracing if isinstance(args, torch.autograd.Variable): args = (args, ) # Workaround for previous versions if isinstance(input_shapes, tuple): input_shapes = [input_shapes] orig_state_dict_keys = _unique_state_dict(model).keys() with set_training(model, training): trace, torch_out = torch.jit.get_trace_graph(model, tuple(args)) if orig_state_dict_keys != _unique_state_dict(model).keys(): raise RuntimeError("state_dict changed after running the tracer; " "something weird is happening in your model!") # _optimize_trace(trace, False) if version.parse('0.4.0') < version.parse(torch.__version__): trace.set_graph( _optimize_graph(trace.graph(), OperatorExportTypes.ONNX)) else: trace.set_graph(_optimize_graph(trace.graph(), False)) trace.graph().lint() if verbose: print(trace.graph()) # Get all graph nodes nodes = list(trace.graph().nodes()) # Optimize Flatten: # When we have something loke that: # # %523 : Long() = onnx::Constant[value={0}](), scope: ResNet # %524 : Dynamic = onnx::Shape(%522), scope: ResNet # %526 : Long() = onnx::Gather[axis=0](%524, %523), scope: ResNet # %527 : Long() = onnx::Constant[value={-1}](), scope: ResNet # %534 : Dynamic = onnx::Unsqueeze[axes=[0]](%526) # %535 : Dynamic = onnx::Unsqueeze[axes=[0]](%527) # %536 : Dynamic = onnx::Concat[axis=0](%534, %535) # %529 : Float(1, 512) = onnx::Reshape(%522, %536), scope: ResNet # # It's better to replace it with onnx::Flatten if six.PY3: from types import SimpleNamespace seq_to_find = \ ['onnx::Constant', 'onnx::Shape', 'onnx::Gather', 'onnx::Constant', 'onnx::Unsqueeze', 'onnx::Unsqueeze', 'onnx::Concat', 'onnx::Reshape'] k = 0 s = 0 for i, node in enumerate(nodes): if node.kind() == seq_to_find[k]: if k == 0: s = i k += 1 if k == len(seq_to_find): reshape_op = nodes[s + k - 1] flatten_op = { 'kind': (lambda: 'onnx::Flatten'), 'attributeNames': (lambda: {}), 'outputs': (lambda: list(reshape_op.outputs())), 'scopeName': (lambda: reshape_op.scopeName()), 'inputs': (lambda: list(reshape_op.inputs())[:1]), '__str__': (lambda: reshape_op.__str__()), } nodes = nodes[:s] + [SimpleNamespace(**flatten_op) ] + nodes[s + k:] break else: k = 0 s = -1 # Collect graph inputs and outputs graph_outputs = [get_leaf_id(n) for n in trace.graph().outputs()] graph_inputs = [get_leaf_id(n) for n in trace.graph().inputs()] # Collect model state dict state_dict = _unique_state_dict(model) if verbose: print('Graph inputs:', graph_inputs) print('Graph outputs:', graph_outputs) print('State dict:', list(state_dict)) import re import tensorflow.keras from tensorflow.keras import backend as K K.set_image_data_format('channels_first') layers = dict() keras_inputs = [] for i in range(len(args)): layers[graph_inputs[i]] = tensorflow.keras.layers.InputLayer( input_shape=input_shapes[i], name=input_names[i]).output keras_inputs.append(layers[graph_inputs[i]]) outputs = [] group_indices = defaultdict(lambda: 0, {}) for node in nodes: node_inputs = list(node.inputs()) node_input_names = [] for node_input in node_inputs: node_input_names.append(get_leaf_id(node_input)) node_type = node.kind() node_scope_name = node.scopeName() node_id = get_node_id(node) node_name_regex = re.findall(r'\[([\w\d.\-\[\]\s]+)\]', node_scope_name) try: int(node_name_regex[-1]) node_weigth_group_name = '.'.join(node_name_regex[:-1]) node_weights_name = node_weigth_group_name + '.' + str( group_indices[node_weigth_group_name]) group_indices[node_weigth_group_name] += 1 except ValueError: node_weights_name = '.'.join(node_name_regex) except IndexError: node_weights_name = '.'.join(node_input_names) node_attrs = {k: node[k] for k in node.attributeNames()} node_outputs = list(node.outputs()) node_outputs_names = [] for node_output in node_outputs: if node_output.node().scopeName(): node_outputs_names.append(node_output.node().scopeName()) if verbose: print(' ____ ') print('graph node:', node_scope_name) print('node id:', node_id) print('type:', node_type) print('inputs:', node_input_names) print('outputs:', node_outputs_names) print('name in state_dict:', node_weights_name) print('attrs:', node_attrs) print('is_terminal:', node_id in graph_outputs) AVAILABLE_CONVERTERS[node_type](node_attrs, node_weights_name, node_id, node_input_names, layers, state_dict, names) if node_id in graph_outputs: outputs.append(layers[node_id]) model = tensorflow.keras.models.Model(inputs=keras_inputs, outputs=outputs) if change_ordering: import numpy as np conf = model.get_config() for layer in conf['layers']: if layer['config'] and 'batch_input_shape' in layer['config']: layer['config']['batch_input_shape'] = \ tuple(np.reshape(np.array( [ [None] + list(layer['config']['batch_input_shape'][2:][:]) + [layer['config']['batch_input_shape'][1]] ]), -1 )) if layer['config'] and 'target_shape' in layer['config']: if len(list(layer['config']['target_shape'][1:][:])) > 0: layer['config']['target_shape'] = \ tuple(np.reshape(np.array( [ list(layer['config']['target_shape'][1:][:]), layer['config']['target_shape'][0] ]), -1 ),) if layer['config'] and 'data_format' in layer['config']: layer['config']['data_format'] = 'channels_last' if layer['config'] and 'axis' in layer['config']: layer['config']['axis'] = 3 K.set_image_data_format('channels_last') model_tf_ordering = tensorflow.keras.models.Model.from_config(conf) # from tensorflow.keras.utils.layer_utils import convert_all_kernels_in_model # convert_all_kernels_in_model(model) for dst_layer, src_layer in zip(model_tf_ordering.layers, model.layers): dst_layer.set_weights(src_layer.get_weights()) model = model_tf_ordering print( 'Your model was (probably) successfully converted! ' 'Please, follow the repository https://github.com/nerox8664/pytorch2keras and give a star :)' ) return model