def check_amp_convert_model(): # Test with real world model, default inputs for convert_model dir_path = os.path.dirname(os.path.realpath(__file__)) model_path = os.path.join(dir_path, 'model') if not os.path.isdir(model_path): os.mkdir(model_path) prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) # Test with real world model, tweak inputs for convert_model result_sym, result_arg_params, result_aux_params = amp.convert_model( sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=["Convolution"]) mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu()) mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1, )]]) mod.set_params(result_arg_params, result_aux_params) mod.forward( mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], label=[mx.nd.ones((1, ))])) mod.get_outputs()[0].asnumpy() assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float32 # Call convert_model with cast_optional_params set to True result_sym, result_arg_params, result_aux_params = amp.convert_model( sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=["Convolution"], cast_optional_params=True) mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu()) mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1, )]]) mod.set_params(result_arg_params, result_aux_params) mod.forward( mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], label=[mx.nd.ones((1, ))])) mod.get_outputs()[0].asnumpy() assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float16
def low_precison_convert(model_name, low_precision, sym, arg_params, aux_params, excluded_sym_names=[]): if low_precision == 'bfloat16': if model_name.find('imagenet1k-resnet-152') != -1: excluded_sym_names += ['conv0'] elif model_name.find('imagenet1k-inception-bn') != -1: excluded_sym_names += ['conv_1'] elif model_name.find('resnet') != -1 and model_name.find('v1') != -1: excluded_sym_names += ['resnetv10_conv0_fwd'] elif model_name.find('resnet') != -1 and model_name.find('v2') != -1: excluded_sym_names += ['resnetv20_conv0_fwd'] elif model_name.find('vgg') != -1: excluded_sym_names += ['vgg0_conv0_fwd'] elif model_name.find('squeezenet1') != -1: excluded_sym_names += ['squeezenet0_conv0_fwd'] elif model_name.find('mobilenet') != -1 and model_name.find( 'v2') == -1: excluded_sym_names += ['mobilenet0_conv0_fwd'] elif model_name.find('mobilenet') != -1 and model_name.find( 'v2') != -1: excluded_sym_names += ['mobilenetv20_conv0_fwd'] elif model_name.find('inceptionv3') != -1: excluded_sym_names += ['inception30_conv0_fwd'] return amp.convert_model(sym, arg_params, aux_params, target_dtype=low_precision, excluded_sym_names=excluded_sym_names, cast_optional_params=True)
def check_amp_convert_fc_accuracy(data_shape, num_hidden, cast_optional_params): Batch = collections.namedtuple('Batch',['data']) data = mx.sym.Variable(name='data') data_low = 0.0 data_high = 100.0 fc = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, name='fc') fc_exe_fp32 = mx.mod.Module(symbol=fc, label_names=None, context=mx.cpu()) fc_exe_fp32.bind(data_shapes=[('data', data_shape)]) fc_exe_fp32.init_params() data_fp32 = [mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('float32')] fc_exe_fp32.forward(Batch(data_fp32), is_train=False) arg_params, aux_params = fc_exe_fp32.get_params() output_fp32 = fc_exe_fp32.get_outputs()[0] fc_bf16, arg_params_bf16, aux_params_bf16 = amp.convert_model(fc, arg_params, aux_params, target_dtype="bfloat16", target_dtype_ops=["FullyConnected"], cast_optional_params=cast_optional_params) fc_exe_bf16 = mx.mod.Module(symbol=fc_bf16, label_names=None, context=mx.cpu()) fc_exe_bf16.bind(data_shapes=[('data', data_shape)]) fc_exe_bf16.set_params(arg_params_bf16, aux_params_bf16) fc_exe_bf16.forward(Batch(data_fp32), is_train=False) output_bf16 = fc_exe_bf16.get_outputs()[0] output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32") assert_almost_equal(output_bf16_2_fp32, output_fp32, rtol=1e-1, atol=2e-1)
def check_amp_convert_conv_accuracy(data_shape, kernel, num_filter, pad, stride, no_bias, cast_optional_params): Batch = collections.namedtuple('Batch',['data']) data = mx.sym.Variable(name='data') data_low = 0.0 data_high = 100.0 conv2d = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride, no_bias=no_bias, cudnn_off=False, name='conv2d') conv_exe_fp32 = mx.mod.Module(symbol=conv2d, label_names=None, context=mx.cpu()) conv_exe_fp32.bind(data_shapes=[('data', data_shape)]) conv_exe_fp32.init_params() data_fp32 = [mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('float32')] conv_exe_fp32.forward(Batch(data_fp32), is_train=False) arg_params, aux_params = conv_exe_fp32.get_params() output_fp32 = conv_exe_fp32.get_outputs()[0] conv2d_bf16, arg_params_bf16, aux_params_bf16 = amp.convert_model(conv2d, arg_params, aux_params, target_dtype="bfloat16", target_dtype_ops=["Convolution"], cast_optional_params=cast_optional_params) conv_exe_bf16 = mx.mod.Module(symbol=conv2d_bf16, label_names=None, context=mx.cpu()) conv_exe_bf16.bind(data_shapes=[('data', data_shape)]) conv_exe_bf16.set_params(arg_params=arg_params_bf16, aux_params=aux_params_bf16) conv_exe_bf16.forward(Batch(data_fp32), is_train=False) output_bf16 = conv_exe_bf16.get_outputs()[0] output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32") assert_almost_equal(output_bf16_2_fp32, output_fp32, rtol=1e-1, atol = 2e-1)
def test_module_backward_compatibility(): channel_num = 10 conv_layer_filter_dims = [2, 3] conv_layer_strides = [1, 1] dimension = 5 data_len = 10 data = mx.sym.var("data") conv = mx.sym.Convolution(data, num_filter=channel_num, kernel=tuple(conv_layer_filter_dims), stride=tuple(conv_layer_strides)) bn = mx.sym.BatchNorm(conv, eps=0.001, momentum=0.9, fix_gamma=False, use_global_stats=False, output_mean_var=False, name="conv0_batchnorm") fc = mx.sym.FullyConnected(bn, num_hidden=10, name="fullyconnected") mod = mx.mod.Module(fc, data_names=["data"], context=mx.cpu()) mod.bind(data_shapes=[['data', (1, 3, 224, 224)]]) mod.init_params() arg_params, aux_params = mod.get_params() for param_key, param_val in arg_params.items(): assert param_val.dtype == np.float32, "Incorrect inference type for arg_params," \ "please check simple_bind for module executor" for param_key, param_val in aux_params.items(): assert param_val.dtype == np.float32, "Incorrect inference type for aux_params," \ "please check simple_bind for module executor" sym, arg_params, aux_params = amp.convert_model( mod._symbol, mod._arg_params, mod._aux_params, target_dtype="bfloat16", target_dtype_ops=["Convolution"]) mod = mx.mod.Module(sym, data_names=["data"], context=mx.cpu()) mod.bind(data_shapes=[['data', (1, 3, 224, 224)]]) mod.set_params(arg_params, aux_params) assert arg_params["fullyconnected_weight"].dtype == bfloat16, \ "Module API is overwriting the inferred dtype for a mixed precision model"