def propagate_tile(op, const_value_by_tensor): # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] input, multiples = op.inputs multiples = const_value_by_tensor[multiples].tolist() # type: typing.List[int] return [infer.tile(input=input.shape, repeat=multiples)], [op.attribs['T']]
def tile_shape(op): # type: (Caffe2Operation)->ShapeResult assert len(op.inputs) in [1, 2, 3] tiles = op.attribs.get('tiles', 1) axis = op.attribs.get('axis', 0) if len(op.inputs) >= 2: if op.inputs[1].data is None: raise utils.NNEFToolsException( 'Tile is not supported with calculated sizes.') tiles = op.inputs[1].data.item() if len(op.inputs) >= 3: if op.inputs[2].data is None: raise utils.NNEFToolsException( 'Tile is not supported with calculated sizes.') axis = op.inputs[2].data.item() repeats = [1] * op.inputs[0].rank repeats[axis] = tiles op.attribs['tiles'] = tiles op.attribs['axis'] = axis op.inputs = (op.inputs[0], ) return infer.tile(op.inputs[0].shape, repeats), op.inputs[0].dtype
def propagate_tile(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] if len(op.inputs) == 3: input, tiles, axis = op.inputs tiles = evaluate_scalar_int_tensor_simple(tiles) axis = evaluate_scalar_int_tensor_simple(axis) output_shape = list(input.shape) output_shape[axis] *= tiles return [output_shape], [op.input.dtype] elif len(op.inputs) == 2: input, repeats = op.inputs return [ infer.tile(input=input.shape, repeat=evaluate_shape_tensor_simple(repeats)) ], [input.dtype] else: assert False, 'Tile must have 2 or 3 inputs'
def propagate_tile(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] if len(op.inputs) == 3: input, tiles, axis = op.inputs tiles = evaluate_scalar_int_tensor_simple(tiles) axis = evaluate_scalar_int_tensor_simple(axis) repeats = [1] * input.rank repeats[axis] = tiles op.inputs = (input, ) op.attribs['repeats'] = repeats elif len(op.inputs) == 2: input, repeats = op.inputs op.inputs = (input, ) op.attribs['repeats'] = evaluate_shape_tensor_simple(repeats) else: assert False, 'Tile must have 2 or 3 inputs' return [infer.tile(input=input.shape, repeat=op.attribs['repeats'])], [input.dtype]
def test_tile(self): self.assertEqual([4, 6, 6, 4], infer.tile(input=[1, 2, 3, 4], repeat=[4, 3, 2, 1]))
def tile_shape(op): rank = op.inputs[0].shape axis = op.attrib['axis'] return shapes.tile( op.inputs[0].shape, repeat=[op.attrib['tiles'] if i == axis else 1 for i in range(rank)])