def is_boundary(tensor2, from_up, started_down): if tensor2 in visited_tensors: return False if started_down: perm = utils.inverse_permutation(axes) else: perm = axes if from_up: is_boundary_transpose = ( tensor2.producer is not None and tensor2.producer.name == driver.transpose_op_name and driver.get_axes_from_transpose( tensor2.producer) == perm) is_boundary_squeeze = (tensor2.producer is not None and tensor2.producer.name == driver.squeeze_op_name and _is_squeeze_invariant_to_perm( driver.get_axes_from_squeeze( tensor2.producer), perm)) return is_boundary_transpose or is_boundary_squeeze else: is_boundary_transpose = ( tensor2.producer is not None and tensor2.producer.name == driver.transpose_op_name and driver.get_axes_from_transpose(tensor2.producer) == utils.inverse_permutation(perm)) is_boundary_varlike = ( ((merge_into_constants and tensor2.is_constant) or (merge_into_variables and tensor2.is_variable)) and len(tensor2.shape) in [0, 1, len(axes)]) return is_boundary_transpose or is_boundary_varlike
def evaluate_shape_of_transpose_grad(op): # type: (TFOperation)->None grad_shape = op.inputs[2].shape # grad perm = op.attribs["orig_perm"] if perm is None: op.output.shape = _unify_shape(reversed(grad_shape)) else: op.output.shape = _unify_shape(_np_permute(grad_shape, utils.inverse_permutation(perm)))
def transform_transpose_grad(g, op): # type: (TFGraph, TFOperation)->None grad = op.inputs[2] TFOperation( graph=g, name="tf.transpose", inputs=grad, attribs=dict(perm=utils.inverse_permutation(op.attribs["orig_perm"])), outputs=op.outputs) g.remove_operation(op, unlink=True)
def _get_transform_for_box_or_max_pool(input_shape, active): # type: (List[int], List[bool])->Any assert len(input_shape) >= 3 assert len(input_shape) == len(active) if sum(active) > 3: raise utils.NNEFToolsException( "Sliding window operations are not supported if they have more than 3 'active' dimensions. " "We have: {}".format(sum(active))) if 3 <= len(input_shape ) <= 5 and not active[0] and not active[1]: # Direct support return None, None, None, None else: inactive_dims = [i for i, a in enumerate(active) if not a] active_dims = [i for i, a in enumerate(active) if a] inactive_shape = [ s for i, s in enumerate(input_shape) if i not in active_dims ] active_shape = [ s for i, s in enumerate(input_shape) if i in active_dims ] perm = inactive_dims + active_dims perm_inv = utils.inverse_permutation(perm) return perm, perm_inv, inactive_shape, active_shape
def transform_remove_inverse_transposes( g, # type: BaseGraph transforms_by_name, # type:typing.Dict[str, typing.List[Transform]] merge_into_constants, # type: bool merge_into_variables, # type: bool driver, # type: DataFormatOptimizationDriver transposable_ops=None, # type: typing.Optional[typing.List[TransposableOperation]] ): # type: (...)-> None if transposable_ops is None: transposable_ops = [] transposable_op_by_name = { } # type: typing.Dict[str, TransposableOperation] transposable_op_by_name.update({top.name: top for top in transposable_ops}) for op in g.operations: if op.name == driver.transpose_op_name and op.output.rank > len( driver.get_axes_from_transpose(op)): driver.set_axes_on_transpose( op, driver.get_axes_from_transpose(op) + list(range( op.output.rank))[len(driver.get_axes_from_transpose(op)):]) matches = _find_inverse_transposes( g, transposable_op_names=set(six.iterkeys(transposable_op_by_name)), merge_into_constants=merge_into_constants, merge_into_variables=merge_into_variables, driver=driver) for axes, subgraph in matches: upper_perm = axes if subgraph.started_down else utils.inverse_permutation( axes) lower_perm = utils.inverse_permutation(upper_perm) upper_boundary = [ be for be in subgraph.boundary_elements if not be.from_up ] lower_boundary = [ be for be in subgraph.boundary_elements if be.from_up ] for _, tensor in upper_boundary: if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name: if tensor in g.outputs: graph_output = driver.create_tensor( graph=g, name=tensor.name, shape=utils.apply_permutation( tensor.producer.input.shape, upper_perm), dtype=tensor.producer.input.dtype) driver.create_transpose_op(graph=g, input=tensor.producer.input, axes=list(upper_perm), output=graph_output) graph_utils.replace_tensor_in_outputs( g, tensor, graph_output) elif (len(tensor.producer.input.consumers) == 1 and tensor.producer.input not in g.inputs and tensor.producer.input not in g.outputs): tensor.producer.input.name = tensor.name add_transform(transforms_by_name, tensor.producer.input, Transpose(lower_perm)) remove_passthrough_ex(g, tensor.producer) else: assert (merge_into_variables and tensor.is_variable) \ or (merge_into_constants and tensor.is_constant) apply_transpose_to_varlike(tensor, lower_perm, transforms_by_name) skipped_ops = set( tensor.producer for tensor in subgraph.skipped_tensors) # type: typing.Set[BaseOperation] for op in skipped_ops: assert op.name in transposable_op_by_name transposable_op_by_name[op.name].dg_transpose( _transposer, g, op, lower_perm) for output in op.outputs: if output in g.outputs: graph_output = driver.create_tensor(graph=g, name=output.name, shape=output.shape, dtype=output.dtype) driver.create_transpose_op(graph=g, input=output, axes=list(upper_perm), output=graph_output) graph_utils.replace_tensor_in_outputs( g, output, graph_output) output.name = None output.shape = utils.apply_permutation( output.shape, lower_perm) else: output.shape = utils.apply_permutation( output.shape, lower_perm) add_transform(transforms_by_name, output, Transpose(lower_perm)) for _, tensor in lower_boundary: if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name: if tensor in g.outputs: graph_output = driver.create_tensor( graph=g, name=tensor.name, shape=tensor.producer.input.shape, dtype=tensor.producer.input.dtype) driver.create_copy_op(graph=g, input=tensor.producer.input, output=graph_output) graph_utils.replace_tensor_in_outputs( g, tensor, graph_output) remove_passthrough_ex(g, tensor.producer) elif tensor.producer is not None and tensor.producer.name == driver.squeeze_op_name: driver.set_axes_on_squeeze( tensor.producer, sorted( Transposer.apply_permutation_to_axes( driver.get_axes_from_squeeze(tensor.producer), lower_perm))) else: assert False graph_utils.remove_unreachable(g)
def transform_io(g, io_transform, transforms_by_name, driver): # type:(BaseGraph, TrafoOrTrafoDictType, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None io_tensors_by_name = {t.name: t for t in list(g.inputs) + list(g.outputs)} transform_by_io_tensor = {} if isinstance(io_transform, dict): for k, v in six.iteritems(io_transform): assert isinstance(k, (str, driver.tensor_type)), \ "io_transform: Key type must be {} or str".format(driver.tensor_type.__name__) assert isinstance( v, Transform), "io_transform: Value type must be Transform" if isinstance(k, BaseTensor): assert k in six.itervalues(io_tensors_by_name) else: assert k in io_tensors_by_name k = io_tensors_by_name[k] transform_by_io_tensor[k] = v for io_tensor in six.itervalues(io_tensors_by_name): assert io_tensor in transform_by_io_tensor, \ "io_transform: Please specify transform for all io tensors. " \ "You can use graph_optimizer.IDENTITY if no change is required." else: assert isinstance(io_transform, Transform), \ "io_transform must be Transform or Dict[str, Transform] or Dict[NNEFTensor, Transform]" for t in six.itervalues(io_tensors_by_name): transform_by_io_tensor[t] = io_transform for tensor, transform in six.iteritems(transform_by_io_tensor): assert bool(tensor in g.inputs) != bool(tensor in g.outputs), \ "Tensor must be input or output (and not both)" assert isinstance(transform, (Transpose, _CustomTransform)), \ "Unsupported io_transform" if isinstance(transform, _Identity): continue if isinstance(transform, _SmartTFNCHWToNCHW): try: _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) except _TransformException: pass continue if isinstance(transform, _SmartNHWCToNCHW): if tensor.rank <= 2: continue transform = Transpose([0, tensor.rank - 1] + list(range(tensor.rank))[1:-1]) if isinstance(transform, _SmartTFNHWCToNCHW): try: _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) continue except _TransformException: if tensor.rank <= 2: continue transform = Transpose([0, tensor.rank - 1] + list(range(tensor.rank))[1:-1]) if isinstance(transform, _SmartNCHWToTFNCHW): try: _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) except _TransformException: pass continue if isinstance(transform, _SmartNCHWToNHWC): if tensor.rank <= 2: continue transform = Transpose([0] + list(range(tensor.rank))[2:] + [1]) if isinstance(transform, _SmartNCHWToTFNHWC): try: _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) continue except _TransformException: if tensor.rank <= 2: continue transform = Transpose([0] + list(range(tensor.rank))[2:] + [1]) if isinstance(transform, _TFFilterGradToNNEF): _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) continue if isinstance(transform, _NNEFFilterGradToTF): _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) continue assert isinstance( transform, Transpose), "Unsupported io_transform: {}".format(transform) assert len(transform.axes) == tensor.rank, "Transpose: invalid rank" if transform.is_identity(): continue if tensor in g.inputs: assert tensor.name new_input_tensor = driver.create_tensor( graph=g, name=tensor.name, shape=utils.apply_permutation(tensor.shape, transform.axes), dtype=tensor.dtype) add_transform(transforms_by_name, new_input_tensor, transform) transpose = driver.create_transpose_op( graph=g, input=new_input_tensor, axes=utils.inverse_permutation(transform.axes), output=driver.create_tensor(graph=g, name=None, shape=tensor.shape, dtype=tensor.dtype)) graph_utils.replace_tensor_in_inputs(g, tensor, new_input_tensor) graph_utils.replace_tensor_in_consumers(g, tensor, transpose.output, remove=True) else: # output transpose = driver.create_transpose_op( graph=g, input=tensor, axes=transform.axes, output=driver.create_tensor(graph=g, name=tensor.name, shape=utils.apply_permutation( tensor.shape, transform.axes), dtype=tensor.dtype)) add_transform(transforms_by_name, transpose.output, transform) tensor.name = None graph_utils.replace_tensor_in_outputs(g, tensor, transpose.output)
def _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver): # type: (BaseGraph, BaseTensor, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None assert driver.conv_grad_filter_op_names cgf1_output = matcher.Tensor() cgf1 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf1_output) transpose1 = matcher.Operation(name=driver.transpose_op_name, inputs=cgf1_output) cgf2_output = matcher.Tensor() cgf2 = matcher.Operation(name=driver.conv_grad_filter_op_names, outputs=cgf2_output) reshape2_output = matcher.Tensor() reshape2 = matcher.Operation(name=driver.reshape_op_name, inputs=cgf2_output, outputs=reshape2_output) transpose2 = matcher.Operation(name=driver.transpose_op_name, inputs=reshape2_output) if tensor.producer is None: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") m = matcher.match(g, tensor.producer, matcher.OrPattern(transpose1, transpose2)) if transpose1 in m: cgf = m[cgf1] # type: BaseOperation transpose = m[transpose1] # type: BaseOperation if not (len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose]) elif transpose2 in m: cgf = m[cgf2] # type: BaseOperation reshape = m[reshape2] # type: BaseOperation transpose = m[transpose2] # type: BaseOperation if not (len(reshape.output.consumers) <= 1 and len(transpose.output.consumers) <= 1 and cgf.output not in g.outputs): raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF") cgf.output.name = transpose.output.name add_transform( transforms_by_name, cgf.output, Transpose( utils.inverse_permutation( driver.get_axes_from_transpose(transpose)))) add_transform(transforms_by_name, cgf.output, Reshape(cgf.output.shape)) graph_utils.replace_tensor_in_outputs(g, transpose.output, cgf.output) graph_utils.remove_subgraph(g, [transpose, reshape]) else: raise _TransformException("Cannot apply TF_FILTER_GRAD_TO_NNEF")
def inverse_permutation(perm): return utils.inverse_permutation(perm)
def expand_softmax(tf_graph, tf_op): assert tf_op.input.rank != 0 axis = tf_op.attribs.get('axis') if axis is None: axis = -1 if axis < 0: axis += tf_op.input.rank tf_op.attribs['axis'] = -1 if tf_op.input.rank == 2 and axis == 1: return if axis != tf_op.input.rank - 1: perm = utils.without(range(tf_op.input.rank), axis) + [axis] perm_inv = utils.inverse_permutation(perm) transpose = TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.input, attribs=dict(perm=perm), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.transpose( input=tf_op.input.shape, axes=perm), dtype=tf_op.input.dtype)) tf_op.inputs = transpose.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=tf_op.input.shape, dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.output, attribs=dict(perm=perm_inv), outputs=old_output) if tf_op.input.rank != 2: shape = [-1, tf_op.input.shape[-1]] reshape = TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.input, attribs=dict(shape=shape), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.reshape( input=tf_op.input.shape, shape=shape), dtype=tf_op.input.dtype)) tf_op.inputs = reshape.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=list(tf_op.input.shape), dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.output, attribs=dict(shape=old_output.shape), outputs=old_output)