Ejemplo n.º 1
0
def gen_data(begin, begin_mask, dtype, ellipsis_mask, end, end_mask,
             new_axis_mask, shape, shrink_axis_mask, strides):
    """ Generate data for testing the op """
    input = np.random.uniform(low=-1.0, high=1.0,
                              size=tuple(shape)).astype(dtype)
    # get numpy result
    slices = strided_slice.args_to_slices(begin, end, strides, begin_mask,
                                          end_mask, ellipsis_mask,
                                          new_axis_mask, shrink_axis_mask)
    expect = input[tuple(slices)]
    out_shape = expect.shape if expect.shape != (0, ) else (1, )
    output = np.full(out_shape, np.nan, dtype)
    return expect, input, output
Ejemplo n.º 2
0
def strided_slice_python(input_shape,
                         begin,
                         end,
                         strides,
                         begin_mask,
                         end_mask,
                         ellipsis_mask,
                         new_axis_mask,
                         shrink_axis_mask,
                         grad,
                         dtype=np.float16):
    slices = strided_slice.args_to_slices(begin, end, strides, begin_mask,
                                          end_mask, ellipsis_mask,
                                          new_axis_mask, shrink_axis_mask)

    dx = np.zeros(input_shape).astype(dtype)
    dx[tuple(slices)] = grad
    return dx
Ejemplo n.º 3
0
def check_grad_shape(input_shape,
                     begin,
                     end,
                     strides,
                     begin_mask,
                     end_mask,
                     ellipsis_mask,
                     new_axis_mask,
                     shrink_axis_mask,
                     grad_shape_given,
                     dtype=np.float16):
    slices = strided_slice.args_to_slices(begin, end, strides, begin_mask,
                                          end_mask, ellipsis_mask,
                                          new_axis_mask, shrink_axis_mask)
    dx = np.zeros(input_shape)
    grad_shape = dx[tuple(slices)].shape
    assert list(grad_shape) == list(grad_shape_given), \
        ("parameters invalid: grad shape should be ", list(grad_shape),
         "but given is", list(grad_shape_given))
    return slices