Пример #1
0
def propagate_tile(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, multiples = op.inputs
    multiples = const_value_by_tensor[multiples].tolist()  # type: typing.List[int]

    return [infer.tile(input=input.shape,
                       repeat=multiples)], [op.attribs['T']]
Пример #2
0
def tile_shape(op):
    # type: (Caffe2Operation)->ShapeResult

    assert len(op.inputs) in [1, 2, 3]

    tiles = op.attribs.get('tiles', 1)
    axis = op.attribs.get('axis', 0)

    if len(op.inputs) >= 2:
        if op.inputs[1].data is None:
            raise utils.NNEFToolsException(
                'Tile is not supported with calculated sizes.')
        tiles = op.inputs[1].data.item()
    if len(op.inputs) >= 3:
        if op.inputs[2].data is None:
            raise utils.NNEFToolsException(
                'Tile is not supported with calculated sizes.')
        axis = op.inputs[2].data.item()

    repeats = [1] * op.inputs[0].rank
    repeats[axis] = tiles

    op.attribs['tiles'] = tiles
    op.attribs['axis'] = axis
    op.inputs = (op.inputs[0], )

    return infer.tile(op.inputs[0].shape, repeats), op.inputs[0].dtype
Пример #3
0
def propagate_tile(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    if len(op.inputs) == 3:
        input, tiles, axis = op.inputs

        tiles = evaluate_scalar_int_tensor_simple(tiles)
        axis = evaluate_scalar_int_tensor_simple(axis)

        output_shape = list(input.shape)
        output_shape[axis] *= tiles

        return [output_shape], [op.input.dtype]

    elif len(op.inputs) == 2:
        input, repeats = op.inputs
        return [
            infer.tile(input=input.shape,
                       repeat=evaluate_shape_tensor_simple(repeats))
        ], [input.dtype]
    else:
        assert False, 'Tile must have 2 or 3 inputs'
Пример #4
0
def propagate_tile(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    if len(op.inputs) == 3:
        input, tiles, axis = op.inputs

        tiles = evaluate_scalar_int_tensor_simple(tiles)
        axis = evaluate_scalar_int_tensor_simple(axis)
        repeats = [1] * input.rank
        repeats[axis] = tiles

        op.inputs = (input, )
        op.attribs['repeats'] = repeats
    elif len(op.inputs) == 2:
        input, repeats = op.inputs

        op.inputs = (input, )
        op.attribs['repeats'] = evaluate_shape_tensor_simple(repeats)
    else:
        assert False, 'Tile must have 2 or 3 inputs'

    return [infer.tile(input=input.shape,
                       repeat=op.attribs['repeats'])], [input.dtype]
Пример #5
0
 def test_tile(self):
     self.assertEqual([4, 6, 6, 4], infer.tile(input=[1, 2, 3, 4], repeat=[4, 3, 2, 1]))
Пример #6
0
def tile_shape(op):
    rank = op.inputs[0].shape
    axis = op.attrib['axis']
    return shapes.tile(
        op.inputs[0].shape,
        repeat=[op.attrib['tiles'] if i == axis else 1 for i in range(rank)])