Example #1
0
    def check_amp_convert_symbol():
        x = mx.sym.var("x")
        y = mx.sym.var("y")
        z = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True)
        siny = mx.sym.sin(y)
        res = z + siny
        # Compare symbols with similar computation graphs created using convert_symbol and manually.
        res_converted = amp.convert_symbol(res,
                                           target_dtype="float16",
                                           target_dtype_ops=["FullyConnected"],
                                           fp32_ops=["sin"])

        x_fp16 = mx.sym.amp_cast(x, dtype="float16")
        y_fp16 = mx.sym.amp_cast(y, dtype="float16")
        siny = mx.sym.sin(y)
        z = mx.sym.FullyConnected(x_fp16, y_fp16, num_hidden=10, no_bias=True)
        amp_casted_z = mx.sym.amp_cast(z, dtype="float32")
        res_expected = amp_casted_z + siny
        assert same_symbol_structure(res_converted, res_expected), \
            "convert_symbol generating wrong computation graph"

        # convert_symbol called with incorrect inputs
        pytest.raises(AssertionError,
                      amp.convert_symbol,
                      res,
                      target_dtype="float16",
                      target_dtype_ops=["FullyConnected"],
                      fp32_ops=["elemwise_add"])
        pytest.raises(AssertionError,
                      amp.convert_symbol,
                      res,
                      target_dtype="float16",
                      target_dtype_ops=["FullyConnected"],
                      fp32_ops=["Activation"],
                      conditional_fp32_ops=[('Activation', 'act_type',
                                             ['selu'])])
        pytest.raises(AssertionError,
                      amp.convert_symbol,
                      res,
                      target_dtype="float16",
                      target_dtype_ops=["Activation"],
                      fp32_ops=["Activation"],
                      conditional_fp32_ops=[('Activation', 'act_type',
                                             ['selu'])])
        pytest.raises(AssertionError,
                      amp.convert_symbol,
                      res,
                      target_dtype="float16",
                      target_dtype_ops=["FullyConnected"],
                      fp32_ops=["FullyConnected"])

        # Test for op in conditional ops with condition not satisfied
        x = mx.sym.var("x")
        y = mx.sym.var("y")
        fc_cond = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True)
        res_converted = amp.convert_symbol(fc_cond,
                                           target_dtype="float16",
                                           target_dtype_ops=[],
                                           fp32_ops=["sin"],
                                           conditional_fp32_ops=[
                                               ("FullyConnected", "no_bias",
                                                ["False"])
                                           ])

        res_expected = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True)
        assert same_symbol_structure(res_converted, res_expected), \
            "convert_symbol generating wrong computation graph when conditional ops is used"

        # Test for op in conditional ops with condition satisfied
        res_converted = amp.convert_symbol(fc_cond,
                                           target_dtype="float16",
                                           target_dtype_ops=[],
                                           fp32_ops=["sin"],
                                           conditional_fp32_ops=[
                                               ("FullyConnected", "no_bias",
                                                ["True"])
                                           ])
        x_fp32 = mx.sym.amp_cast(x, dtype="float32")
        y_fp32 = mx.sym.amp_cast(y, dtype="float32")
        res_expected = mx.sym.FullyConnected(x_fp32,
                                             y_fp32,
                                             num_hidden=10,
                                             no_bias=True)
        assert same_symbol_structure(res_converted, res_expected), \
            "convert_symbol generating wrong computation graph when conditional ops used with satisfying condition"

        # Test with a real world model, default inputs for convert_symbol
        dir_path = os.path.dirname(os.path.realpath(__file__))
        model_path = os.path.join(dir_path, 'model')
        if not os.path.isdir(model_path):
            os.mkdir(model_path)

        prefix, epoch = download_model("imagenet1k-resnet-18",
                                       dst_dir=model_path)
        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
        inputs = {}
        inputs['data'] = mx.nd.ones((1, 3, 224, 224))
        inputs.update(arg_params)
        converted_sym = amp.convert_symbol(sym)
        exe = converted_sym.simple_bind(mx.gpu(0),
                                        data=(1, 3, 224, 224),
                                        grad_req='null')
        exe.forward(is_train=False, **inputs)
        exe.outputs[0].asnumpy()

        inputs2 = {}
        inputs2['data'] = mx.nd.ones((1, 3, 224, 224))
        inputs2['fc1_weight'] = inputs['fc1_weight'].astype(np.float16)
        inputs2['fc1_bias'] = inputs['fc1_bias'].astype(np.float16)

        # Test with a real world model, tweak inputs for convert_symbol
        converted_sym = amp.convert_symbol(sym,
                                           target_dtype="float16",
                                           target_dtype_ops=["Convolution"],
                                           data_names=["data"],
                                           cast_optional_params=True)
        converted_sym2 = amp.convert_symbol(sym,
                                            target_dtype="float16",
                                            target_dtype_ops=["Convolution"],
                                            data_names=["data"],
                                            cast_optional_params=False)

        exe = converted_sym.simple_bind(mx.gpu(0),
                                        data=(1, 3, 224, 224),
                                        grad_req='null')
        exe2 = converted_sym2.simple_bind(mx.gpu(),
                                          data=(1, 3, 224, 224),
                                          grad_req='null')

        converted_args = converted_sym.list_arguments()
        converted_auxs = converted_sym.list_auxiliary_states()
        for i, key in enumerate(exe.arg_arrays):
            if converted_args[i] in arg_params:
                arg_params[converted_args[i]] = arg_params[
                    converted_args[i]].astype(exe.arg_arrays[i].dtype)
        for i, key in enumerate(exe.aux_arrays):
            if converted_auxs[i] in aux_params:
                aux_params[converted_auxs[i]] = aux_params[
                    converted_auxs[i]].astype(exe.aux_arrays[i].dtype)

        inputs2.update(arg_params)
        exe.forward(is_train=False, **inputs2)
        exe.outputs[0].wait_to_read()

        inputs['fc1_weight'] = inputs['fc1_weight'].astype(np.float16)
        inputs['fc1_bias'] = inputs['fc1_bias'].astype(np.float16)
        exe2.forward(is_train=False, **inputs)
        exe2.outputs[0].wait_to_read()
Example #2
0
def test_fp16_casting(amp_tests):
    data = mx.sym.var("data")
    out1 = mx.sym.amp_cast(data, dtype="float16")
    out2 = mx.sym.amp_cast(data, dtype="float32")
    out3 = mx.sym.amp_cast(data, dtype="float16")
    # When two ops from data, with different dtypes,
    # data should be float32
    res = mx.sym.Group([out1, out2])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # When two ops from data, both casted to float16,
    # data should be float16
    res = mx.sym.Group([out1, out3])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # AMP Multicast test where one node is float32, another is float16
    data = mx.sym.var("data", dtype=np.float32)
    data2 = mx.sym.var("data2", dtype=np.float16)
    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
    final_res = amp.convert_symbol(out4, cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # AMP Multicast test where two non input nodes are float16,
    # and one input node is float32
    data = mx.sym.var("data", dtype=np.float32)
    data2 = mx.sym.var("data2", dtype=np.float16)
    data3 = mx.sym.var("data3", dtype=np.float16)
    out5 = mx.sym.amp_multicast(data,
                                mx.sym.elemwise_add(data2, data3),
                                num_outputs=2)
    final_res = amp.convert_symbol(out5,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.gpu(),
                                data=(1, 2),
                                data2=(1, 2),
                                data3=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # AMP Multicast test where three input nodes one fp16, one fp32
    # one unknown
    data = mx.sym.var("data", dtype=np.float16)
    data2 = mx.sym.var("data2", dtype=np.float32)
    data3 = mx.sym.var("data3")
    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
    final_res = amp.convert_symbol(out6,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.gpu(),
                                data=(1, 2),
                                data2=(1, 2),
                                data3=(1, 2))
    assert exe.arg_arrays[2].dtype == np.float32

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already fp16, it should still be fp16
    data = mx.sym.var("data", dtype=np.float16)
    data2 = mx.sym.var("data2", dtype=np.float32)
    out7 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype="float16")
    ])
    final_res = amp.convert_symbol(out7,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already fp32, it should be changed to fp16
    data = mx.sym.var("data", dtype=np.float32)
    data2 = mx.sym.var("data2", dtype=np.float16)
    out8 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype="float16")
    ])
    final_res = amp.convert_symbol(out8,
                                   target_dtype_ops=[],
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float16

    # Check for symbol which has slice channel
    data = mx.sym.var("data")
    data2 = mx.sym.var("data2")
    data._set_attr(__dtype__="-1")
    data2._set_attr(__dtype__="-1")
    concat_res = mx.sym.concat(data, data2)
    out = mx.sym.split(concat_res, axis=1, num_outputs=2)
    final_res = amp.convert_symbol(out)
def test_bf16_casting():
    data = mx.sym.var("data")
    out1 = mx.sym.amp_cast(data, dtype=bfloat16)
    out2 = mx.sym.amp_cast(data, dtype="float32")
    out3 = mx.sym.amp_cast(data, dtype=bfloat16)
    # When two ops from data, with different dtypes,
    # data should be float32
    res = mx.sym.Group([out1, out2])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   target_dtype="bfloat16",
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # When two ops from data, both casted to bfloat16,
    # data should be bfloat16
    res = mx.sym.Group([out1, out3])
    final_res = amp.convert_symbol(res,
                                   data_names=[],
                                   target_dtype="bfloat16",
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16

    # AMP Multicast test where one node is float32, another is bfloat16
    data = mx.sym.var("data", dtype="float32")
    data2 = mx.sym.var("data2", dtype=bfloat16)
    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
    final_res = amp.convert_symbol(out4,
                                   target_dtype="bfloat16",
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.cpu(), data2=(1, 2), data=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16

    # AMP Multicast test where two non input nodes are bfloat16,
    # and one input node is float32
    data = mx.sym.var("data", dtype="float32")
    data2 = mx.sym.var("data2", dtype=bfloat16)
    data3 = mx.sym.var("data3", dtype=bfloat16)
    out5 = mx.sym.amp_multicast(data,
                                mx.sym.elemwise_add(data2, data3),
                                num_outputs=2)
    final_res = amp.convert_symbol(out5,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.cpu(),
                                data=(1, 2),
                                data2=(1, 2),
                                data3=(1, 2))
    assert exe.arg_arrays[0].dtype == np.float32

    # AMP Multicast test where three input nodes one bf16, one fp32
    # one unknown
    data = mx.sym.var("data", dtype=bfloat16)
    data2 = mx.sym.var("data2", dtype="float32")
    data3 = mx.sym.var("data3")
    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
    final_res = amp.convert_symbol(out6,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.cpu(),
                                data=(1, 2),
                                data2=(1, 2),
                                data3=(1, 2))
    assert exe.arg_arrays[2].dtype == np.float32

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already bf16, it should still be bf16
    data = mx.sym.var("data", dtype=bfloat16)
    data2 = mx.sym.var("data2", dtype="float32")
    out7 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype=bfloat16)
    ])
    final_res = amp.convert_symbol(out7,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16

    # Input node to amp_multicast and amp_cast, if dtypes conflict
    # and input node is already fp32, it should be changed to bf16
    data = mx.sym.var("data", dtype="float32")
    data2 = mx.sym.var("data2", dtype=bfloat16)
    out8 = mx.sym.Group([
        mx.sym.amp_multicast(data, data2, num_outputs=2),
        mx.sym.amp_cast(data, dtype=bfloat16)
    ])
    final_res = amp.convert_symbol(out8,
                                   target_dtype_ops=[],
                                   target_dtype="bfloat16",
                                   fp32_ops=[],
                                   cast_optional_params=True)
    exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2))
    assert exe.arg_arrays[0].dtype == bfloat16