Ejemplo n.º 1
0
def caffe2_op_to_node_def(op_def, env):
    node_def = onnx_pb2.NodeProto()
    # NB: This must happen BEFORE we start freshening inplace outputs
    node_def.input.extend(map(env.rename, op_def.input))
    node_def.op_type = get_node_op_type(op_def)

    # Determine what was inplace updates
    input_set = set(op_def.input)
    output_set = set(op_def.output)

    schema = onnx.defs.get_schema(node_def.op_type)
    # ints does not support extend()
    consumes = []
    for i, x in enumerate(op_def.input):
        is_consumed, output_idx = schema.consumed(i)
        if is_consumed == onnx.defs.OpSchema.UseType.CONSUME_ENFORCED:
            consumes.append(1)
        elif is_consumed == onnx.defs.OpSchema.UseType.CONSUME_ALLOWED:
            if x in output_set:
                consumes.append(1)
            else:
                consumes.append(0)
        else:
            if x in output_set:
                raise RuntimeError(
                    "schema says consume not allowed, but caffe2 used inplace syntax"
                )
            consumes.append(0)
    if any(consumes):
        consumes_attr = onnx_pb2.AttributeProto()
        consumes_attr.name = "consumed_inputs"
        consumes_attr.ints.extend(consumes)
    else:
        consumes_attr = None

    def fresh_or_rename(out):
        if out in input_set:
            return env.fresh(out)
        else:
            return env.rename(out)

    node_def.output.extend(map(fresh_or_rename, op_def.output))
    # TODO: refactor frontend to allow special handling for individual ops
    if node_def.op_type == 'Concat':
        assert len(node_def.output) == 2
        del node_def.output[1]

    node_def.name = op_def.name
    attrs = get_onnx_attrs(node_def.op_type, op_def)
    if consumes_attr:
        attrs.append(consumes_attr)
    node_def.attribute.extend(attrs)
    checker.check_node(node_def)
    return node_def
Ejemplo n.º 2
0
 def apply_trans(args, k, dim=2):
     onnx_attr = None
     if dim == 2:
         k_h, k_w, ks = k + '_h', k + '_w', k + 's'
     else:
         k_t, k_l, k_b, k_r, ks = k + '_t', k + '_l', k + '_b', k + '_r', k + 's'
     if dim == 2 and k_h in args and k_w in args:
         assert not onnx_attr
         onnx_attr = onnx_pb2.AttributeProto()
         onnx_attr.name = ks
         onnx_attr_assign(onnx_attr, ArgType.INTS,
                          [args[k_h][1], args[k_w][1]])
         del args[k_h]
         del args[k_w]
     if dim == 4 and k_t in args and k_l in args and k_b in args and k_r in args:
         assert not onnx_attr
         onnx_attr = onnx_pb2.AttributeProto()
         onnx_attr.name = ks
         onnx_attr_assign(
             onnx_attr, ArgType.INTS,
             [args[k_t][1], args[k_l][1], args[k_b][1], args[k_r][1]])
         del args[k_t]
         del args[k_l]
         del args[k_b]
         del args[k_r]
     if k in args:
         assert not onnx_attr
         onnx_attr = onnx_pb2.AttributeProto()
         onnx_attr.name = ks
         onnx_attr_assign(onnx_attr, ArgType.INTS, [args[k][1]] * dim)
         del args[k]
     if onnx_attr:
         if op_type in ['GlobalMaxPool', 'GlobalAveragePool']:
             # TODO: check the values are equal to the default values in c2
             pass
         else:
             onnx_attrs.append(onnx_attr)
Ejemplo n.º 3
0
def get_onnx_attrs(op_type, op_def):
    onnx_attrs = []
    args = {a.name: get_caffe2_arg_type_and_val(a) for a in op_def.arg}
    if op_type in [
            'Conv', 'ConvTranspose', 'MaxPool', 'GlobalMaxPool', 'AveragePool',
            'GlobalAveragePool', 'ChannelShuffle'
    ]:

        def apply_trans(args, k, dim=2):
            onnx_attr = None
            if dim == 2:
                k_h, k_w, ks = k + '_h', k + '_w', k + 's'
            else:
                k_t, k_l, k_b, k_r, ks = k + '_t', k + '_l', k + '_b', k + '_r', k + 's'
            if dim == 2 and k_h in args and k_w in args:
                assert not onnx_attr
                onnx_attr = onnx_pb2.AttributeProto()
                onnx_attr.name = ks
                onnx_attr_assign(onnx_attr, ArgType.INTS,
                                 [args[k_h][1], args[k_w][1]])
                del args[k_h]
                del args[k_w]
            if dim == 4 and k_t in args and k_l in args and k_b in args and k_r in args:
                assert not onnx_attr
                onnx_attr = onnx_pb2.AttributeProto()
                onnx_attr.name = ks
                onnx_attr_assign(
                    onnx_attr, ArgType.INTS,
                    [args[k_t][1], args[k_l][1], args[k_b][1], args[k_r][1]])
                del args[k_t]
                del args[k_l]
                del args[k_b]
                del args[k_r]
            if k in args:
                assert not onnx_attr
                onnx_attr = onnx_pb2.AttributeProto()
                onnx_attr.name = ks
                onnx_attr_assign(onnx_attr, ArgType.INTS, [args[k][1]] * dim)
                del args[k]
            if onnx_attr:
                if op_type in ['GlobalMaxPool', 'GlobalAveragePool']:
                    # TODO: check the values are equal to the default values in c2
                    pass
                else:
                    onnx_attrs.append(onnx_attr)

        apply_trans(args, 'kernel')
        apply_trans(args, 'stride')
        apply_trans(args, 'dilation')
        apply_trans(args, 'adj')
        apply_trans(args, 'pad', 4)

    for a in args:
        t, val = args[a]
        if a in _expected_arg_values:
            if val not in _expected_arg_values[a]:
                raise Exception('value {} not in the expected value list({})'
                                'for argument {}'.format(
                                    val, _expected_arg_values[a], a))
        if a not in _blacklist_caffe2_args:
            onnx_attr = onnx_pb2.AttributeProto()
            onnx_attr.name = a
            onnx_attr_assign(onnx_attr, t, val)
            onnx_attrs.append(onnx_attr)

    for attr in onnx_attrs:
        if op_type in _renamed_args and attr.name in _renamed_args[op_type]:
            attr.name = _renamed_args[op_type][attr.name]
    return onnx_attrs