def check_amp_convert_symbol(): x = mx.sym.var("x") y = mx.sym.var("y") z = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True) siny = mx.sym.sin(y) res = z + siny # Compare symbols with similar computation graphs created using convert_symbol and manually. res_converted = amp.convert_symbol(res, target_dtype="float16", target_dtype_ops=["FullyConnected"], fp32_ops=["sin"]) x_fp16 = mx.sym.amp_cast(x, dtype="float16") y_fp16 = mx.sym.amp_cast(y, dtype="float16") siny = mx.sym.sin(y) z = mx.sym.FullyConnected(x_fp16, y_fp16, num_hidden=10, no_bias=True) amp_casted_z = mx.sym.amp_cast(z, dtype="float32") res_expected = amp_casted_z + siny assert same_symbol_structure(res_converted, res_expected), \ "convert_symbol generating wrong computation graph" # convert_symbol called with incorrect inputs pytest.raises(AssertionError, amp.convert_symbol, res, target_dtype="float16", target_dtype_ops=["FullyConnected"], fp32_ops=["elemwise_add"]) pytest.raises(AssertionError, amp.convert_symbol, res, target_dtype="float16", target_dtype_ops=["FullyConnected"], fp32_ops=["Activation"], conditional_fp32_ops=[('Activation', 'act_type', ['selu'])]) pytest.raises(AssertionError, amp.convert_symbol, res, target_dtype="float16", target_dtype_ops=["Activation"], fp32_ops=["Activation"], conditional_fp32_ops=[('Activation', 'act_type', ['selu'])]) pytest.raises(AssertionError, amp.convert_symbol, res, target_dtype="float16", target_dtype_ops=["FullyConnected"], fp32_ops=["FullyConnected"]) # Test for op in conditional ops with condition not satisfied x = mx.sym.var("x") y = mx.sym.var("y") fc_cond = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True) res_converted = amp.convert_symbol(fc_cond, target_dtype="float16", target_dtype_ops=[], fp32_ops=["sin"], conditional_fp32_ops=[ ("FullyConnected", "no_bias", ["False"]) ]) res_expected = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True) assert same_symbol_structure(res_converted, res_expected), \ "convert_symbol generating wrong computation graph when conditional ops is used" # Test for op in conditional ops with condition satisfied res_converted = amp.convert_symbol(fc_cond, target_dtype="float16", target_dtype_ops=[], fp32_ops=["sin"], conditional_fp32_ops=[ ("FullyConnected", "no_bias", ["True"]) ]) x_fp32 = mx.sym.amp_cast(x, dtype="float32") y_fp32 = mx.sym.amp_cast(y, dtype="float32") res_expected = mx.sym.FullyConnected(x_fp32, y_fp32, num_hidden=10, no_bias=True) assert same_symbol_structure(res_converted, res_expected), \ "convert_symbol generating wrong computation graph when conditional ops used with satisfying condition" # Test with a real world model, default inputs for convert_symbol dir_path = os.path.dirname(os.path.realpath(__file__)) model_path = os.path.join(dir_path, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) inputs = {} inputs['data'] = mx.nd.ones((1, 3, 224, 224)) inputs.update(arg_params) converted_sym = amp.convert_symbol(sym) exe = converted_sym.simple_bind(mx.gpu(0), data=(1, 3, 224, 224), grad_req='null') exe.forward(is_train=False, **inputs) exe.outputs[0].asnumpy() inputs2 = {} inputs2['data'] = mx.nd.ones((1, 3, 224, 224)) inputs2['fc1_weight'] = inputs['fc1_weight'].astype(np.float16) inputs2['fc1_bias'] = inputs['fc1_bias'].astype(np.float16) # Test with a real world model, tweak inputs for convert_symbol converted_sym = amp.convert_symbol(sym, target_dtype="float16", target_dtype_ops=["Convolution"], data_names=["data"], cast_optional_params=True) converted_sym2 = amp.convert_symbol(sym, target_dtype="float16", target_dtype_ops=["Convolution"], data_names=["data"], cast_optional_params=False) exe = converted_sym.simple_bind(mx.gpu(0), data=(1, 3, 224, 224), grad_req='null') exe2 = converted_sym2.simple_bind(mx.gpu(), data=(1, 3, 224, 224), grad_req='null') converted_args = converted_sym.list_arguments() converted_auxs = converted_sym.list_auxiliary_states() for i, key in enumerate(exe.arg_arrays): if converted_args[i] in arg_params: arg_params[converted_args[i]] = arg_params[ converted_args[i]].astype(exe.arg_arrays[i].dtype) for i, key in enumerate(exe.aux_arrays): if converted_auxs[i] in aux_params: aux_params[converted_auxs[i]] = aux_params[ converted_auxs[i]].astype(exe.aux_arrays[i].dtype) inputs2.update(arg_params) exe.forward(is_train=False, **inputs2) exe.outputs[0].wait_to_read() inputs['fc1_weight'] = inputs['fc1_weight'].astype(np.float16) inputs['fc1_bias'] = inputs['fc1_bias'].astype(np.float16) exe2.forward(is_train=False, **inputs) exe2.outputs[0].wait_to_read()
def test_fp16_casting(amp_tests): data = mx.sym.var("data") out1 = mx.sym.amp_cast(data, dtype="float16") out2 = mx.sym.amp_cast(data, dtype="float32") out3 = mx.sym.amp_cast(data, dtype="float16") # When two ops from data, with different dtypes, # data should be float32 res = mx.sym.Group([out1, out2]) final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2)) assert exe.arg_arrays[0].dtype == np.float32 # When two ops from data, both casted to float16, # data should be float16 res = mx.sym.Group([out1, out3]) final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2)) assert exe.arg_arrays[0].dtype == np.float16 # AMP Multicast test where one node is float32, another is float16 data = mx.sym.var("data", dtype=np.float32) data2 = mx.sym.var("data2", dtype=np.float16) out4 = mx.sym.amp_multicast(data, data2, num_outputs=2) final_res = amp.convert_symbol(out4, cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2)) assert exe.arg_arrays[0].dtype == np.float16 # AMP Multicast test where two non input nodes are float16, # and one input node is float32 data = mx.sym.var("data", dtype=np.float32) data2 = mx.sym.var("data2", dtype=np.float16) data3 = mx.sym.var("data3", dtype=np.float16) out5 = mx.sym.amp_multicast(data, mx.sym.elemwise_add(data2, data3), num_outputs=2) final_res = amp.convert_symbol(out5, target_dtype_ops=[], fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), data3=(1, 2)) assert exe.arg_arrays[0].dtype == np.float32 # AMP Multicast test where three input nodes one fp16, one fp32 # one unknown data = mx.sym.var("data", dtype=np.float16) data2 = mx.sym.var("data2", dtype=np.float32) data3 = mx.sym.var("data3") out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3) final_res = amp.convert_symbol(out6, target_dtype_ops=[], fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), data3=(1, 2)) assert exe.arg_arrays[2].dtype == np.float32 # Input node to amp_multicast and amp_cast, if dtypes conflict # and input node is already fp16, it should still be fp16 data = mx.sym.var("data", dtype=np.float16) data2 = mx.sym.var("data2", dtype=np.float32) out7 = mx.sym.Group([ mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16") ]) final_res = amp.convert_symbol(out7, target_dtype_ops=[], fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2)) assert exe.arg_arrays[0].dtype == np.float16 # Input node to amp_multicast and amp_cast, if dtypes conflict # and input node is already fp32, it should be changed to fp16 data = mx.sym.var("data", dtype=np.float32) data2 = mx.sym.var("data2", dtype=np.float16) out8 = mx.sym.Group([ mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16") ]) final_res = amp.convert_symbol(out8, target_dtype_ops=[], fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2)) assert exe.arg_arrays[0].dtype == np.float16 # Check for symbol which has slice channel data = mx.sym.var("data") data2 = mx.sym.var("data2") data._set_attr(__dtype__="-1") data2._set_attr(__dtype__="-1") concat_res = mx.sym.concat(data, data2) out = mx.sym.split(concat_res, axis=1, num_outputs=2) final_res = amp.convert_symbol(out)
def test_bf16_casting(): data = mx.sym.var("data") out1 = mx.sym.amp_cast(data, dtype=bfloat16) out2 = mx.sym.amp_cast(data, dtype="float32") out3 = mx.sym.amp_cast(data, dtype=bfloat16) # When two ops from data, with different dtypes, # data should be float32 res = mx.sym.Group([out1, out2]) final_res = amp.convert_symbol(res, data_names=[], target_dtype="bfloat16", cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2)) assert exe.arg_arrays[0].dtype == np.float32 # When two ops from data, both casted to bfloat16, # data should be bfloat16 res = mx.sym.Group([out1, out3]) final_res = amp.convert_symbol(res, data_names=[], target_dtype="bfloat16", cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2)) assert exe.arg_arrays[0].dtype == bfloat16 # AMP Multicast test where one node is float32, another is bfloat16 data = mx.sym.var("data", dtype="float32") data2 = mx.sym.var("data2", dtype=bfloat16) out4 = mx.sym.amp_multicast(data, data2, num_outputs=2) final_res = amp.convert_symbol(out4, target_dtype="bfloat16", cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.cpu(), data2=(1, 2), data=(1, 2)) assert exe.arg_arrays[0].dtype == bfloat16 # AMP Multicast test where two non input nodes are bfloat16, # and one input node is float32 data = mx.sym.var("data", dtype="float32") data2 = mx.sym.var("data2", dtype=bfloat16) data3 = mx.sym.var("data3", dtype=bfloat16) out5 = mx.sym.amp_multicast(data, mx.sym.elemwise_add(data2, data3), num_outputs=2) final_res = amp.convert_symbol(out5, target_dtype_ops=[], target_dtype="bfloat16", fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2), data3=(1, 2)) assert exe.arg_arrays[0].dtype == np.float32 # AMP Multicast test where three input nodes one bf16, one fp32 # one unknown data = mx.sym.var("data", dtype=bfloat16) data2 = mx.sym.var("data2", dtype="float32") data3 = mx.sym.var("data3") out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3) final_res = amp.convert_symbol(out6, target_dtype_ops=[], target_dtype="bfloat16", fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2), data3=(1, 2)) assert exe.arg_arrays[2].dtype == np.float32 # Input node to amp_multicast and amp_cast, if dtypes conflict # and input node is already bf16, it should still be bf16 data = mx.sym.var("data", dtype=bfloat16) data2 = mx.sym.var("data2", dtype="float32") out7 = mx.sym.Group([ mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype=bfloat16) ]) final_res = amp.convert_symbol(out7, target_dtype_ops=[], target_dtype="bfloat16", fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2)) assert exe.arg_arrays[0].dtype == bfloat16 # Input node to amp_multicast and amp_cast, if dtypes conflict # and input node is already fp32, it should be changed to bf16 data = mx.sym.var("data", dtype="float32") data2 = mx.sym.var("data2", dtype=bfloat16) out8 = mx.sym.Group([ mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype=bfloat16) ]) final_res = amp.convert_symbol(out8, target_dtype_ops=[], target_dtype="bfloat16", fp32_ops=[], cast_optional_params=True) exe = final_res.simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2)) assert exe.arg_arrays[0].dtype == bfloat16