Ejemplo n.º 1
0
def trace_and_get_graph_from_model(model, args, training):

    orig_state_dict_keys = _unique_state_dict(model).keys()

    # By default, training=False, which is good because running a model in
    # training mode could result in internal buffers getting updated, dropout
    # getting applied, etc.  If you really know what you're doing, you
    # can turn training=True (or None, to preserve whatever the original
    # training mode was.)
    with set_training(model, training):
        if hasattr(torch.jit, "get_trace_graph"):
            trace, torch_out = torch.jit.get_trace_graph(model, args)
            graph = trace.graph()
        else:
            old_map = _get_trace_map()
            _init_trace_state()
            model.apply(_register_trace_fn)
            graph, torch_out = _get_trace_graph()(model, args)
            _remove_trace_fn()
            _init_trace_map(old_map)
            # torch.jit._trace_module_map = old_map

        if orig_state_dict_keys != _unique_state_dict(model).keys():
            raise RuntimeError("state_dict changed after running the tracer; "
                               "something weird is happening in your model!")

        return graph, torch_out
Ejemplo n.º 2
0
def _trace_and_get_graph_from_model(model, args):

    # A basic sanity check: make sure the state_dict keys are the same
    # before and after running the model.  Fail fast!
    orig_state_dict_keys = _unique_state_dict(model).keys()

    trace_graph, torch_out, inputs_states = \
        torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
    warn_on_static_input_change(inputs_states)

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    return trace_graph, torch_out
Ejemplo n.º 3
0
def _export(model,
            args,
            f,
            export_params=True,
            verbose=False,
            training=False,
            input_names=None,
            output_names=None,
            aten=False):
    # Special case for common case of passing a single Variable
    if isinstance(args, torch.autograd.Variable):
        args = (args, )

    # A basic sanity check: make sure the state_dict keys are the same
    # before and after running the model.  Fail fast!
    orig_state_dict_keys = _unique_state_dict(model).keys()

    # By default, training=False, which is good because running a model in
    # training mode could result in internal buffers getting updated, dropout
    # getting applied, etc.  If you really know what you're doing, you
    # can turn training=True (or None, to preserve whatever the original
    # training mode was.)
    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, args)

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    _optimize_trace(trace, aten)

    _set_input_and_output_names(trace.graph(), input_names, output_names)

    if verbose:
        print(trace)

    # TODO: Don't allocate a in-memory string for the protobuf
    from torch.onnx.symbolic import _onnx_opset_version
    if export_params:
        # NB: OrderedDict values is not actually a list, but trace.export is
        # not duck-typed and expects an actual list.
        proto = trace.export(list(_unique_state_dict(model).values()),
                             _onnx_opset_version)
    else:
        proto = trace.export([], _onnx_opset_version)

    torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
    return torch_out
Ejemplo n.º 4
0
 def __init__(self, model):
     # sanity check.
     super(PytorchGraph, self).__init__(model)
     self.model = model
     self.state_dict = _unique_state_dict(self.model)
     self.shape_dict = dict()
     self.layer_weight_map = dict()
Ejemplo n.º 5
0
def _model_to_graph(model,
                    args,
                    f,
                    verbose=False,
                    training=False,
                    input_names=None,
                    output_names=None,
                    aten=False):
    # Special case for common case of passing a single Variable
    if isinstance(args, torch.Tensor):
        args = (args, )

    if isinstance(model, torch.jit.ScriptModule):
        torch_out = None
        try:
            method = model.__getattr__('forward')
            graph = method.propagate_shapes(args, False)
            params = method.params()
        except AttributeError:
            # TODO: just trace it
            raise RuntimeError('\'forward\' method must be a script method')
    else:
        graph, torch_out = _trace_and_get_graph_from_model(
            model, args, training)
        params = list(_unique_state_dict(model).values())

    graph = _optimize_graph(graph, aten)

    _set_input_and_output_names(graph, input_names, output_names)
    if verbose:
        print(graph)

    return graph, params, torch_out
Ejemplo n.º 6
0
def _model_to_graph(model, args, f, verbose=False, training=False,
                    input_names=None, output_names=None,
                    operator_export_type=OperatorExportTypes.ONNX,
                    example_outputs=None, propagate=False):
    # Special case for common case of passing a single Variable
    if isinstance(args, torch.Tensor):
        args = (args, )

    if isinstance(model, torch.jit.ScriptModule):
        torch_out = None
        assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
        if isinstance(example_outputs, torch.Tensor):
            example_outputs = [example_outputs]
        try:
            method = model.__getattr__('forward')
            graph = method.propagate_and_assign_input_and_output_shapes(
                args, example_outputs, False, propagate)
            # Erase number types to bring the graph to a pre-NumberType state
            torch._C._jit_pass_erase_number_types(graph)
            params = method.params()
        except AttributeError:
            # TODO: just trace it
            raise RuntimeError('\'forward\' method must be a script method')
    else:
        graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
        params = list(_unique_state_dict(model).values())

    graph = _optimize_graph(graph, operator_export_type)

    _set_input_and_output_names(graph, input_names, output_names)
    if verbose:
        print(graph)

    return graph, params, torch_out
Ejemplo n.º 7
0
def _model_to_graph(model, args, f, verbose=False, training=False,
                    input_names=None, output_names=None,
                    operator_export_type=OperatorExportTypes.ONNX,
                    example_outputs=None, propagate=False):
    # Special case for common case of passing a single Variable
    if isinstance(args, torch.Tensor):
        args = (args, )

    if isinstance(model, torch.jit.ScriptModule):
        torch_out = None
        assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
        if isinstance(example_outputs, torch.Tensor):
            example_outputs = [example_outputs]
        try:
            method = model.__getattr__('forward')
            graph = method.propagate_and_assign_input_and_output_shapes(
                args, example_outputs, False, propagate)
            # Erase number types to bring the graph to a pre-NumberType state
            params = method.params()
        except AttributeError:
            # TODO: just trace it
            raise RuntimeError('\'forward\' method must be a script method')
    else:
        graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
        params = list(_unique_state_dict(model).values())

    graph = _optimize_graph(graph, operator_export_type)

    _set_input_and_output_names(graph, input_names, output_names)
    if verbose:
        print(graph)

    return graph, params, torch_out
Ejemplo n.º 8
0
def _model_to_graph(model, args, f, verbose=False, training=False,
                    input_names=None, output_names=None,
                    operator_export_type=OperatorExportTypes.ONNX,
                    example_outputs=None, propagate=False,
                    _retain_param_name=False):
    # Special case for common case of passing a single Tensor
    if isinstance(args, torch.Tensor):
        args = (args, )

    if isinstance(model, torch.jit.ScriptModule):
        torch_out = None
        assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
        if isinstance(example_outputs, torch.Tensor):
            example_outputs = [example_outputs]
        try:
            method = model.__getattr__('forward')
            params = method.initial_ivalues()
            graph = _propagate_and_assign_input_and_output_shapes(
                method.graph, tuple(args) + tuple(params), example_outputs, False, propagate)
            # Erase number types to bring the graph to a pre-NumberType state

        except AttributeError:
            # TODO: just trace it
            raise RuntimeError('\'forward\' method must be a script method')
    else:
        graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
        state_dict = _unique_state_dict(model)
        params = list(state_dict.values())
        if _retain_param_name:
            graph_inputs = list(graph.inputs())
            user_input_num = len(graph_inputs) - len(state_dict)
            param_names = list(state_dict.keys())
            for i, inp in enumerate(graph_inputs):
                if i >= user_input_num:
                    inp.setUniqueName(param_names[i - user_input_num])

    graph = _optimize_graph(graph, operator_export_type)

    # NB: ONNX requires complete information about output types, which might be
    # erased by some optimizations, so we need to set it explicitly again.
    if torch_out is not None:
        output_tensors, _ = torch._C._jit_flatten(torch_out)
        for output, tensor in zip(graph.outputs(), output_tensors):
            output.inferTypeFrom(tensor)

    _set_input_and_output_names(graph, input_names, output_names)

    # make sure that the param dict and the graph match each other
    flatten_args, _ = torch._C._jit_flatten(args)
    assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs())

    input_and_param_names = [val.uniqueName() for val in graph.inputs()]
    param_names = input_and_param_names[len(input_and_param_names) - len(params):]
    params_dict = dict(zip(param_names, params))

    if verbose:
        print(graph)

    return graph, params_dict, torch_out
Ejemplo n.º 9
0
def _trace_and_get_graph_from_model(model, args, training):
    # A basic sanity check: make sure the state_dict keys are the same
    # before and after running the model.  Fail fast!
    orig_state_dict_keys = _unique_state_dict(model).keys()

    # By default, training=False, which is good because running a model in
    # training mode could result in internal buffers getting updated, dropout
    # getting applied, etc.  If you really know what you're doing, you
    # can turn training=True (or None, to preserve whatever the original
    # training mode was.)
    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, args)

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    return trace.graph(), torch_out
Ejemplo n.º 10
0
def _trace_and_get_graph_from_model(model, args, training):

    # A basic sanity check: make sure the state_dict keys are the same
    # before and after running the model.  Fail fast!
    orig_state_dict_keys = _unique_state_dict(model).keys()

    # By default, training=False, which is good because running a model in
    # training mode could result in internal buffers getting updated, dropout
    # getting applied, etc.  If you really know what you're doing, you
    # can turn training=True (or None, to preserve whatever the original
    # training mode was.)
    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, args)

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    return trace.graph(), torch_out
Ejemplo n.º 11
0
def _export(model, args, f, export_params=True, verbose=False, training=False,
            input_names=None, output_names=None, aten=False):
    # Special case for common case of passing a single Variable
    if isinstance(args, torch.autograd.Variable):
        args = (args, )

    # A basic sanity check: make sure the state_dict keys are the same
    # before and after running the model.  Fail fast!
    orig_state_dict_keys = _unique_state_dict(model).keys()

    # By default, training=False, which is good because running a model in
    # training mode could result in internal buffers getting updated, dropout
    # getting applied, etc.  If you really know what you're doing, you
    # can turn training=True (or None, to preserve whatever the original
    # training mode was.)
    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, args)

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    _optimize_trace(trace, aten)

    _set_input_and_output_names(trace.graph(), input_names, output_names)

    if verbose:
        print(trace)

    # TODO: Don't allocate a in-memory string for the protobuf
    from torch.onnx.symbolic import _onnx_opset_version
    if export_params:
        # NB: OrderedDict values is not actually a list, but trace.export is
        # not duck-typed and expects an actual list.
        proto = trace.export(list(_unique_state_dict(model).values()), _onnx_opset_version)
    else:
        proto = trace.export([], _onnx_opset_version)

    torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
    return torch_out
Ejemplo n.º 12
0
def unique_state_dict(model):
    """
    Wrapper of torch.jit._unique_state_dict.

    Args:
        model (Module): Torch model.

    Returns:
        dict, params.
    """
    from torch.jit import _unique_state_dict

    return _unique_state_dict(model)
Ejemplo n.º 13
0
def rename_graph_param_name(model, graph):
    state_dict = _unique_state_dict(model)
    graph_inputs = list(graph.inputs())
    user_input_num = len(graph_inputs) - len(state_dict)
    param_names = list(state_dict.keys())
    params = []
    for i, inp in enumerate(graph_inputs):
        if i >= user_input_num:
            set_unique_name(inp, param_names[i - user_input_num])
            params.append(unique_name(inp))
        else:
            set_unique_name(inp, 'input_' + str(i))
    return params
Ejemplo n.º 14
0
def _model_to_graph(model,
                    args,
                    f,
                    verbose=False,
                    training=False,
                    input_names=None,
                    output_names=None,
                    operator_export_type=OperatorExportTypes.ONNX,
                    example_outputs=None,
                    propagate=False):
    # Special case for common case of passing a single Tensor
    if isinstance(args, torch.Tensor):
        args = (args, )

    if isinstance(model, torch.jit.ScriptModule):
        torch_out = None
        assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
        if isinstance(example_outputs, torch.Tensor):
            example_outputs = [example_outputs]
        try:
            method = model.__getattr__('forward')
            graph = method.propagate_and_assign_input_and_output_shapes(
                args, example_outputs, False, propagate)
            # Erase number types to bring the graph to a pre-NumberType state
            params = method.params()
        except AttributeError:
            # TODO: just trace it
            raise RuntimeError('\'forward\' method must be a script method')
    else:
        graph, torch_out = _trace_and_get_graph_from_model(
            model, args, training)
        params = list(_unique_state_dict(model).values())

    graph = _optimize_graph(graph, operator_export_type)

    # NB: ONNX requires complete information about output types, which might be
    # erased by some optimizations, so we need to set it explicitly again.
    if torch_out is not None:
        output_tensors, _ = torch._C._jit_flatten(torch_out)
        for output, tensor in zip(graph.outputs(), output_tensors):
            output.inferTypeFrom(tensor)

    _set_input_and_output_names(graph, input_names, output_names)
    if verbose:
        print(graph)

    return graph, params, torch_out
Ejemplo n.º 15
0
def _get_jit_params(module, param_exclude, param_include):
    state_dict = _unique_state_dict(module)
    if param_exclude is not None:
        param_exclude = re.compile(param_exclude)
    if param_include is not None:
        param_include = re.compile(param_include)

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if param_exclude is not None and param_exclude.match(k) is not None:
            continue
        if param_include is not None and param_include.match(k) is None:
            continue
        if "num_batches_tracked" not in k:
            if "weight" in k or "bias" in k or "running_mean" in k or "running_var" in k:
                new_state_dict[k] = v
    params = list(new_state_dict.values())[::-1]
    return params, list(new_state_dict.keys())[::-1]
Ejemplo n.º 16
0
def pytorch_to_keras(model,
                     args,
                     input_shape,
                     change_ordering=False,
                     training=False,
                     verbose=False):
    """
    By given pytorch model convert layers with specified convertors.

    Args:
        model: pytorch model
        args: pytorch model arguments
        input_shape: keras input shape (using for InputLayer creation)
        change_ordering: change NCHW to NHWC
        training: switch model to training mode
        verbose: verbose output

    Returns:
        model: created keras model.
    """

    # PyTorch JIT tracing
    args = (args, ) if isinstance(args, torch.autograd.Variable) else args

    orig_state_dict_keys = _unique_state_dict(model).keys()

    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, args)

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    # _optimize_trace(trace, False)
    trace.set_graph(_optimize_graph(trace.graph(), False))

    if verbose:
        print(trace.graph())

    if verbose:
        print(list(trace.graph().outputs()))

    # Get all graph nodes
    nodes = list(trace.graph().nodes())

    # Collect graph outputs
    graph_outputs = [n.uniqueName() for n in trace.graph().outputs()]
    print('Graph outputs:', graph_outputs)

    # Collect model state dict
    state_dict = _unique_state_dict(model)
    if verbose:
        print('State dict:', list(state_dict))

    import re
    import keras
    from keras import backend as K
    K.set_image_data_format('channels_first')

    layers = dict()
    layers['input'] = keras.layers.InputLayer(input_shape=input_shape,
                                              name='input').output

    outputs = []

    for node in nodes:
        node_inputs = list(node.inputs())
        node_input_names = []
        for node_input in node_inputs:
            if node_input.node().scopeName():
                node_input_names.append(get_node_id(node_input.node()))

        if len(node_input_names) == 0:
            node_input_names.append('input')

        node_type = node.kind()

        node_scope_name = node.scopeName()
        node_id = get_node_id(node)
        node_weights_name = '.'.join(
            re.findall(r'\[([\w\d.]+)\]', node_scope_name))
        node_attrs = {k: node[k] for k in node.attributeNames()}

        node_outputs = list(node.outputs())
        node_outputs_names = []
        for node_output in node_outputs:
            if node_output.node().scopeName():
                node_outputs_names.append(node_output.node().scopeName())

        if verbose:
            print(' ____ ')
            print('graph node:', node_scope_name)
            print('type:', node_type)
            print('inputs:', node_input_names)
            print('outputs:', node_outputs_names)
            print('name in state_dict:', node_weights_name)
            print('attrs:', node_attrs)
            print('node_id:', node_id)
            print('is_terminal:', node_id in graph_outputs)
        AVAILABLE_CONVERTERS[node_type](params=node_attrs,
                                        w_name=node_weights_name,
                                        scope_name=node_id,
                                        inputs=node_input_names,
                                        layers=layers,
                                        weights=state_dict)
        if node_id in graph_outputs:
            outputs.append(layers[node_id])

    model = keras.models.Model(inputs=layers['input'], outputs=outputs)
    model.summary()

    if change_ordering:
        # Change from 'NCW' to 'NWC' ordering customary in tf
        import numpy as np
        config = model.get_config()
        output_shape = None
        for layer_type, lc in ((layer['class_name'], layer['config'])
                               for layer in config['layers']):

            if 'batch_input_shape' in lc:
                if len(lc['batch_input_shape']) == 3:
                    N, C, W = lc['batch_input_shape']
                    lc['batch_input_shape'] = (N, W, C)
                elif len(lc['batch_input_shape']) == 4:
                    N, C, H, W = lc['batch_input_shape']
                    lc['batch_input_shape'] = (N, H, W, C)
                else:
                    raise NotImplementedError(
                        "len(batch_input_shape) should be either 3 or 4")
                output_shape = lc['batch_input_shape']

            if layer_type == 'Con1D':
                (N, W, _), K = output_shape, lc['kernel_size'][0]
                C = lc['filters']
                W -= K - 1
                output_shape = (N, W, C)

            if 'target_shape' in lc:
                lc['target_shape'] = tuple(
                    np.reshape(
                        np.array([
                            list(lc['target_shape'][1:][:]),
                            lc['target_shape'][0]
                        ]), -1))

            if 'data_format' in lc:
                lc['data_format'] = 'channels_last'

            if 'axis' in lc:
                lc['axis'] = len(output_shape) - 1

        K.set_image_data_format('channels_last')

        # # For theano:
        # from keras.utils.layer_utils import convert_all_kernels_in_model
        # convert_all_kernels_in_model(model)

        # Set the weights into the model with new ordering
        # `Dense` layers after `Flatten` have their weights transposed.
        src_weights = []
        last_was_flatten = False
        last_shape = None
        for layer in model.layers:
            W = layer.get_weights()
            if last_was_flatten and W:
                assert len(last_shape) == 3, str(last_shape)
                A, b = W
                _, C, H = last_shape
                A.shape = (C, H, -1)
                A = np.ascontiguousarray(np.swapaxes(A, 0, 1))
                A.shape = (H * C, -1)
                W = [A, b]
                last_was_flatten = False
            if isinstance(layer, keras.layers.core.Flatten):
                last_was_flatten = True
            elif not last_was_flatten:
                last_shape = layer.output_shape
            src_weights.append(W)

        if K.backend() == 'tensorflow':
            # Tensorflow needs a new graph for the converted model
            # to retain the same scopes for the operators.
            import tensorflow as tf
            tf.reset_default_graph()
            K.set_session(tf.Session())
            model_tf_ordering = keras.models.Model.from_config(config)
            for dst, src in zip(model_tf_ordering.layers, src_weights):
                dst.set_weights(src)
        else:
            model_tf_ordering = keras.models.Model.from_config(config)
            for dst, src in zip(model_tf_ordering.layers, src_weights):
                dst.set_weights(src)

        model = model_tf_ordering

    return model
Ejemplo n.º 17
0
def _export(model,
            args,
            f,
            export_params=True,
            verbose=False,
            training=False,
            input_names=None,
            output_names=None,
            aten=False,
            export_type=ExportTypes.PROTOBUF_FILE):
    # Special case for common case of passing a single Variable
    if isinstance(args, torch.autograd.Variable):
        args = (args, )

    # A basic sanity check: make sure the state_dict keys are the same
    # before and after running the model.  Fail fast!
    orig_state_dict_keys = _unique_state_dict(model).keys()

    # By default, training=False, which is good because running a model in
    # training mode could result in internal buffers getting updated, dropout
    # getting applied, etc.  If you really know what you're doing, you
    # can turn training=True (or None, to preserve whatever the original
    # training mode was.)
    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, args)

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    trace.set_graph(_optimize_graph(trace.graph(), aten))

    _set_input_and_output_names(trace.graph(), input_names, output_names)

    if verbose:
        print(trace)

    # TODO: Don't allocate a in-memory string for the protobuf
    from torch.onnx.symbolic import _onnx_opset_version
    defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
    if export_params:
        # NB: OrderedDict values is not actually a list, but trace.export is
        # not duck-typed and expects an actual list.
        proto, export_map = trace.export(
            list(_unique_state_dict(model).values()), _onnx_opset_version,
            defer_weight_export)
    else:
        proto, export_map = trace.export([], _onnx_opset_version, False)

    if export_type == ExportTypes.PROTOBUF_FILE:
        assert (len(export_map) == 0)
        torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
    elif export_type in [
            ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE
    ]:
        import zipfile
        compression = zipfile.ZIP_DEFLATED \
            if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
            else zipfile.ZIP_STORED
        with zipfile.ZipFile(f, 'w', compression=compression) as z:
            z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
            for k, v in export_map.items():
                z.writestr(k, v)
    elif export_type == ExportTypes.DIRECTORY:
        import os
        if os.path.exists(f):
            assert (os.path.isdir(f))
        else:
            os.makedirs(f)

        model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
        torch.serialization._with_file_like(model_proto_file, "wb",
                                            lambda f: f.write(proto))

        for k, v in export_map.items():
            weight_proto_file = os.path.join(f, k)
            torch.serialization._with_file_like(weight_proto_file, "wb",
                                                lambda f: f.write(v))
    else:
        raise RuntimeError('Unknown export type')
    return torch_out
Ejemplo n.º 18
0
def pytorch_to_keras(
    model,
    args,
    input_shapes,
    change_ordering=False,
    training=False,
    verbose=False,
    short_names=False,
):
    """
    By given pytorch model convert layers with specified convertors.

    Args:
        model: pytorch model
        args: pytorch model arguments
        input_shapes: keras input shapes (using for each InputLayer)
        change_ordering: change CHW to HWC
        training: switch model to training mode
        verbose: verbose output
        short_names: use shorn names for keras layers

    Returns:
        model: created keras model.
    """

    # PyTorch JIT tracing
    if isinstance(args, torch.autograd.Variable):
        args = (args, )

    # Workaround for previous versions
    if isinstance(input_shapes, tuple):
        input_shapes = [input_shapes]

    orig_state_dict_keys = _unique_state_dict(model).keys()

    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, tuple(args))

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    # _optimize_trace(trace, False)
    trace.set_graph(_optimize_graph(trace.graph(), False))

    if verbose:
        print(trace.graph())

    if verbose:
        print(list(trace.graph().outputs()))

    # Get all graph nodes
    nodes = list(trace.graph().nodes())

    # Collect graph outputs
    graph_outputs = [n.uniqueName() for n in trace.graph().outputs()]
    print('Graph outputs:', graph_outputs)

    # Collect model state dict
    state_dict = _unique_state_dict(model)
    if verbose:
        print('State dict:', list(state_dict))

    import re
    import keras
    from keras import backend as K
    K.set_image_data_format('channels_first')

    layers = dict()
    keras_inputs = []
    for i in range(len(args)):
        layers['input{0}'.format(i)] = keras.layers.InputLayer(
            input_shape=input_shapes[i], name='input{0}'.format(i)).output
        keras_inputs.append(layers['input{0}'.format(i)])

    outputs = []

    input_index = 0
    model_inputs = dict()
    for node in nodes:
        node_inputs = list(node.inputs())
        node_input_names = []
        for node_input in node_inputs:
            if node_input.node().scopeName():
                node_input_names.append(get_node_id(node_input.node()))

        if len(node_input_names) == 0:
            if len(node_inputs) > 0:
                if node_inputs[0] in model_inputs:
                    node_input_names.append(model_inputs[node_inputs[0]])
                else:
                    input_name = 'input{0}'.format(input_index)
                    node_input_names.append(input_name)
                    input_index += 1
                    model_inputs[node_inputs[0]] = input_name

        node_type = node.kind()
        # print(dir(node))

        node_scope_name = node.scopeName()
        node_id = get_node_id(node)
        node_weights_name = '.'.join(
            re.findall(r'\[([\w\d.]+)\]', node_scope_name))
        node_attrs = {k: node[k] for k in node.attributeNames()}

        node_outputs = list(node.outputs())
        node_outputs_names = []
        for node_output in node_outputs:
            if node_output.node().scopeName():
                node_outputs_names.append(node_output.node().scopeName())

        if verbose:
            print(' ____ ')
            print('graph node:', node_scope_name)
            print('type:', node_type)
            print('inputs:', node_input_names)
            print('outputs:', node_outputs_names)
            print('name in state_dict:', node_weights_name)
            print('attrs:', node_attrs)
            print('is_terminal:', node_id in graph_outputs)
        AVAILABLE_CONVERTERS[node_type](node_attrs, node_weights_name, node_id,
                                        node_input_names, layers, state_dict,
                                        short_names)
        if node_id in graph_outputs:
            outputs.append(layers[node_id])

    model = keras.models.Model(inputs=keras_inputs, outputs=outputs)

    if change_ordering:
        import numpy as np
        conf = model.get_config()

        for layer in conf['layers']:
            if layer['config'] and 'batch_input_shape' in layer['config']:
                layer['config']['batch_input_shape'] = \
                    tuple(np.reshape(np.array(
                        [
                            [None] +
                            list(layer['config']['batch_input_shape'][2:][:]) +
                            [layer['config']['batch_input_shape'][1]]
                        ]), -1
                    ))
            if layer['config'] and 'target_shape' in layer['config']:
                if len(list(layer['config']['target_shape'][1:][:])) > 0:
                    layer['config']['target_shape'] = \
                        tuple(np.reshape(np.array(
                            [
                                list(layer['config']['target_shape'][1:][:]),
                                layer['config']['target_shape'][0]
                            ]), -1
                        ),)

            if layer['config'] and 'data_format' in layer['config']:
                layer['config']['data_format'] = 'channels_last'
            if layer['config'] and 'axis' in layer['config']:
                layer['config']['axis'] = 3

        K.set_image_data_format('channels_last')
        model_tf_ordering = keras.models.Model.from_config(conf)

        # from keras.utils.layer_utils import convert_all_kernels_in_model
        # convert_all_kernels_in_model(model)

        for dst_layer, src_layer in zip(model_tf_ordering.layers,
                                        model.layers):
            dst_layer.set_weights(src_layer.get_weights())

        model = model_tf_ordering

    return model
Ejemplo n.º 19
0
def _model_to_graph(model,
                    args,
                    verbose=False,
                    training=False,
                    input_names=None,
                    output_names=None,
                    operator_export_type=OperatorExportTypes.ONNX,
                    example_outputs=None,
                    propagate=False,
                    _retain_param_name=False,
                    do_constant_folding=False,
                    _disable_torch_constant_prop=False):
    from torch.onnx.symbolic_helper import _export_onnx_opset_version
    # Special case for common case of passing a single Tensor
    if isinstance(args, torch.Tensor):
        args = (args, )

    if isinstance(example_outputs, torch.Tensor):
        example_outputs = [example_outputs]

    torch_out = None

    if isinstance(model, torch.jit.ScriptModule):
        assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
        try:
            method_graph, params = model.forward._lowered_graph()
            in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
            graph = _propagate_and_assign_input_shapes(method_graph,
                                                       tuple(in_vars), False,
                                                       propagate)
        except AttributeError:
            raise RuntimeError('\'forward\' method must be a script method')
    elif isinstance(model, torch.jit.Function):
        assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function"
        method = model
        params = ()
        in_vars, in_desc = torch.jit._flatten(tuple(args))
        graph = _propagate_and_assign_input_shapes(model.graph, tuple(in_vars),
                                                   False, propagate)
    else:
        graph, torch_out = _trace_and_get_graph_from_model(
            model, args, training)
        state_dict = _unique_state_dict(model)
        params = list(state_dict.values())
        if _retain_param_name:
            graph_inputs = list(graph.inputs())
            user_input_num = len(graph_inputs) - len(state_dict)
            param_names = list(state_dict.keys())
            for i, inp in enumerate(graph_inputs):
                if i >= user_input_num:
                    inp.setDebugName(param_names[i - user_input_num])

    graph = _optimize_graph(
        graph,
        operator_export_type,
        _disable_torch_constant_prop=_disable_torch_constant_prop)

    if isinstance(model, torch.jit.ScriptModule) or isinstance(
            model, torch.jit.Function):
        out_vars, _ = torch.jit._flatten(tuple(example_outputs))
        graph = _assign_output_shapes(graph, out_vars)

    # NB: ONNX requires complete information about output types, which might be
    # erased by some optimizations, so we need to set it explicitly again.
    if torch_out is not None:
        output_tensors, _ = torch._C._jit_flatten(torch_out)
        for output, tensor in zip(graph.outputs(), output_tensors):
            output.inferTypeFrom(tensor)

    _set_input_and_output_names(graph, input_names, output_names)

    # make sure that the param dict and the graph match each other
    flatten_args, _ = torch._C._jit_flatten(args)
    assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs())

    input_and_param_names = [val.debugName() for val in graph.inputs()]
    param_names = input_and_param_names[len(input_and_param_names) -
                                        len(params):]
    params_dict = dict(zip(param_names, params))

    if do_constant_folding and _export_onnx_opset_version in [9, 10]:
        params_dict = torch._C._jit_pass_onnx_constant_fold(
            graph, params_dict, _export_onnx_opset_version)
        torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)

    # For ONNX opset < 9, constants only have three data types: float16, float, double.
    # In this pass transform constants of other data types to float/double + cast operator.
    if _export_onnx_opset_version < 9:
        torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph)

    if verbose:
        print(graph)

    return graph, params_dict, torch_out
Ejemplo n.º 20
0
def pytorch_to_keras(
    model,
    args,
    input_names,
    input_shapes,
    change_ordering=False,
    training=False,
    verbose=False,
    names=False,
):
    """
    By given pytorch model convert layers with specified convertors.

    Args:
        model: pytorch model
        args: pytorch model arguments
        input_names: keras input names (using for each InputLayer)
        input_shapes: keras input shapes (using for each InputLayer)
        change_ordering: change CHW to HWC
        training: switch model to training mode
        verbose: verbose output
        names: use short names, use random-suffix or keep original names for keras layers

    Returns:
        model: created keras model.
    """

    # PyTorch JIT tracing
    if isinstance(args, torch.autograd.Variable):
        args = (args, )

    # Workaround for previous versions
    if isinstance(input_shapes, tuple):
        input_shapes = [input_shapes]

    orig_state_dict_keys = _unique_state_dict(model).keys()

    with set_training(model, training):
        trace, torch_out = torch.jit.get_trace_graph(model, tuple(args))

    if orig_state_dict_keys != _unique_state_dict(model).keys():
        raise RuntimeError("state_dict changed after running the tracer; "
                           "something weird is happening in your model!")

    # _optimize_trace(trace, False)
    if version.parse('0.4.0') < version.parse(torch.__version__):
        trace.set_graph(
            _optimize_graph(trace.graph(), OperatorExportTypes.ONNX))
    else:
        trace.set_graph(_optimize_graph(trace.graph(), False))

    trace.graph().lint()

    if verbose:
        print(trace.graph())

    # Get all graph nodes
    nodes = list(trace.graph().nodes())

    # Optimize Flatten:
    # When we have something loke that:
    #
    # %523 : Long() = onnx::Constant[value={0}](), scope: ResNet
    # %524 : Dynamic = onnx::Shape(%522), scope: ResNet
    # %526 : Long() = onnx::Gather[axis=0](%524, %523), scope: ResNet
    # %527 : Long() = onnx::Constant[value={-1}](), scope: ResNet
    # %534 : Dynamic = onnx::Unsqueeze[axes=[0]](%526)
    # %535 : Dynamic = onnx::Unsqueeze[axes=[0]](%527)
    # %536 : Dynamic = onnx::Concat[axis=0](%534, %535)
    # %529 : Float(1, 512) = onnx::Reshape(%522, %536), scope: ResNet
    #
    # It's better to replace it with onnx::Flatten
    if six.PY3:
        from types import SimpleNamespace
        seq_to_find = \
            ['onnx::Constant', 'onnx::Shape', 'onnx::Gather',
             'onnx::Constant', 'onnx::Unsqueeze', 'onnx::Unsqueeze', 'onnx::Concat', 'onnx::Reshape']
        k = 0
        s = 0
        for i, node in enumerate(nodes):
            if node.kind() == seq_to_find[k]:
                if k == 0:
                    s = i
                k += 1
                if k == len(seq_to_find):
                    reshape_op = nodes[s + k - 1]
                    flatten_op = {
                        'kind': (lambda: 'onnx::Flatten'),
                        'attributeNames': (lambda: {}),
                        'outputs': (lambda: list(reshape_op.outputs())),
                        'scopeName': (lambda: reshape_op.scopeName()),
                        'inputs': (lambda: list(reshape_op.inputs())[:1]),
                        '__str__': (lambda: reshape_op.__str__()),
                    }
                    nodes = nodes[:s] + [SimpleNamespace(**flatten_op)
                                         ] + nodes[s + k:]
                    break
            else:
                k = 0
                s = -1

    # Collect graph inputs and outputs
    graph_outputs = [get_leaf_id(n) for n in trace.graph().outputs()]
    graph_inputs = [get_leaf_id(n) for n in trace.graph().inputs()]

    # Collect model state dict
    state_dict = _unique_state_dict(model)
    if verbose:
        print('Graph inputs:', graph_inputs)
        print('Graph outputs:', graph_outputs)
        print('State dict:', list(state_dict))

    import re
    import tensorflow.keras
    from tensorflow.keras import backend as K
    K.set_image_data_format('channels_first')

    layers = dict()
    keras_inputs = []
    for i in range(len(args)):
        layers[graph_inputs[i]] = tensorflow.keras.layers.InputLayer(
            input_shape=input_shapes[i], name=input_names[i]).output
        keras_inputs.append(layers[graph_inputs[i]])

    outputs = []
    group_indices = defaultdict(lambda: 0, {})

    for node in nodes:
        node_inputs = list(node.inputs())
        node_input_names = []

        for node_input in node_inputs:
            node_input_names.append(get_leaf_id(node_input))

        node_type = node.kind()

        node_scope_name = node.scopeName()
        node_id = get_node_id(node)
        node_name_regex = re.findall(r'\[([\w\d.\-\[\]\s]+)\]',
                                     node_scope_name)

        try:
            int(node_name_regex[-1])
            node_weigth_group_name = '.'.join(node_name_regex[:-1])
            node_weights_name = node_weigth_group_name + '.' + str(
                group_indices[node_weigth_group_name])
            group_indices[node_weigth_group_name] += 1

        except ValueError:
            node_weights_name = '.'.join(node_name_regex)
        except IndexError:
            node_weights_name = '.'.join(node_input_names)

        node_attrs = {k: node[k] for k in node.attributeNames()}

        node_outputs = list(node.outputs())
        node_outputs_names = []
        for node_output in node_outputs:
            if node_output.node().scopeName():
                node_outputs_names.append(node_output.node().scopeName())

        if verbose:
            print(' ____ ')
            print('graph node:', node_scope_name)
            print('node id:', node_id)
            print('type:', node_type)
            print('inputs:', node_input_names)
            print('outputs:', node_outputs_names)
            print('name in state_dict:', node_weights_name)
            print('attrs:', node_attrs)
            print('is_terminal:', node_id in graph_outputs)
        AVAILABLE_CONVERTERS[node_type](node_attrs, node_weights_name, node_id,
                                        node_input_names, layers, state_dict,
                                        names)
        if node_id in graph_outputs:
            outputs.append(layers[node_id])

    model = tensorflow.keras.models.Model(inputs=keras_inputs, outputs=outputs)

    if change_ordering:
        import numpy as np
        conf = model.get_config()

        for layer in conf['layers']:
            if layer['config'] and 'batch_input_shape' in layer['config']:
                layer['config']['batch_input_shape'] = \
                    tuple(np.reshape(np.array(
                        [
                            [None] +
                            list(layer['config']['batch_input_shape'][2:][:]) +
                            [layer['config']['batch_input_shape'][1]]
                        ]), -1
                    ))
            if layer['config'] and 'target_shape' in layer['config']:
                if len(list(layer['config']['target_shape'][1:][:])) > 0:
                    layer['config']['target_shape'] = \
                        tuple(np.reshape(np.array(
                            [
                                list(layer['config']['target_shape'][1:][:]),
                                layer['config']['target_shape'][0]
                            ]), -1
                        ),)

            if layer['config'] and 'data_format' in layer['config']:
                layer['config']['data_format'] = 'channels_last'
            if layer['config'] and 'axis' in layer['config']:
                layer['config']['axis'] = 3

        K.set_image_data_format('channels_last')
        model_tf_ordering = tensorflow.keras.models.Model.from_config(conf)

        # from tensorflow.keras.utils.layer_utils import convert_all_kernels_in_model
        # convert_all_kernels_in_model(model)

        for dst_layer, src_layer in zip(model_tf_ordering.layers,
                                        model.layers):
            dst_layer.set_weights(src_layer.get_weights())

        model = model_tf_ordering

    print(
        'Your model was (probably) successfully converted! '
        'Please, follow the repository https://github.com/nerox8664/pytorch2keras and give a star :)'
    )
    return model