Ejemplo n.º 1
0
def div_exporter(op_def, context):
    node, const_tensors = export_util.translate(**locals())
    const_tensors = []  # Global scalars
    for name in op_def.input:
        if name.startswith('/share/scalar/'):
            const_tensors.append(helper.from_tensor(name, context.ws))
    return node, const_tensors
Ejemplo n.º 2
0
def minimum_exporter(op_def, context):
    node, const_tensors = export_util.translate(**locals())
    node.op_type = 'Min'  # Eltwise, Broadcast
    const_tensors = []  # Global scalars
    for name in op_def.input:
        if name.startswith('/share/scalar/'):
            const_tensors.append(helper.from_tensor(name, context.ws))
    return node, const_tensors
Ejemplo n.º 3
0
def add_exporter(op_def, context):
    node, const_tensors = export_util.translate(**locals())
    dtype = str(helper.fetch_tensor(op_def.output[0], context.ws).dtype)
    node.op_type = 'Or' if dtype == 'bool' else 'Add'
    const_tensors = []  # Global scalars
    for name in op_def.input:
        if name.startswith('/share/scalar/'):
            const_tensors.append(helper.from_tensor(name, context.ws))
    return node, const_tensors
Ejemplo n.º 4
0
def channel_affine_exporter(op_def, context):
    node, const_tensors = export_util.translate(**locals())
    node.op_type = 'ATen'  # Currently not supported in ai.onnx
    helper.add_attribute(node, 'op_type', 'ChannelAffine')
    for arg in op_def.arg:
        if arg.name == 'axis':
            helper.add_attribute(node, 'axis', arg.i)
        elif arg.name == 'num_axes':
            helper.add_attribute(node, 'num_axes', arg.i)
    # Weights and biases
    const_tensors = [
        helper.from_tensor(e, context.ws) for e in op_def.input[1:]
    ]
    return node, const_tensors
Ejemplo n.º 5
0
    def graph_def_to_onnx_graph(
        cls,
        graph_def,
        input_names=None,
        output_names=None,
        input_shapes=None,
        constants=None,
        value_info=None,
        opset_version=None,
        workspace=None,
        verbose=True,
    ):
        input_names = [] if input_names is None else input_names
        output_names = [] if output_names is None else output_names
        constants = {} if constants is None else constants
        value_info = {} if value_info is None else value_info

        if not nest.is_sequence(input_names):
            raise ValueError('<input_names> should be a sequence.')
        if not nest.is_sequence(output_names):
            raise ValueError('<output_names> should be a sequence.')
        if not isinstance(constants, dict):
            raise ValueError('<constants> should be a dict with name -> value.')
        if not isinstance(value_info, dict):
            raise ValueError('<value_info> should be a dict with name -> (dtype, shape).')

        # Determine the opset version to select exporters.
        if opset_version is None:
            opset_version = cls._check_opset_version(opset_version)

        # Create aliases for blobs.
        blob_aliases = {}
        for i, alias in enumerate(output_names):
            blob_aliases[graph_def.output[i]] = alias
            workspace.RegisterAlias(graph_def.output[i], alias)
            if graph_def.output[i] in value_info:
                value_info[alias] = value_info[graph_def.output[i]]
        for i, alias in enumerate(input_names):
            blob_aliases[graph_def.input[i]] = alias
            workspace.RegisterAlias(graph_def.input[i], alias)
            if graph_def.input[i] in value_info:
                value_info[alias] = value_info[graph_def.input[i]]

        # Maybe rewrite the input shapes for future development.
        # A common case is that we should fill ``-1`` for dynamic dimension
        # in the inference runtime like TensorRT.
        if input_shapes is not None:
            if isinstance(input_shapes, dict):
                for k, v in input_shapes.items():
                    value_info[k] = (value_info[k][0], v)
            else:
                for k, v in zip(graph_def.input[:], input_shapes):
                    value_info[k] = (value_info[k][0], v)

        # Prepare to make the graph.
        onnx_graph = onnx.GraphProto(name=graph_def.name
                                     if len(graph_def.name) > 0
                                     else 'onnx-model')
        blob_shapes, blob_names = {}, {}
        blob_versions = collections.defaultdict(
            int, **dict((blob_aliases.get(k, k), 1)
                        for k in helper.collect_inputs(graph_def)))
        initializers, seen_initializers = [], set()

        # Build translator context.
        context = export_util.TranslatorContext(
            workspace=workspace,
            blob_names=blob_names,
            blob_shapes=blob_shapes,
            blob_versions=blob_versions,
            opset_version=opset_version,
        )

        # Add nodes.
        for op in graph_def.op:
            # Get the shape of inputs and outputs.
            for name in itertools.chain(op.input, op.output):
                impl = workspace.GetTensor(name)
                if impl is not None:
                    blob_shapes[name] = impl.dims
                else:
                    blob_shapes[name] = value_info[name][1]

            # Translate definition.
            nodes, const_tensors = cls._make_node(op, context)

            # Rewritten for names.
            for node in nodes:
                node.input[:] = [blob_aliases.get(e, e) for e in node.input]
                node.output[:] = [blob_aliases.get(e, e) for e in node.output]
                cls._rewrite_for_ssa(node, context)

            # Convert constant outputs if necessary.
            if None in nodes:
                const_tensors = [helper.from_tensor(name, workspace)
                                 for name in op.output]
            else:
                onnx_graph.node.extend(nodes)

            # Merge constant tensors.
            if const_tensors is not None:
                value_info = {**value_info,
                              **dict((e.name, (e.data_type, e.dims))
                                     for e in const_tensors)}
                for tensor in const_tensors:
                    if tensor.name not in seen_initializers:
                        initializers.append(tensor)
                        seen_initializers.add(tensor.name)

        # Add constants.
        if constants is not None:
            for k, v in constants.items():
                initializers.append(helper.from_array(v, name=k))

        # Add inputs.
        for name in helper.collect_inputs(onnx_graph):
            try:
                onnx_graph.input.extend([
                    helper.make_tensor_value_info(
                        name=name,
                        elem_type=value_info[name][0],
                        shape=value_info[name][1])])
            except KeyError:
                impl = workspace.GetTensor(name)
                if impl is not None:
                    initializer = helper.from_tensor(name, workspace)
                    onnx_graph.input.extend([
                        helper.make_tensor_value_info(
                            name=name,
                            elem_type=initializer.data_type,
                            shape=initializer.dims)])
                    if name not in seen_initializers:
                        initializers.append(initializer)
                        seen_initializers.add(initializer.name)
                else:
                    raise ValueError(
                        'Info of tensor `{}` is missing, '
                        'specify it in <value_info>.'.format(name))

        # Add initializers.
        onnx_graph.initializer.extend(initializers)

        # Add outputs.
        onnx_graph.output.extend(
            helper.make_tensor_value_info(
                name=blob_names.get(name_v2, name_v2),
                elem_type=value_info[name_v2][0],
                shape=value_info[name_v2][1])
            for name_v2 in [blob_aliases.get(name, name)
                            for name in set(graph_def.output)])

        if verbose:
            print(helper.printable_graph(onnx_graph))

        return onnx_graph