Ejemplo n.º 1
0
def test_bf16_casting():
    data = mx.sym.var("data")
    out1 = mx.sym.amp_cast(data, dtype=bfloat16)
    out2 = mx.sym.amp_cast(data, dtype="float32")
    out3 = mx.sym.amp_cast(data, dtype=bfloat16)
    # When two ops from data, with different dtypes,
    # data should be float32
    res = mx.sym.Group([out1, out2])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   target_dtype="bfloat16",
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # When two ops from data, both casted to bfloat16,
    # data should be bfloat16
    res = mx.sym.Group([out1, out3])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   target_dtype="bfloat16",
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16

    # AMP Multicast test where one node is float32, another is bfloat16
    data = mx.sym.var("data", dtype="float32")
    data2 = mx.sym.var("data2", dtype=bfloat16)
    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
    final_res = amp.convert_symbol(out4,
                                   target_dtype="bfloat16",
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.cpu(), data2=(1, 2), data=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16

    # AMP Multicast test where two non input nodes are bfloat16,
    # and one input node is float32
    data = mx.sym.var("data", dtype="float32")
    data2 = mx.sym.var("data2", dtype=bfloat16)
    data3 = mx.sym.var("data3", dtype=bfloat16)
    out5 = mx.sym.amp_multicast(data,
                                mx.sym.elemwise_add(data2, data3),
                                num_outputs=2)
    final_res = amp.convert_symbol(out5,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.cpu(),
                                 data=(1, 2),
                                 data2=(1, 2),
                                 data3=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # AMP Multicast test where three input nodes one bf16, one fp32
    # one unknown
    data = mx.sym.var("data", dtype=bfloat16)
    data2 = mx.sym.var("data2", dtype="float32")
    data3 = mx.sym.var("data3")
    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
    final_res = amp.convert_symbol(out6,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.cpu(),
                                 data=(1, 2),
                                 data2=(1, 2),
                                 data3=(1, 2))
    assert exe.arg_arrays[2].dtype == np.float32

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already bf16, it should still be bf16
    data = mx.sym.var("data", dtype=bfloat16)
    data2 = mx.sym.var("data2", dtype="float32")
    out7 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype=bfloat16)
    ])
    final_res = amp.convert_symbol(out7,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already fp32, it should be changed to bf16
    data = mx.sym.var("data", dtype="float32")
    data2 = mx.sym.var("data2", dtype=bfloat16)
    out8 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype=bfloat16)
    ])
    final_res = amp.convert_symbol(out8,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16
Ejemplo n.º 2
0
def test_fp16_casting(amp_tests):
    data = mx.sym.var("data")
    out1 = mx.sym.amp_cast(data, dtype="float16")
    out2 = mx.sym.amp_cast(data, dtype="float32")
    out3 = mx.sym.amp_cast(data, dtype="float16")
    # When two ops from data, with different dtypes,
    # data should be float32
    res = mx.sym.Group([out1, out2])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # When two ops from data, both casted to float16,
    # data should be float16
    res = mx.sym.Group([out1, out3])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # AMP Multicast test where one node is float32, another is float16
    data = mx.sym.var("data", dtype=np.float32)
    data2 = mx.sym.var("data2", dtype=np.float16)
    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
    final_res = amp.convert_symbol(out4, cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # AMP Multicast test where two non input nodes are float16,
    # and one input node is float32
    data = mx.sym.var("data", dtype=np.float32)
    data2 = mx.sym.var("data2", dtype=np.float16)
    data3 = mx.sym.var("data3", dtype=np.float16)
    out5 = mx.sym.amp_multicast(data,
                                mx.sym.elemwise_add(data2, data3),
                                num_outputs=2)
    final_res = amp.convert_symbol(out5,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.gpu(),
                                 data=(1, 2),
                                 data2=(1, 2),
                                 data3=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # AMP Multicast test where three input nodes one fp16, one fp32
    # one unknown
    data = mx.sym.var("data", dtype=np.float16)
    data2 = mx.sym.var("data2", dtype=np.float32)
    data3 = mx.sym.var("data3")
    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
    final_res = amp.convert_symbol(out6,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.gpu(),
                                 data=(1, 2),
                                 data2=(1, 2),
                                 data3=(1, 2))
    assert exe.arg_arrays[2].dtype == np.float32

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already fp16, it should still be fp16
    data = mx.sym.var("data", dtype=np.float16)
    data2 = mx.sym.var("data2", dtype=np.float32)
    out7 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype="float16")
    ])
    final_res = amp.convert_symbol(out7,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already fp32, it should be changed to fp16
    data = mx.sym.var("data", dtype=np.float32)
    data2 = mx.sym.var("data2", dtype=np.float16)
    out8 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype="float16")
    ])
    final_res = amp.convert_symbol(out8,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # Check for symbol which has slice channel
    data = mx.sym.var("data")
    data2 = mx.sym.var("data2")
    data._set_attr(__dtype__="-1")
    data2._set_attr(__dtype__="-1")
    concat_res = mx.sym.concat(data, data2)
    out = mx.sym.split(concat_res, axis=1, num_outputs=2)
    final_res = amp.convert_symbol(out)