def test_prelu_nhwc(): x = sym.Variable("x") a = sym.Variable("a") y = sym.prelu(data=x, alpha=a, axis=3) def forward(x, a): return (x < 0) * (x * a.reshape(1, 1, 3)) + (x >= 0) * x shape = {'x': (1, 32, 32, 3), 'a': (3, )} check_function(y, forward, shape=shape)
def test_prelu_nhwc(): x = sym.Variable("x") a = sym.Variable("a") y = sym.prelu(data=x, alpha=a, axis=3) def forward(x, a): return (x < 0) * (x * a.reshape(1, 1, 3)) + (x>=0) * x shape = {'x': (1, 32, 32, 3), 'a': (3,)} check_function(y, forward, shape=shape)
def test_prelu_nchw(): x = sym.Variable("x") a = sym.Variable("a") y = sym.prelu(data=x, alpha=a) def forward(x, a): return (x < 0) * (x * a.reshape(3, 1, 1)) + (x >= 0) * x dtype = "float32" dshape_x = (1, 3, 32, 32) dshape_w = (3, ) inputs = [('x', dshape_x, x), ('a', dshape_w, a)] helper(y, inputs, dtype, forward)
def check(in_shape, axis, out_shape): x = sym.Variable("x", shape=in_shape) w = sym.Variable("w") y = sym.prelu(x, w, axis=axis, name="y") sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape))
def check(in_shape, axis, out_shape): x = sym.Variable("x", shape=in_shape) w = sym.Variable("w") y = sym.prelu(x, w, axis=axis, name="y") sdict = infer_shape(y) assert(tuple(sdict["y"][0]) == tuple(out_shape))
def test_prelu(): x = sym.Variable("x") w = sym.Variable("w") y = sym.prelu(x, w) assert(y.list_input_names()[0] == 'x') assert(y.list_input_names()[1] == 'w')