def test_pick_partial_shape(): x = mx.sym.Variable("x") index = mx.sym.Variable("index") y = mx.sym.pick(x, index, axis=1) # batch size unknown _, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,)) assert result == [(0, 3)] with mx.np_compat(True): _, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,)) assert result == [(-1, 3)]
def test_embedding_partial_shape(): # testing embedding with batch size unknown x = mx.sym.Variable("x") w = mx.sym.Variable("w") y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10) _, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10)) assert result_shape == [(0, 5, 10)] with mx.np_compat(True): _, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10)) assert result_shape == [(-1, 5, 10)]
def test_dot_partial_shape(): x = mx.sym.Variable("x") y = mx.sym.Variable("y") z = mx.sym.dot(x, y) # batch size(first dim) of lhs unknown _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5)) assert result_shape == [(0, 3, 5)] with mx.np_compat(True): _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5)) assert result_shape == [(-1, 3, 5)]
def test_shape_completely_unknown(): data = mx.sym.var("data") ret = mx.sym.sin(data) arg_shapes, out_shapes, _ = ret.infer_shape_partial() assert arg_shapes[0] == () assert out_shapes[0] == () with mx.np_compat(): data = mx.sym.var("data") ret = mx.sym.sin(data) arg_shapes, out_shapes, _ = ret.infer_shape_partial() assert arg_shapes[0] is None assert out_shapes[0] is None
def test_transpose_partial_shape(): # test converting tensor shape # from channels first to channels last # with batch size unknown axes = [0, 3, 2, 1] x = mx.sym.Variable("x") y = mx.sym.transpose(x, axes=axes) _, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224)) assert result == [(0, 224, 224, 3)] with mx.np_compat(True): _, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224)) assert result == [(-1, 224, 224, 3)]
def test_where_partial_shape(): x = mx.sym.Variable("x") y = mx.sym.Variable("y") cond = mx.sym.Variable("cond") where_op = mx.sym.where(cond, x, y) # condition must be fully known to infer shape _, result, _ = where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2)) assert result == [()] _, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2)) assert result == [()] with mx.np_compat(True): _, result, _ = where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2)) assert result == [None] _, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2), y=(2, 2)) assert result == [None]
def test_batch_dot_partial_shape(): x = mx.sym.Variable("x") y = mx.sym.Variable("y") z = mx.sym.batch_dot(x, y) # lhs and rhs batch size unknown _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5)) assert result_shape == [(0, 3, 5)] # rhs second dim unknown _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5)) assert result_shape == [()] with mx.np_compat(True): _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5)) assert result_shape == [(-1, 3, 5)] _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5)) assert result_shape == [None]