コード例 #1
0
def verify_expand_like(in_shape, out_shape, axis, exclude):
    x = sym.Variable("x")
    y = sym.Variable("y")
    z = sym.expand_like(x, y, axis=axis, exclude=exclude)

    def forward(x, y):
        odim = len(out_shape)
        real_axis = [i if i >= 0 else i + odim for i in axis]
        real_axis = sorted(real_axis)
        if exclude:
            real_axis = list(set(range(odim)) - set(real_axis))
        for i in real_axis:
            x = np.expand_dims(x, i).astype(x.dtype)
        for i in real_axis:
            x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype)

        return x

    def backward(head_grads, x, y):
        odim = len(out_shape)
        real_axis = [i if i >= 0 else i + odim for i in axis]
        real_axis = sorted(real_axis)
        if exclude:
            real_axis = list(set(range(odim)) - set(real_axis))
        return [np.sum(head_grads, axis=tuple(real_axis)),
                np.zeros_like(y)]


    dtype = "float32"
    inputs = [('x', in_shape, x),
              ('y', out_shape, y)]
    helper(z, inputs, dtype, forward, backward, need_input=False)
コード例 #2
0
ファイル: test_top_level4.py プロジェクト: zhangquan920/tasn
def verify_expand_like(in_shape, out_shape, axis, exclude):
    x = sym.Variable("x")
    y = sym.Variable("y")
    z = sym.expand_like(x, y, axis=axis, exclude=exclude)

    def forward(x, y):
        odim = len(out_shape)

        if len(x.shape) == len(y.shape):
            return np.broadcast_to(x, y.shape)

        if x.shape == (1, ) and len(y.shape) == odim:
            x = np.reshape(x, ())

        real_axis = [i if i >= 0 else i + odim for i in axis]
        real_axis = sorted(real_axis)
        if exclude:
            real_axis = list(set(range(odim)) - set(real_axis))
        for i in real_axis:
            x = np.expand_dims(x, i).astype(x.dtype)
        for i in real_axis:
            x = np.concatenate([x] * out_shape[i], axis=i).astype(x.dtype)

        return x

    def backward(head_grads, x, y):
        odim = len(out_shape)

        keepdims = len(x.shape) == len(y.shape)

        if x.shape == (1, ) and len(y.shape) == odim:
            x = np.reshape(x, ())

        real_axis = [i if i >= 0 else i + odim for i in axis]
        real_axis = sorted(real_axis)
        if exclude:
            real_axis = list(set(range(odim)) - set(real_axis))
        return [
            np.sum(head_grads, axis=tuple(real_axis), keepdims=keepdims),
            np.zeros_like(y)
        ]

    shape = {'x': in_shape, 'y': out_shape}
    check_function(z, forward, backward, shape=shape)
コード例 #3
0
ファイル: test_top_level4.py プロジェクト: bddppq/tvm
def verify_expand_like(in_shape, out_shape, axis, exclude):
    x = sym.Variable("x")
    y = sym.Variable("y")
    z = sym.expand_like(x, y, axis=axis, exclude=exclude)

    def forward(x, y):
        odim = len(out_shape)

        if len(x.shape) == len(y.shape):
            return np.broadcast_to(x, y.shape)

        if x.shape == (1,) and len(y.shape) == odim:
            x = np.reshape(x, ())

        real_axis = [i if i >= 0 else i + odim for i in axis]
        real_axis = sorted(real_axis)
        if exclude:
            real_axis = list(set(range(odim)) - set(real_axis))
        for i in real_axis:
            x = np.expand_dims(x, i).astype(x.dtype)
        for i in real_axis:
            x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype)

        return x

    def backward(head_grads, x, y):
        odim = len(out_shape)

        keepdims = len(x.shape) == len(y.shape)

        if x.shape == (1,) and len(y.shape) == odim:
            x = np.reshape(x, ())

        real_axis = [i if i >= 0 else i + odim for i in axis]
        real_axis = sorted(real_axis)
        if exclude:
            real_axis = list(set(range(odim)) - set(real_axis))
        return [np.sum(head_grads, axis=tuple(real_axis), keepdims=keepdims),
                np.zeros_like(y)]


    shape = {'x': in_shape, 'y': out_shape}
    check_function(z, forward, backward, shape=shape)