def add_unsqueeze(onnx_graph, onnx_tensor, axes, unsqueezed_tensor=None): # type: (ONNXGraph, ONNXTensor, typing.List[int], typing.Optional[ONNXTensor])->ONNXTensor if unsqueezed_tensor is None: unsqueezed_tensor = ONNXTensor(graph=onnx_graph, shape=transforms.unsqueezed_shape(onnx_tensor.shape, axes), dtype=onnx_tensor.dtype) return ONNXOperation(graph=onnx_graph, name="Unsqueeze", inputs=onnx_tensor, attribs=dict(axes=axes), outputs=unsqueezed_tensor).output
def generic_convert_argminmax(op, target_name): # type: (TFOperation, str)->None kTfLiteInt64 = 4 axis = op.attribs["axis"] op.name = target_name op.inputs = (op.input, TFTensor(graph=op.graph, shape=[], dtype='INT32', data=[axis])) op.attribs = dict(output_type=kTfLiteInt64) output_tensor = op.output op.outputs = (TFTensor(graph=op.graph, shape=unsqueezed_shape(shape=output_tensor.shape, axes=[axis]), dtype=output_tensor.dtype),) TFOperation(graph=op.graph, name="SQUEEZE", inputs=op.output, outputs=output_tensor, attribs=dict(squeeze_dims=[axis]))
def merge_transforms_into_varlikes(g, transforms_by_name, merge_into_constants, merge_into_variables, driver): # type: (BaseGraph, typing.Dict[str, typing.List[Transform]], bool, bool, DataFormatOptimizationDriver)->None transform_ops = [ driver.squeeze_op_name, driver.unsqueeze_op_name, driver.reshape_op_name, driver.transpose_op_name, driver.copy_op_name ] def get_param(op): if op.name == driver.squeeze_op_name: return driver.get_axes_from_squeeze(tensor.consumers[0]) elif op.name == driver.unsqueeze_op_name: return driver.get_axes_from_unsqueeze(tensor.consumers[0]) elif op.name == driver.transpose_op_name: return driver.get_axes_from_transpose(tensor.consumers[0]) elif op.name == driver.reshape_op_name: return driver.get_shape_from_reshape(tensor.consumers[0]) elif op.name == driver.copy_op_name: return None else: assert False for tensor in list(g.tensors): while (((merge_into_variables and tensor.is_variable) or (merge_into_constants and tensor.is_constant)) and len(tensor.consumers) >= 1 and tensor.consumers[0].name in transform_ops and tensor is tensor.consumers[0].inputs[0] # need to check for onnx graph and all(t not in g.outputs for t in [tensor, tensor.consumers[0].output])): op_name = tensor.consumers[0].name op_param = get_param(tensor.consumers[0]) if not all(op.name == op_name and get_param(op) == op_param for op in tensor.consumers[1:]): break if op_name == driver.squeeze_op_name: axes = op_param tensor.shape = squeezed_shape(tensor.shape, axes) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.squeeze(tensor.data, tuple(axes)) elif tensor.is_constant: pass # good as it is add_transform(transforms_by_name, tensor, Squeeze(axes)) elif op_name == driver.unsqueeze_op_name: axes = op_param tensor.shape = unsqueezed_shape(tensor.shape, axes) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.reshape(tensor.data, tensor.shape) elif tensor.is_constant: pass # good as it is add_transform(transforms_by_name, tensor, Unsqueeze(axes)) elif op_name == driver.reshape_op_name: tensor.shape = _reshaped_shape(tensor.shape, op_param) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.reshape(tensor.data, tensor.shape) elif tensor.is_constant: pass # good as it is add_transform(transforms_by_name, tensor, Reshape(tensor.shape)) elif op_name == driver.transpose_op_name: apply_transpose_to_varlike(tensor, op_param, transforms_by_name) elif op_name == driver.copy_op_name: pass else: assert False for op in list(tensor.consumers): remove_passthrough_ex(g, op)