def _is_squeeze_invariant_to_perm(squeeze_axes, perm): dummy_shape = list(range(len(perm))) perm_dummy_shape = utils.apply_permutation(dummy_shape, perm) perm_squeeze_axes = Transposer.apply_permutation_to_axes( squeeze_axes, perm) shape1 = squeezed_shape(dummy_shape, squeeze_axes, can_squeeze_not_one=True) shape2 = squeezed_shape(perm_dummy_shape, perm_squeeze_axes, can_squeeze_not_one=True) return shape1 == shape2
def add_squeeze(caffe2_graph, caffe2_tensor, dims): # type: (Caffe2Graph, Caffe2Tensor, typing.List[int])->Caffe2Tensor return Caffe2Operation(graph=caffe2_graph, name="Squeeze", inputs=caffe2_tensor, attribs=dict(dims=dims), outputs=Caffe2Tensor(graph=caffe2_graph, shape=transforms.squeezed_shape(caffe2_tensor.shape, dims), dtype=caffe2_tensor.dtype)).output
def add_squeeze(onnx_graph, onnx_tensor, axes, squeezed_tensor=None): # type: (ONNXGraph, ONNXTensor, typing.List[int], typing.Optional[ONNXTensor])->ONNXTensor if squeezed_tensor is None: squeezed_tensor = ONNXTensor(graph=onnx_graph, shape=transforms.squeezed_shape(onnx_tensor.shape, axes), dtype=onnx_tensor.dtype) return ONNXOperation(graph=onnx_graph, name="Squeeze", inputs=onnx_tensor, attribs=dict(axes=axes), outputs=squeezed_tensor).output
def generic_convert_argminmax(op, target_name): # type: (TFOperation, str)->None tflite_to_tf_dtype = {4: 9, 2: 3} op.name = target_name assert op.inputs[ 1].data is not None, "ARG_MIN/ARG_MAX is only supported with constant axis (inputs[1])" axis = op.inputs[1].data.tolist() op.attribs = dict( axis=axis, output_type=tflite_to_tf_dtype[op.attribs["output_type"]]) op.inputs = (op.inputs[0], ) output_tensor = op.output op.outputs = (TFTensor(graph=op.graph, shape=squeezed_shape(shape=output_tensor.shape, axes=[axis]), dtype=output_tensor.dtype), ) TFOperation(graph=op.graph, name="tf.expand_dims", inputs=op.output, outputs=output_tensor, attribs=dict(axis=axis))
def merge_transforms_into_varlikes(g, transforms_by_name, merge_into_constants, merge_into_variables, driver): # type: (BaseGraph, typing.Dict[str, typing.List[Transform]], bool, bool, DataFormatOptimizationDriver)->None transform_ops = [ driver.squeeze_op_name, driver.unsqueeze_op_name, driver.reshape_op_name, driver.transpose_op_name, driver.copy_op_name ] def get_param(op): if op.name == driver.squeeze_op_name: return driver.get_axes_from_squeeze(tensor.consumers[0]) elif op.name == driver.unsqueeze_op_name: return driver.get_axes_from_unsqueeze(tensor.consumers[0]) elif op.name == driver.transpose_op_name: return driver.get_axes_from_transpose(tensor.consumers[0]) elif op.name == driver.reshape_op_name: return driver.get_shape_from_reshape(tensor.consumers[0]) elif op.name == driver.copy_op_name: return None else: assert False for tensor in list(g.tensors): while (((merge_into_variables and tensor.is_variable) or (merge_into_constants and tensor.is_constant)) and len(tensor.consumers) >= 1 and tensor.consumers[0].name in transform_ops and tensor is tensor.consumers[0].inputs[0] # need to check for onnx graph and all(t not in g.outputs for t in [tensor, tensor.consumers[0].output])): op_name = tensor.consumers[0].name op_param = get_param(tensor.consumers[0]) if not all(op.name == op_name and get_param(op) == op_param for op in tensor.consumers[1:]): break if op_name == driver.squeeze_op_name: axes = op_param tensor.shape = squeezed_shape(tensor.shape, axes) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.squeeze(tensor.data, tuple(axes)) elif tensor.is_constant: pass # good as it is add_transform(transforms_by_name, tensor, Squeeze(axes)) elif op_name == driver.unsqueeze_op_name: axes = op_param tensor.shape = unsqueezed_shape(tensor.shape, axes) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.reshape(tensor.data, tensor.shape) elif tensor.is_constant: pass # good as it is add_transform(transforms_by_name, tensor, Unsqueeze(axes)) elif op_name == driver.reshape_op_name: tensor.shape = _reshaped_shape(tensor.shape, op_param) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.reshape(tensor.data, tensor.shape) elif tensor.is_constant: pass # good as it is add_transform(transforms_by_name, tensor, Reshape(tensor.shape)) elif op_name == driver.transpose_op_name: apply_transpose_to_varlike(tensor, op_param, transforms_by_name) elif op_name == driver.copy_op_name: pass else: assert False for op in list(tensor.consumers): remove_passthrough_ex(g, op)