def caffe2_op_to_node_def(op_def, env): node_def = onnx_pb2.NodeProto() # NB: This must happen BEFORE we start freshening inplace outputs node_def.input.extend(map(env.rename, op_def.input)) node_def.op_type = get_node_op_type(op_def) # Determine what was inplace updates input_set = set(op_def.input) output_set = set(op_def.output) schema = onnx.defs.get_schema(node_def.op_type) # ints does not support extend() consumes = [] for i, x in enumerate(op_def.input): is_consumed, output_idx = schema.consumed(i) if is_consumed == onnx.defs.OpSchema.UseType.CONSUME_ENFORCED: consumes.append(1) elif is_consumed == onnx.defs.OpSchema.UseType.CONSUME_ALLOWED: if x in output_set: consumes.append(1) else: consumes.append(0) else: if x in output_set: raise RuntimeError( "schema says consume not allowed, but caffe2 used inplace syntax" ) consumes.append(0) if any(consumes): consumes_attr = onnx_pb2.AttributeProto() consumes_attr.name = "consumed_inputs" consumes_attr.ints.extend(consumes) else: consumes_attr = None def fresh_or_rename(out): if out in input_set: return env.fresh(out) else: return env.rename(out) node_def.output.extend(map(fresh_or_rename, op_def.output)) # TODO: refactor frontend to allow special handling for individual ops if node_def.op_type == 'Concat': assert len(node_def.output) == 2 del node_def.output[1] node_def.name = op_def.name attrs = get_onnx_attrs(node_def.op_type, op_def) if consumes_attr: attrs.append(consumes_attr) node_def.attribute.extend(attrs) checker.check_node(node_def) return node_def
def apply_trans(args, k, dim=2): onnx_attr = None if dim == 2: k_h, k_w, ks = k + '_h', k + '_w', k + 's' else: k_t, k_l, k_b, k_r, ks = k + '_t', k + '_l', k + '_b', k + '_r', k + 's' if dim == 2 and k_h in args and k_w in args: assert not onnx_attr onnx_attr = onnx_pb2.AttributeProto() onnx_attr.name = ks onnx_attr_assign(onnx_attr, ArgType.INTS, [args[k_h][1], args[k_w][1]]) del args[k_h] del args[k_w] if dim == 4 and k_t in args and k_l in args and k_b in args and k_r in args: assert not onnx_attr onnx_attr = onnx_pb2.AttributeProto() onnx_attr.name = ks onnx_attr_assign( onnx_attr, ArgType.INTS, [args[k_t][1], args[k_l][1], args[k_b][1], args[k_r][1]]) del args[k_t] del args[k_l] del args[k_b] del args[k_r] if k in args: assert not onnx_attr onnx_attr = onnx_pb2.AttributeProto() onnx_attr.name = ks onnx_attr_assign(onnx_attr, ArgType.INTS, [args[k][1]] * dim) del args[k] if onnx_attr: if op_type in ['GlobalMaxPool', 'GlobalAveragePool']: # TODO: check the values are equal to the default values in c2 pass else: onnx_attrs.append(onnx_attr)
def get_onnx_attrs(op_type, op_def): onnx_attrs = [] args = {a.name: get_caffe2_arg_type_and_val(a) for a in op_def.arg} if op_type in [ 'Conv', 'ConvTranspose', 'MaxPool', 'GlobalMaxPool', 'AveragePool', 'GlobalAveragePool', 'ChannelShuffle' ]: def apply_trans(args, k, dim=2): onnx_attr = None if dim == 2: k_h, k_w, ks = k + '_h', k + '_w', k + 's' else: k_t, k_l, k_b, k_r, ks = k + '_t', k + '_l', k + '_b', k + '_r', k + 's' if dim == 2 and k_h in args and k_w in args: assert not onnx_attr onnx_attr = onnx_pb2.AttributeProto() onnx_attr.name = ks onnx_attr_assign(onnx_attr, ArgType.INTS, [args[k_h][1], args[k_w][1]]) del args[k_h] del args[k_w] if dim == 4 and k_t in args and k_l in args and k_b in args and k_r in args: assert not onnx_attr onnx_attr = onnx_pb2.AttributeProto() onnx_attr.name = ks onnx_attr_assign( onnx_attr, ArgType.INTS, [args[k_t][1], args[k_l][1], args[k_b][1], args[k_r][1]]) del args[k_t] del args[k_l] del args[k_b] del args[k_r] if k in args: assert not onnx_attr onnx_attr = onnx_pb2.AttributeProto() onnx_attr.name = ks onnx_attr_assign(onnx_attr, ArgType.INTS, [args[k][1]] * dim) del args[k] if onnx_attr: if op_type in ['GlobalMaxPool', 'GlobalAveragePool']: # TODO: check the values are equal to the default values in c2 pass else: onnx_attrs.append(onnx_attr) apply_trans(args, 'kernel') apply_trans(args, 'stride') apply_trans(args, 'dilation') apply_trans(args, 'adj') apply_trans(args, 'pad', 4) for a in args: t, val = args[a] if a in _expected_arg_values: if val not in _expected_arg_values[a]: raise Exception('value {} not in the expected value list({})' 'for argument {}'.format( val, _expected_arg_values[a], a)) if a not in _blacklist_caffe2_args: onnx_attr = onnx_pb2.AttributeProto() onnx_attr.name = a onnx_attr_assign(onnx_attr, t, val) onnx_attrs.append(onnx_attr) for attr in onnx_attrs: if op_type in _renamed_args and attr.name in _renamed_args[op_type]: attr.name = _renamed_args[op_type][attr.name] return onnx_attrs