def apply_transpose_to_varlike(tensor, axes, transforms_by_name): # type: (BaseTensor, typing.List[int], typing.Dict[str, typing.List[Transform]])->None if tensor.rank <= 1: return old_shape = tensor.shape tensor.shape = utils.apply_permutation(old_shape, axes) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.transpose(tensor.data, axes) elif tensor.is_constant: if len(tensor.data) > 1: tensor.data = (np.array(tensor.data).reshape(old_shape).transpose( axes).flatten().tolist()) add_transform(transforms_by_name, tensor, Transpose(axes))
class IOTransform(object): Transpose = Transpose NDHWC_TO_NCDHW = Transpose([0, 4, 1, 2, 3]) NCDHW_TO_NDHWC = Transpose([0, 2, 3, 4, 1]) NHWC_TO_NCHW = Transpose([0, 3, 1, 2]) NCHW_TO_NHWC = Transpose([0, 2, 3, 1]) NWC_TO_NCW = Transpose([0, 3, 1]) NCW_TO_NWC = Transpose([0, 2, 1]) IDENTITY = _Identity() TF_FILTER_GRAD_TO_NNEF = _TFFilterGradToNNEF() NNEF_FILTER_GRAD_TO_TF = _NNEFFilterGradToTF() SMART_NCHW_TO_NHWC = _SmartNCHWToNHWC() SMART_NHWC_TO_NCHW = _SmartNHWCToNCHW() SMART_TF_NHWC_TO_NCHW = _SmartTFNHWCToNCHW() SMART_TF_NCHW_TO_NCHW = _SmartTFNCHWToNCHW() SMART_NCHW_TO_TF_NHWC = _SmartNCHWToTFNHWC() SMART_NCHW_TO_TF_NCHW = _SmartNCHWToTFNCHW()
def transform_remove_inverse_transposes( g, # type: BaseGraph transforms_by_name, # type:typing.Dict[str, typing.List[Transform]] merge_into_constants, # type: bool merge_into_variables, # type: bool driver, # type: DataFormatOptimizationDriver transposable_ops=None, # type: typing.Optional[typing.List[TransposableOperation]] ): # type: (...)-> None if transposable_ops is None: transposable_ops = [] transposable_op_by_name = { } # type: typing.Dict[str, TransposableOperation] transposable_op_by_name.update({top.name: top for top in transposable_ops}) for op in g.operations: if op.name == driver.transpose_op_name and op.output.rank > len( driver.get_axes_from_transpose(op)): driver.set_axes_on_transpose( op, driver.get_axes_from_transpose(op) + list(range( op.output.rank))[len(driver.get_axes_from_transpose(op)):]) matches = _find_inverse_transposes( g, transposable_op_names=set(six.iterkeys(transposable_op_by_name)), merge_into_constants=merge_into_constants, merge_into_variables=merge_into_variables, driver=driver) for axes, subgraph in matches: upper_perm = axes if subgraph.started_down else utils.inverse_permutation( axes) lower_perm = utils.inverse_permutation(upper_perm) upper_boundary = [ be for be in subgraph.boundary_elements if not be.from_up ] lower_boundary = [ be for be in subgraph.boundary_elements if be.from_up ] for _, tensor in upper_boundary: if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name: if tensor in g.outputs: graph_output = driver.create_tensor( graph=g, name=tensor.name, shape=utils.apply_permutation( tensor.producer.input.shape, upper_perm), dtype=tensor.producer.input.dtype) driver.create_transpose_op(graph=g, input=tensor.producer.input, axes=list(upper_perm), output=graph_output) graph_utils.replace_tensor_in_outputs( g, tensor, graph_output) elif (len(tensor.producer.input.consumers) == 1 and tensor.producer.input not in g.inputs and tensor.producer.input not in g.outputs): tensor.producer.input.name = tensor.name add_transform(transforms_by_name, tensor.producer.input, Transpose(lower_perm)) remove_passthrough_ex(g, tensor.producer) else: assert (merge_into_variables and tensor.is_variable) \ or (merge_into_constants and tensor.is_constant) apply_transpose_to_varlike(tensor, lower_perm, transforms_by_name) skipped_ops = set( tensor.producer for tensor in subgraph.skipped_tensors) # type: typing.Set[BaseOperation] for op in skipped_ops: assert op.name in transposable_op_by_name transposable_op_by_name[op.name].dg_transpose( _transposer, g, op, lower_perm) for output in op.outputs: if output in g.outputs: graph_output = driver.create_tensor(graph=g, name=output.name, shape=output.shape, dtype=output.dtype) driver.create_transpose_op(graph=g, input=output, axes=list(upper_perm), output=graph_output) graph_utils.replace_tensor_in_outputs( g, output, graph_output) output.name = None output.shape = utils.apply_permutation( output.shape, lower_perm) else: output.shape = utils.apply_permutation( output.shape, lower_perm) add_transform(transforms_by_name, output, Transpose(lower_perm)) for _, tensor in lower_boundary: if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name: if tensor in g.outputs: graph_output = driver.create_tensor( graph=g, name=tensor.name, shape=tensor.producer.input.shape, dtype=tensor.producer.input.dtype) driver.create_copy_op(graph=g, input=tensor.producer.input, output=graph_output) graph_utils.replace_tensor_in_outputs( g, tensor, graph_output) remove_passthrough_ex(g, tensor.producer) elif tensor.producer is not None and tensor.producer.name == driver.squeeze_op_name: driver.set_axes_on_squeeze( tensor.producer, sorted( Transposer.apply_permutation_to_axes( driver.get_axes_from_squeeze(tensor.producer), lower_perm))) else: assert False graph_utils.remove_unreachable(g)
def transform_io(g, io_transform, transforms_by_name, driver): # type:(BaseGraph, TrafoOrTrafoDictType, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None io_tensors_by_name = {t.name: t for t in list(g.inputs) + list(g.outputs)} transform_by_io_tensor = {} if isinstance(io_transform, dict): for k, v in six.iteritems(io_transform): assert isinstance(k, (str, driver.tensor_type)), \ "io_transform: Key type must be {} or str".format(driver.tensor_type.__name__) assert isinstance( v, Transform), "io_transform: Value type must be Transform" if isinstance(k, BaseTensor): assert k in six.itervalues(io_tensors_by_name) else: assert k in io_tensors_by_name k = io_tensors_by_name[k] transform_by_io_tensor[k] = v for io_tensor in six.itervalues(io_tensors_by_name): assert io_tensor in transform_by_io_tensor, \ "io_transform: Please specify transform for all io tensors. " \ "You can use graph_optimizer.IDENTITY if no change is required." else: assert isinstance(io_transform, Transform), \ "io_transform must be Transform or Dict[str, Transform] or Dict[NNEFTensor, Transform]" for t in six.itervalues(io_tensors_by_name): transform_by_io_tensor[t] = io_transform for tensor, transform in six.iteritems(transform_by_io_tensor): assert bool(tensor in g.inputs) != bool(tensor in g.outputs), \ "Tensor must be input or output (and not both)" assert isinstance(transform, (Transpose, _CustomTransform)), \ "Unsupported io_transform" if isinstance(transform, _Identity): continue if isinstance(transform, _SmartTFNCHWToNCHW): try: _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) except _TransformException: pass continue if isinstance(transform, _SmartNHWCToNCHW): if tensor.rank <= 2: continue transform = Transpose([0, tensor.rank - 1] + list(range(tensor.rank))[1:-1]) if isinstance(transform, _SmartTFNHWCToNCHW): try: _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) continue except _TransformException: if tensor.rank <= 2: continue transform = Transpose([0, tensor.rank - 1] + list(range(tensor.rank))[1:-1]) if isinstance(transform, _SmartNCHWToTFNCHW): try: _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) except _TransformException: pass continue if isinstance(transform, _SmartNCHWToNHWC): if tensor.rank <= 2: continue transform = Transpose([0] + list(range(tensor.rank))[2:] + [1]) if isinstance(transform, _SmartNCHWToTFNHWC): try: _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) continue except _TransformException: if tensor.rank <= 2: continue transform = Transpose([0] + list(range(tensor.rank))[2:] + [1]) if isinstance(transform, _TFFilterGradToNNEF): _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) continue if isinstance(transform, _NNEFFilterGradToTF): _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) continue assert isinstance( transform, Transpose), "Unsupported io_transform: {}".format(transform) assert len(transform.axes) == tensor.rank, "Transpose: invalid rank" if transform.is_identity(): continue if tensor in g.inputs: assert tensor.name new_input_tensor = driver.create_tensor( graph=g, name=tensor.name, shape=utils.apply_permutation(tensor.shape, transform.axes), dtype=tensor.dtype) add_transform(transforms_by_name, new_input_tensor, transform) transpose = driver.create_transpose_op( graph=g, input=new_input_tensor, axes=utils.inverse_permutation(transform.axes), output=driver.create_tensor(graph=g, name=None, shape=tensor.shape, dtype=tensor.dtype)) graph_utils.replace_tensor_in_inputs(g, tensor, new_input_tensor) graph_utils.replace_tensor_in_consumers(g, tensor, transpose.output, remove=True) else: # output transpose = driver.create_transpose_op( graph=g, input=tensor, axes=transform.axes, output=driver.create_tensor(graph=g, name=tensor.name, shape=utils.apply_permutation( tensor.shape, transform.axes), dtype=tensor.dtype)) add_transform(transforms_by_name, transpose.output, transform) tensor.name = None graph_utils.replace_tensor_in_outputs(g, tensor, transpose.output)
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver): # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None assert driver.conv_grad_filter_op_names cgf1_output = matcher.Tensor() cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf1_output) transpose1 = matcher.Operation(name=driver.transpose_op_name, inputs=cgf1_output) cgf2_output = matcher.Tensor() cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf2_output) reshape2_output = matcher.Tensor() reshape2 = matcher.Operation(name=driver.reshape_op_name, inputs=cgf2_output, outputs=reshape2_output) transpose2 = matcher.Operation(name=driver.transpose_op_name, inputs=reshape2_output) if tensor.producer is None: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") m = matcher.match(g, tensor.producer, matcher.OrPattern(transpose1, transpose2)) if transpose1 in m: cgf = m[cgf1] # type: BaseOperation transpose = m[transpose1] # type: BaseOperation if not (len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose]) elif transpose2 in m: cgf = m[cgf2] # type: BaseOperation reshape = m[reshape2] # type: BaseOperation transpose = m[transpose2] # type: BaseOperation if not (len(reshape.output.consumers) <= 1 and len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) add_transform(transforms_by_name, cgf.output, Reshape(cgf.output.shape)) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose, reshape]) else: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")