def test_set_symbolic_shape(): a = Tensor([1.0, 2.0]) @trace(symbolic=True, capture_as_const=True) def fwd(a): return F.relu(a * 2) fwd(a) orig_model = io.BytesIO() fwd.dump( orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False, ) orig_model.seek(0) net = Net.load(orig_model) var_a = net.input_vars[0] saved_symbolic_shape = set_symbolic_shape(True) assert isinstance(var_a.shape, VarNode) set_symbolic_shape(False) assert var_a.shape == var_a.partial_shape set_symbolic_shape(saved_symbolic_shape)
def test_squeeze(is_varnode): if is_varnode: network = Network() saved_symbolic_shape = set_symbolic_shape(False) else: network = None x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) xx = make_tensor(x, network) for axis in [None, 3, -4, (3, -4)]: y = np.squeeze(x, axis) yy = F.squeeze(xx, axis) np.testing.assert_equal(y, yy.numpy()) if is_varnode: set_symbolic_shape(saved_symbolic_shape)
def test_split_basic(is_varnode): if is_varnode: network = Network() saved_symbolic_shape = set_symbolic_shape(False) else: network = None data = np.random.random((2, 3, 4, 5)).astype(np.float32) inp = make_tensor(data, network) mge_out0 = F.split(inp, 2, axis=3) mge_out1 = F.split(inp, [3], axis=3) np_out = np.split(data, [3, 5], axis=3) assert len(mge_out0) == 2 assert len(mge_out1) == 2 np.testing.assert_equal(mge_out0[0].numpy(), np_out[0]) np.testing.assert_equal(mge_out1[0].numpy(), np_out[0]) np.testing.assert_equal(mge_out0[1].numpy(), np_out[1]) np.testing.assert_equal(mge_out1[1].numpy(), np_out[1]) try: F.split(inp, 4) assert False except ValueError as e: pass try: F.split(inp, [3, 2, 5], axis=3) assert False except ValueError as e: assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]" if is_varnode: set_symbolic_shape(saved_symbolic_shape)
def test_reshape_shape_inference(is_varnode): if is_varnode: network = Network() saved_symbolic_shape = set_symbolic_shape(False) else: network = None x_shape_known = make_tensor([1, 2, 3, 4], network) x_shape_unknown = F.broadcast_to(make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()) tshp_unknown = astensor1d( (make_tensor([2], network), make_tensor([2], network)), x_shape_known) tshp_known = astensor1d((2, 2), x_shape_known) tshp_known_unspec = astensor1d((2, -1), x_shape_known) def check_shape(output, target): source = output.shape if isinstance(source, tensor): source = source.numpy() np.testing.assert_equal(source, target.shape) def func(x, target_shape): return x.reshape(target_shape) cases = [ { "input": [x_shape_known, tshp_unknown], "output": [ np.zeros((2, 2)), ] }, { "input": [x_shape_unknown, tshp_unknown], "output": [ np.zeros((2, 2)), ] }, { "input": [x_shape_known, tshp_known], "output": [ np.zeros((2, 2)), ] }, { "input": [x_shape_known, tshp_known_unspec], "output": [ np.zeros((2, 2)), ] }, { "input": [x_shape_unknown, tshp_known], "output": [ np.zeros((2, 2)), ] }, { "input": [x_shape_unknown, tshp_known_unspec], "output": [ np.zeros((2, 2)), ] }, ] opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) if is_varnode: set_symbolic_shape(saved_symbolic_shape)