Пример #1
0
    def test_strided_slice(self):
        anything = 0xcafecafe
        self.assertEqual([2, 1, 3, 3], infer.strided_slice(input=[10, 32, 32, 3],
                                                           begin=[1, anything, 2, 2, 0],
                                                           end=[3, anything, 5, 3, 3],
                                                           stride=[1, anything, 1, 1, 1],
                                                           new_axis_mask=[0, 1, 0, 0, 0],
                                                           shrink_axis_mask=[0, 0, 0, 1, 0],
                                                           begin_mask=[0, 0, 0, 0, 0],
                                                           end_mask=[0, 0, 0, 0, 0],
                                                           ellipsis_mask=[0, 0, 0, 0, 0]))
        self.assertEqual([10, 32, 10, 1], infer.strided_slice(input=[10, 32, 32, 3],
                                                              begin=[anything, anything, anything, anything],
                                                              end=[anything, anything, anything, anything],
                                                              stride=[1, -1, 3, -2],
                                                              new_axis_mask=[0, 0, 0, 0],
                                                              shrink_axis_mask=[0, 0, 0, 0],
                                                              begin_mask=[1, 1, 1, 1],
                                                              end_mask=[1, 1, 1, 1],
                                                              ellipsis_mask=[0, 0, 0, 0]))

        self.assertEqual([1, 32, 32, 1], infer.strided_slice(input=[10, 32, 32, 3],
                                                             begin=[0, anything, 0],
                                                             end=[1, anything, 1],
                                                             stride=[1, anything, 1],
                                                             new_axis_mask=[0, 0, 0],
                                                             shrink_axis_mask=[0, 0, 0],
                                                             begin_mask=[0, 0, 0],
                                                             end_mask=[0, 0, 0],
                                                             ellipsis_mask=[0, 1, 0]))

        self.assertEqual([1, 32, 32, 1], infer.strided_slice(input=[10, 32, 32, 3],
                                                             begin=[0, anything, 0],
                                                             end=[1, anything, 1],
                                                             stride=[1, anything, 1],
                                                             new_axis_mask=0,
                                                             shrink_axis_mask=0,
                                                             begin_mask=0,
                                                             end_mask=0,
                                                             ellipsis_mask=2))
Пример #2
0
def propagate_strided_slice(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, begin, end, strides = op.inputs
    begin = const_value_by_tensor[begin].tolist()  # type: typing.List[int]
    end = const_value_by_tensor[end].tolist()  # type: typing.List[int]
    strides = const_value_by_tensor[strides].tolist()  # type: typing.List[int]
    return [infer.strided_slice(input=input.shape,
                                begin=begin,
                                end=end,
                                stride=strides,
                                ellipsis_mask=op.attribs['ellipsis_mask'],
                                new_axis_mask=op.attribs['new_axis_mask'],
                                shrink_axis_mask=op.attribs['shrink_axis_mask'],
                                begin_mask=op.attribs['begin_mask'],
                                end_mask=op.attribs['end_mask'])], [op.attribs['T']]