def verify_squeeze(dshape, axis): x = sym.Variable("x") if axis: y = sym.squeeze(x, axis=axis) else: y = sym.squeeze(x) y = y + 1 def forward(x): return np.squeeze(x, axis=axis) + 1 dtype = "float32" inputs = {'x': (dshape, x)} helper(y, inputs, dtype, forward)
def verify_squeeze(shape, axis): x = sym.Variable("x") if axis is not None: y = sym.squeeze(x, axis=axis) else: y = sym.squeeze(x) y = y + 1 def forward(x): return np.squeeze(x, axis=axis) + 1 def backward(head_grads, x): return [np.reshape(head_grads, x.shape)] check_function(y, forward, backward, shape=shape)
def verify_squeeze(dshape, axis): x = sym.Variable("x") if axis: y = sym.squeeze(x, axis=axis) else: y = sym.squeeze(x) y = y + 1 def forward(x): return np.squeeze(x, axis=axis) + 1 def backward(head_grads, x): return [np.reshape(head_grads, x.shape)] dtype = "float32" inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward)
def verify_squeeze(dshape, axis): x = sym.Variable("x") if axis: y = sym.squeeze(x, axis=axis) else: y = sym.squeeze(x) y = y + 1 dtype = "float32" for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = graph_runtime.create(graph, lib, ctx) # set input data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) m.run(x=data) out_np = np.squeeze(data.asnumpy(), axis=axis) + 1 out = m.get_output(0, tvm.nd.empty(out_np.shape)) np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_squeeze(): x = sym.Variable("x", shape=(1, 1, 1, 10)) y = sym.squeeze(x, axis=(1,2), name='squeeze') sdict = infer_shape(y) assert(sdict['squeeze'][0] == [1, 10]) x = sym.Variable("x", shape=(1, 3, 1)) y = sym.squeeze(x, name='squeeze') sdict = infer_shape(y) assert(sdict['squeeze'][0] == [3]) y = sym.squeeze(x, axis=(0), name='squeeze') sdict = infer_shape(y) assert(sdict['squeeze'][0] == [3, 1]) y = sym.squeeze(x, axis=(0,2), name='squeeze') sdict = infer_shape(y) assert(sdict['squeeze'][0] == [3])