def slice_exporter(op_def, context): node, const_tensors = export_util.translate(**locals()) in_shape = context.blob_shapes[op_def.input[0]] starts, sizes, ends = [], [], [] for arg in op_def.arg: if arg.name == 'starts': starts = [int(e) for e in arg.ints] elif arg.name == 'starts_desc': starts = helper.fetch_argument(op_def, arg, context.ws) elif arg.name == 'starts_descs': starts = helper.fetch_arguments(op_def, arg, context.ws) elif arg.name == 'sizes': sizes = [int(e) for e in arg.ints] elif arg.name == 'sizes_desc': sizes = helper.fetch_argument(op_def, arg, context.ws) elif arg.name == 'sizes_descs': sizes = helper.fetch_arguments(op_def, arg, context.ws) for i, size in enumerate(sizes): if size == -1: ends.append(in_shape[i]) elif size == 0: ends.append(starts[i] + 1) else: ends.append(starts[i] + size) return node, starts, ends
def transpose_exporter(op_def, context): node, const_tensors = export_util.translate(**locals()) for arg in op_def.arg: if arg.name == 'perm': helper.add_attribute(node, 'perm', arg.ints) elif arg.name == 'perm_desc': values = helper.fetch_argument(op_def, arg, context.ws) helper.add_attribute(node, 'perm', values) elif arg.name == 'perm_descs': if len(arg.strings) > 0: values = helper.fetch_arguments(op_def, arg, context.ws) helper.add_attribute(node, 'perm', values) return node, None
def pad_exporter(op_def, context): node, const_tensors = export_util.translate(**locals()) pads, value = [], 0 for arg in op_def.arg: if arg.name == 'pads': pads = [int(e) for e in arg.ints] elif arg.name == 'pads_desc': pads = helper.fetch_argument(op_def, arg, context.ws) elif arg.name == 'pads_descs': pads = helper.fetch_arguments(op_def, arg, context.ws) elif arg.name == 'mode': helper.add_attribute(node, 'mode', arg.s.lower()) elif arg.name == 'value': value = arg.f return node, pads, value
def tile_exporter(op_def, context): node, const_tensors = export_util.translate(**locals()) repeats = [] for arg in op_def.arg: if arg.name == 'repeats': repeats = [e for e in arg.ints] elif arg.name == 'repeats_desc': repeats = helper.fetch_argument(op_def, arg, context.ws) elif arg.name == 'repeats_descs': repeats = helper.fetch_arguments(op_def, arg, context.ws) repeats = helper.from_array( numpy.array(repeats, 'int64'), context.unique_name(op_def.input[0] + '/tile/repeats'), ) node.input.extend([repeats.name]) return node, [repeats]
def reshape_exporter(op_def, context): node, const_tensors = export_util.translate(**locals()) shape = dims = list(context.blob_shapes[op_def.output[0]]) for arg in op_def.arg: if arg.name == 'dims': dims = [int(e) for e in arg.ints] elif arg.name == 'dims_desc': dims = helper.fetch_argument(op_def, arg, context.ws) elif arg.name == 'dims_descs': dims = helper.fetch_arguments(op_def, arg, context.ws) for axis, dim in enumerate(dims): shape[axis] = dim if dim <= 0 else shape[axis] shape = helper.from_array( numpy.array(shape, 'int64'), context.unique_name(op_def.input[0] + '/reshape/shape'), ) node.input.extend([shape.name]) return node, [shape]
def channel_normalize_exporter(op_def, context): node, const_tensors = export_util.translate(**locals()) node.op_type = 'ATen' # Currently not supported in ai.onnx helper.add_attribute(node, 'op_type', 'ChannelNormalize') for arg in op_def.arg: if arg.name == 'mean': helper.add_attribute(node, 'mean', arg.floats) elif arg.name == 'std': helper.add_attribute(node, 'std', arg.floats) elif arg.name == 'axis': helper.add_attribute(node, 'axis', arg.i) elif arg.name == 'dtype': helper.add_attribute(node, 'dtype', arg.s) elif arg.name == 'perm': helper.add_attribute(node, 'perm', arg.ints) elif arg.name == 'perm_desc': values = helper.fetch_argument(op_def, arg, context.ws) helper.add_attribute(node, 'perm', values) elif arg.name == 'perm_descs': if len(arg.strings) > 0: values = helper.fetch_arguments(op_def, arg, context.ws) helper.add_attribute(node, 'perm', values) return node, const_tensors