def test_concat_table(): a = symbol('a', '3 * {a: int32, b: int32}') b = symbol('a', '5 * {a: int32, b: int32}') v = symbol('v', 'var * {a: int32, b: int32}') assert concat(a, b).dshape == dshape('8 * {a: int32, b: int32}') assert concat(a, v).dshape == dshape('var * {a: int32, b: int32}')
def test_concat_arr(): a = symbol('a', '3 * int32') b = symbol('b', '5 * int32') v = symbol('v', 'var * int32') assert concat(a, b).dshape == dshape('8 * int32') assert concat(a, v).dshape == dshape('var * int32')
def test_concat_axis_too_great(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 5 * int32') with pytest.raises(ValueError) as excinfo: concat(a, b, axis=2) assert "must be in range: [0, 2)" in str(excinfo.value)
def test_concat_different_measure(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 5 * float64') with pytest.raises(TypeError) as excinfo: concat(a, b) msg = 'Mismatched measures: {l} != {r}'.format(l=a.dshape.measure, r=b.dshape.measure) assert msg == str(excinfo.value)
def test_concat_mat(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 5 * int32') v = symbol('v', 'var * 5 * int32') u = symbol('u', '3 * var * int32') assert concat(a, b, axis=0).dshape == dshape('6 * 5 * int32') assert concat(a, b, axis=1).dshape == dshape('3 * 10 * int32') assert concat(a, v, axis=0).dshape == dshape('var * 5 * int32') assert concat(a, u, axis=1).dshape == dshape('3 * var * int32')
def test_concat_different_along_concat_axis(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 6 * int32') with pytest.raises(TypeError) as excinfo: concat(a, b, axis=0) assert "not equal along axis 1: 5 != 6" in str(excinfo.value) b = symbol('b', '4 * 6 * int32') with pytest.raises(TypeError) as excinfo: concat(a, b, axis=1) assert "not equal along axis 0: 3 != 4" in str(excinfo.value)
def test_concat_different_measure(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 5 * float64') with pytest.raises(TypeError): concat(a, b)
def test_concat_axis_too_great(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 5 * int32') with pytest.raises(ValueError): concat(a, b, axis=2)
def test_concat_negative_axis(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 5 * int32') with pytest.raises(ValueError): concat(a, b, axis=-1)
def test_concat_different_along_concat_axis(): a = symbol('a', '3 * 5 * int32') b = symbol('b', '3 * 6 * int32') with pytest.raises(TypeError): concat(a, b, axis=0)