def _is_squeeze_invariant_to_perm(squeeze_axes, perm): dummy_shape = list(range(len(perm))) perm_dummy_shape = utils.apply_permutation(dummy_shape, perm) perm_squeeze_axes = Transposer.apply_permutation_to_axes( squeeze_axes, perm) shape1 = squeezed_shape(dummy_shape, squeeze_axes, can_squeeze_not_one=True) shape2 = squeezed_shape(perm_dummy_shape, perm_squeeze_axes, can_squeeze_not_one=True) return shape1 == shape2
def apply_transpose_to_varlike(tensor, axes, transforms_by_name): # type: (BaseTensor, typing.List[int], typing.Dict[str, typing.List[Transform]])->None if tensor.rank <= 1: return old_shape = tensor.shape tensor.shape = utils.apply_permutation(old_shape, axes) if tensor.is_variable and tensor.data.size > 0: tensor.data = np.transpose(tensor.data, axes) elif tensor.is_constant: if len(tensor.data) > 1: tensor.data = (np.array(tensor.data).reshape(old_shape).transpose( axes).flatten().tolist()) add_transform(transforms_by_name, tensor, Transpose(axes))
def transform_remove_inverse_transposes( g, # type: BaseGraph transforms_by_name, # type:typing.Dict[str, typing.List[Transform]] merge_into_constants, # type: bool merge_into_variables, # type: bool driver, # type: DataFormatOptimizationDriver transposable_ops=None, # type: typing.Optional[typing.List[TransposableOperation]] ): # type: (...)-> None if transposable_ops is None: transposable_ops = [] transposable_op_by_name = { } # type: typing.Dict[str, TransposableOperation] transposable_op_by_name.update({top.name: top for top in transposable_ops}) for op in g.operations: if op.name == driver.transpose_op_name and op.output.rank > len( driver.get_axes_from_transpose(op)): driver.set_axes_on_transpose( op, driver.get_axes_from_transpose(op) + list(range( op.output.rank))[len(driver.get_axes_from_transpose(op)):]) matches = _find_inverse_transposes( g, transposable_op_names=set(six.iterkeys(transposable_op_by_name)), merge_into_constants=merge_into_constants, merge_into_variables=merge_into_variables, driver=driver) for axes, subgraph in matches: upper_perm = axes if subgraph.started_down else utils.inverse_permutation( axes) lower_perm = utils.inverse_permutation(upper_perm) upper_boundary = [ be for be in subgraph.boundary_elements if not be.from_up ] lower_boundary = [ be for be in subgraph.boundary_elements if be.from_up ] for _, tensor in upper_boundary: if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name: if tensor in g.outputs: graph_output = driver.create_tensor( graph=g, name=tensor.name, shape=utils.apply_permutation( tensor.producer.input.shape, upper_perm), dtype=tensor.producer.input.dtype) driver.create_transpose_op(graph=g, input=tensor.producer.input, axes=list(upper_perm), output=graph_output) graph_utils.replace_tensor_in_outputs( g, tensor, graph_output) elif (len(tensor.producer.input.consumers) == 1 and tensor.producer.input not in g.inputs and tensor.producer.input not in g.outputs): tensor.producer.input.name = tensor.name add_transform(transforms_by_name, tensor.producer.input, Transpose(lower_perm)) remove_passthrough_ex(g, tensor.producer) else: assert (merge_into_variables and tensor.is_variable) \ or (merge_into_constants and tensor.is_constant) apply_transpose_to_varlike(tensor, lower_perm, transforms_by_name) skipped_ops = set( tensor.producer for tensor in subgraph.skipped_tensors) # type: typing.Set[BaseOperation] for op in skipped_ops: assert op.name in transposable_op_by_name transposable_op_by_name[op.name].dg_transpose( _transposer, g, op, lower_perm) for output in op.outputs: if output in g.outputs: graph_output = driver.create_tensor(graph=g, name=output.name, shape=output.shape, dtype=output.dtype) driver.create_transpose_op(graph=g, input=output, axes=list(upper_perm), output=graph_output) graph_utils.replace_tensor_in_outputs( g, output, graph_output) output.name = None output.shape = utils.apply_permutation( output.shape, lower_perm) else: output.shape = utils.apply_permutation( output.shape, lower_perm) add_transform(transforms_by_name, output, Transpose(lower_perm)) for _, tensor in lower_boundary: if tensor.producer is not None and tensor.producer.name == driver.transpose_op_name: if tensor in g.outputs: graph_output = driver.create_tensor( graph=g, name=tensor.name, shape=tensor.producer.input.shape, dtype=tensor.producer.input.dtype) driver.create_copy_op(graph=g, input=tensor.producer.input, output=graph_output) graph_utils.replace_tensor_in_outputs( g, tensor, graph_output) remove_passthrough_ex(g, tensor.producer) elif tensor.producer is not None and tensor.producer.name == driver.squeeze_op_name: driver.set_axes_on_squeeze( tensor.producer, sorted( Transposer.apply_permutation_to_axes( driver.get_axes_from_squeeze(tensor.producer), lower_perm))) else: assert False graph_utils.remove_unreachable(g)
def transform_io(g, io_transform, transforms_by_name, driver): # type:(BaseGraph, TrafoOrTrafoDictType, typing.Dict[str, typing.List[Transform]], DataFormatOptimizationDriver)->None io_tensors_by_name = {t.name: t for t in list(g.inputs) + list(g.outputs)} transform_by_io_tensor = {} if isinstance(io_transform, dict): for k, v in six.iteritems(io_transform): assert isinstance(k, (str, driver.tensor_type)), \ "io_transform: Key type must be {} or str".format(driver.tensor_type.__name__) assert isinstance( v, Transform), "io_transform: Value type must be Transform" if isinstance(k, BaseTensor): assert k in six.itervalues(io_tensors_by_name) else: assert k in io_tensors_by_name k = io_tensors_by_name[k] transform_by_io_tensor[k] = v for io_tensor in six.itervalues(io_tensors_by_name): assert io_tensor in transform_by_io_tensor, \ "io_transform: Please specify transform for all io tensors. " \ "You can use graph_optimizer.IDENTITY if no change is required." else: assert isinstance(io_transform, Transform), \ "io_transform must be Transform or Dict[str, Transform] or Dict[NNEFTensor, Transform]" for t in six.itervalues(io_tensors_by_name): transform_by_io_tensor[t] = io_transform for tensor, transform in six.iteritems(transform_by_io_tensor): assert bool(tensor in g.inputs) != bool(tensor in g.outputs), \ "Tensor must be input or output (and not both)" assert isinstance(transform, (Transpose, _CustomTransform)), \ "Unsupported io_transform" if isinstance(transform, _Identity): continue if isinstance(transform, _SmartTFNCHWToNCHW): try: _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) except _TransformException: pass continue if isinstance(transform, _SmartNHWCToNCHW): if tensor.rank <= 2: continue transform = Transpose([0, tensor.rank - 1] + list(range(tensor.rank))[1:-1]) if isinstance(transform, _SmartTFNHWCToNCHW): try: _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) continue except _TransformException: if tensor.rank <= 2: continue transform = Transpose([0, tensor.rank - 1] + list(range(tensor.rank))[1:-1]) if isinstance(transform, _SmartNCHWToTFNCHW): try: _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) except _TransformException: pass continue if isinstance(transform, _SmartNCHWToNHWC): if tensor.rank <= 2: continue transform = Transpose([0] + list(range(tensor.rank))[2:] + [1]) if isinstance(transform, _SmartNCHWToTFNHWC): try: _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) continue except _TransformException: if tensor.rank <= 2: continue transform = Transpose([0] + list(range(tensor.rank))[2:] + [1]) if isinstance(transform, _TFFilterGradToNNEF): _transform_tf_filter_grad_to_nnef(g, tensor, transforms_by_name, driver) continue if isinstance(transform, _NNEFFilterGradToTF): _transform_nnef_filter_grad_to_tf(g, tensor, transforms_by_name, driver) continue assert isinstance( transform, Transpose), "Unsupported io_transform: {}".format(transform) assert len(transform.axes) == tensor.rank, "Transpose: invalid rank" if transform.is_identity(): continue if tensor in g.inputs: assert tensor.name new_input_tensor = driver.create_tensor( graph=g, name=tensor.name, shape=utils.apply_permutation(tensor.shape, transform.axes), dtype=tensor.dtype) add_transform(transforms_by_name, new_input_tensor, transform) transpose = driver.create_transpose_op( graph=g, input=new_input_tensor, axes=utils.inverse_permutation(transform.axes), output=driver.create_tensor(graph=g, name=None, shape=tensor.shape, dtype=tensor.dtype)) graph_utils.replace_tensor_in_inputs(g, tensor, new_input_tensor) graph_utils.replace_tensor_in_consumers(g, tensor, transpose.output, remove=True) else: # output transpose = driver.create_transpose_op( graph=g, input=tensor, axes=transform.axes, output=driver.create_tensor(graph=g, name=tensor.name, shape=utils.apply_permutation( tensor.shape, transform.axes), dtype=tensor.dtype)) add_transform(transforms_by_name, transpose.output, transform) tensor.name = None graph_utils.replace_tensor_in_outputs(g, tensor, transpose.output)
def apply_permutation(list, perm): return utils.apply_permutation(list, perm)
def _box_or_max_pool( input, # type: torch.Tensor size, # type: List[int] border='constant', # type: str padding=None, # type: Optional[List[Tuple[int, int]]], stride=None, # type: Optional[List[int]], dilation=None, # type: Optional[List[int]] normalize=False, # type: bool is_max_pool=False, # type: bool ): assert not (normalize and is_max_pool) rank = len(input.shape) padding, stride, dilation = _evaluate_max_pool_or_box_params( input_shape=list(input.shape), size=size, padding=padding, stride=stride, dilation=dilation) active = [ size_ != 1 or padding_ != (0, 0) or stride_ != 1 or dilation_ != 1 for size_, padding_, stride_, dilation_ in zip( size, padding, stride, dilation) ] if sum(active) == 0: return input if rank < 3: perm, perm_inv, inactive_shape, active_shape = None, None, None, None else: perm, perm_inv, inactive_shape, active_shape = _get_transform_for_box_or_max_pool( list(input.shape), active) if rank < 3: input = input.unsqueeze(0).unsqueeze(0) size = [1, 1] + size padding = [(0, 0), (0, 0)] + padding stride = [1, 1] + stride dilation = [1, 1] + dilation elif perm is not None: input = input.permute(*perm) size = utils.apply_permutation(size, perm) padding = utils.apply_permutation(padding, perm) stride = utils.apply_permutation(stride, perm) dilation = utils.apply_permutation(dilation, perm) active_rank = len(active_shape) input = input.reshape(*[utils.product(inactive_shape), 1] + active_shape) size = [1, 1] + size[-active_rank:] padding = [(0, 0), (0, 0)] + padding[-active_rank:] stride = [1, 1] + stride[-active_rank:] dilation = [1, 1] + dilation[-active_rank:] if is_max_pool: output = _max_pool_impl(input=input, size=size, border=border, padding=padding, stride=stride, dilation=dilation, with_index=False) else: output = _box_impl(input=input, size=size, border=border, padding=padding, stride=stride, dilation=dilation, normalize=normalize) if rank < 3: output = output.squeeze(0).squeeze(0) elif perm is not None: active_rank = len(active_shape) output = output.reshape(inactive_shape + list(output.shape)[-active_rank:]) output = output.permute(*perm_inv) return output