Example #1
0
    def check_amp_convert_model():
        # Test with real world model, default inputs for convert_model
        dir_path = os.path.dirname(os.path.realpath(__file__))
        model_path = os.path.join(dir_path, 'model')
        if not os.path.isdir(model_path):
            os.mkdir(model_path)
        prefix, epoch = download_model("imagenet1k-resnet-18",
                                       dst_dir=model_path)

        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

        # Test with real world model, tweak inputs for convert_model
        result_sym, result_arg_params, result_aux_params = amp.convert_model(
            sym,
            arg_params,
            aux_params,
            target_dtype="float16",
            target_dtype_ops=["Convolution"])
        mod = mx.mod.Module(result_sym,
                            data_names=["data"],
                            label_names=["softmax_label"],
                            context=mx.gpu())
        mod.bind(data_shapes=[['data', (1, 3, 224, 224)]],
                 label_shapes=[['softmax_label', (1, )]])

        mod.set_params(result_arg_params, result_aux_params)
        mod.forward(
            mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
                            label=[mx.nd.ones((1, ))]))
        mod.get_outputs()[0].asnumpy()
        assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float32

        # Call convert_model with cast_optional_params set to True
        result_sym, result_arg_params, result_aux_params = amp.convert_model(
            sym,
            arg_params,
            aux_params,
            target_dtype="float16",
            target_dtype_ops=["Convolution"],
            cast_optional_params=True)
        mod = mx.mod.Module(result_sym,
                            data_names=["data"],
                            label_names=["softmax_label"],
                            context=mx.gpu())
        mod.bind(data_shapes=[['data', (1, 3, 224, 224)]],
                 label_shapes=[['softmax_label', (1, )]])
        mod.set_params(result_arg_params, result_aux_params)
        mod.forward(
            mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
                            label=[mx.nd.ones((1, ))]))
        mod.get_outputs()[0].asnumpy()
        assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float16
Example #2
0
def low_precison_convert(model_name,
                         low_precision,
                         sym,
                         arg_params,
                         aux_params,
                         excluded_sym_names=[]):
    if low_precision == 'bfloat16':
        if model_name.find('imagenet1k-resnet-152') != -1:
            excluded_sym_names += ['conv0']
        elif model_name.find('imagenet1k-inception-bn') != -1:
            excluded_sym_names += ['conv_1']
        elif model_name.find('resnet') != -1 and model_name.find('v1') != -1:
            excluded_sym_names += ['resnetv10_conv0_fwd']
        elif model_name.find('resnet') != -1 and model_name.find('v2') != -1:
            excluded_sym_names += ['resnetv20_conv0_fwd']
        elif model_name.find('vgg') != -1:
            excluded_sym_names += ['vgg0_conv0_fwd']
        elif model_name.find('squeezenet1') != -1:
            excluded_sym_names += ['squeezenet0_conv0_fwd']
        elif model_name.find('mobilenet') != -1 and model_name.find(
                'v2') == -1:
            excluded_sym_names += ['mobilenet0_conv0_fwd']
        elif model_name.find('mobilenet') != -1 and model_name.find(
                'v2') != -1:
            excluded_sym_names += ['mobilenetv20_conv0_fwd']
        elif model_name.find('inceptionv3') != -1:
            excluded_sym_names += ['inception30_conv0_fwd']
    return amp.convert_model(sym,
                             arg_params,
                             aux_params,
                             target_dtype=low_precision,
                             excluded_sym_names=excluded_sym_names,
                             cast_optional_params=True)
Example #3
0
    def check_amp_convert_fc_accuracy(data_shape, num_hidden, cast_optional_params):
        Batch = collections.namedtuple('Batch',['data'])
        data = mx.sym.Variable(name='data')
        data_low = 0.0
        data_high = 100.0
        fc = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, name='fc')
        fc_exe_fp32 = mx.mod.Module(symbol=fc, label_names=None, context=mx.cpu())
        fc_exe_fp32.bind(data_shapes=[('data', data_shape)])
        fc_exe_fp32.init_params()
        data_fp32 = [mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('float32')]
        fc_exe_fp32.forward(Batch(data_fp32), is_train=False)
        arg_params, aux_params = fc_exe_fp32.get_params()
        output_fp32 = fc_exe_fp32.get_outputs()[0]

        fc_bf16, arg_params_bf16, aux_params_bf16 = amp.convert_model(fc, arg_params, aux_params,
                                                                    target_dtype="bfloat16",
                                                                    target_dtype_ops=["FullyConnected"], cast_optional_params=cast_optional_params)

        fc_exe_bf16 = mx.mod.Module(symbol=fc_bf16, label_names=None, context=mx.cpu())
        fc_exe_bf16.bind(data_shapes=[('data', data_shape)])
        fc_exe_bf16.set_params(arg_params_bf16, aux_params_bf16)
        fc_exe_bf16.forward(Batch(data_fp32), is_train=False)

        output_bf16 = fc_exe_bf16.get_outputs()[0]
        output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32")

        assert_almost_equal(output_bf16_2_fp32, output_fp32, rtol=1e-1, atol=2e-1)
Example #4
0
    def check_amp_convert_conv_accuracy(data_shape, kernel, num_filter, pad, stride, no_bias, cast_optional_params):
        Batch = collections.namedtuple('Batch',['data'])
        data = mx.sym.Variable(name='data')
        data_low = 0.0
        data_high = 100.0
        conv2d = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
                                    no_bias=no_bias, cudnn_off=False, name='conv2d')
        conv_exe_fp32 = mx.mod.Module(symbol=conv2d, label_names=None, context=mx.cpu())
        conv_exe_fp32.bind(data_shapes=[('data', data_shape)])
        conv_exe_fp32.init_params()
        data_fp32 = [mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('float32')]
        conv_exe_fp32.forward(Batch(data_fp32), is_train=False)
        arg_params, aux_params = conv_exe_fp32.get_params()
        output_fp32 = conv_exe_fp32.get_outputs()[0]

        conv2d_bf16, arg_params_bf16, aux_params_bf16 = amp.convert_model(conv2d, arg_params, aux_params,
                                                                        target_dtype="bfloat16",
                                                                        target_dtype_ops=["Convolution"],
                                                                        cast_optional_params=cast_optional_params)

        conv_exe_bf16 = mx.mod.Module(symbol=conv2d_bf16, label_names=None, context=mx.cpu())
        conv_exe_bf16.bind(data_shapes=[('data', data_shape)])
        conv_exe_bf16.set_params(arg_params=arg_params_bf16, aux_params=aux_params_bf16)
        conv_exe_bf16.forward(Batch(data_fp32), is_train=False)
        output_bf16 = conv_exe_bf16.get_outputs()[0]
        output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32")

        assert_almost_equal(output_bf16_2_fp32, output_fp32, rtol=1e-1, atol = 2e-1)
def test_module_backward_compatibility():
    channel_num = 10
    conv_layer_filter_dims = [2, 3]
    conv_layer_strides = [1, 1]
    dimension = 5
    data_len = 10

    data = mx.sym.var("data")
    conv = mx.sym.Convolution(data,
                              num_filter=channel_num,
                              kernel=tuple(conv_layer_filter_dims),
                              stride=tuple(conv_layer_strides))

    bn = mx.sym.BatchNorm(conv,
                          eps=0.001,
                          momentum=0.9,
                          fix_gamma=False,
                          use_global_stats=False,
                          output_mean_var=False,
                          name="conv0_batchnorm")
    fc = mx.sym.FullyConnected(bn, num_hidden=10, name="fullyconnected")
    mod = mx.mod.Module(fc, data_names=["data"], context=mx.cpu())
    mod.bind(data_shapes=[['data', (1, 3, 224, 224)]])
    mod.init_params()

    arg_params, aux_params = mod.get_params()
    for param_key, param_val in arg_params.items():
        assert param_val.dtype == np.float32, "Incorrect inference type for arg_params," \
                                               "please check simple_bind for module executor"
    for param_key, param_val in aux_params.items():
        assert param_val.dtype == np.float32, "Incorrect inference type for aux_params," \
                                               "please check simple_bind for module executor"

    sym, arg_params, aux_params = amp.convert_model(
        mod._symbol,
        mod._arg_params,
        mod._aux_params,
        target_dtype="bfloat16",
        target_dtype_ops=["Convolution"])
    mod = mx.mod.Module(sym, data_names=["data"], context=mx.cpu())
    mod.bind(data_shapes=[['data', (1, 3, 224, 224)]])
    mod.set_params(arg_params, aux_params)
    assert arg_params["fullyconnected_weight"].dtype == bfloat16, \
        "Module API is overwriting the inferred dtype for a mixed precision model"