def check_amp_fuse(net, data_example, expected_sym=None, quantized_nodes=[], rtol=0.05): net.hybridize() out_ref = net(*data_example) net.optimize_for( data_example, backend=SG_PASS_NAME) # amp pass works only on oneDNN nodes lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=AMP_DTYPE, excluded_sym_names=quantized_nodes, cast_params_offline=True, device=mx.current_context()) lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME) out_lp_net = lp_net(*data_example) # check outputs out_ref = [out_ref] if not isinstance(out_ref, list) else out_ref out_lp_net = [out_lp_net] if not isinstance(out_ref, list) else out_lp_net for ref_out, lp_out in zip(out_ref, out_lp_net): assert_almost_equal(ref_out, lp_out, rtol=rtol, atol=1.0) # check graph if expected_sym is not None: lp_symnet = lp_net.export(None, remove_amp_cast=False)[0] same_graph_structure(lp_symnet, expected_sym, True)
def test_bf16_offline_casting_shared_params(): COMMON_SIZE = 4 class TestNet(nn.HybridBlock): def __init__(self): super().__init__() self.lp16_op1 = nn.Dense(COMMON_SIZE) self.lp16_op2 = nn.Dense(COMMON_SIZE) self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight}) self.fp32_op = nn.Conv1D(COMMON_SIZE, 3) self.fp32_op.share_parameters({'bias': self.lp16_op2.bias}) def forward(self, x): x = self.lp16_op1(x) x1 = self.lp16_op2(x) x2 = mx.np.expand_dims(x, 1) x2 = self.fp32_op(x2) x2 = mx.npx.batch_flatten(x2) x = mx.np.concat((x1, x2), axis=1) return x net = TestNet() net.initialize() data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE)) lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=bfloat16, target_dtype_ops=['FullyConnected'], fp32_ops=['Convolution'], cast_params_offline=True, device=mx.current_context()) lp_net(data_example) for name, data in lp_net.collect_params().items(): assert data.dtype == (np.float32 if 'fp32_op' in name else bfloat16)
def test_amp_basic_use(lp_dtype): class TestNet(nn.HybridBlock): def __init__(self): super().__init__() self.fc1 = nn.Dense(4) self.fc2 = nn.Dense(4) def forward(self, x): x = self.fc1(x) x = self.fc2(x) return x.reshape((-1, 2, 2)) data_example = mx.np.random.uniform(-1, 1, (4, 4)) net = TestNet() net.initialize() net = amp.convert_hybrid_block(net, data_example, lp_dtype) lp16_casts = 1 # cast for network input lp16_casts += 2 # cast for weights and bias of `fc1` lp16_casts += 2 # cast for weights and bias of `fc2` other_casts = 1 # cast for the network output (from lp16 to f32) lp16_tensors = 1 # cast network input lp16_tensors += 3 # cast weights and bias of `fc1`, `fc1` output lp16_tensors += 3 # cast weights and bias of `fc2`, `fc2` output lp16_tensors += 1 # reshape output check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=lp16_tensors, lp16_casts_num=lp16_casts, other_casts_num=other_casts)
def test_fp16_offline_casting(): class TestNet(nn.HybridBlock): def __init__(self): super().__init__() self.lp16_op1 = nn.Conv2D(4, 3) self.lp16_op2 = nn.Conv2DTranspose(4, 3) self.fp32_op = nn.Dense(4) def forward(self, x): x = self.lp16_op1(x) x = self.lp16_op2(x) x = x.reshape(x.shape[0], -1) x = self.fp32_op(x) return x net = TestNet() net.initialize() data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16)) lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16', target_dtype_ops=['Convolution'], fp32_ops=['FullyConnected'], cast_params_offline=True, device=mx.current_context()) lp_net(data_example) for name, data in lp_net.collect_params().items(): assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
def test_lp16_fp32_ops_order_independence(lp_dtype): class TestNet(nn.HybridBlock): def __init__(self, lp16_fp32_is_first): super().__init__() if lp16_fp32_is_first: self.first = mx.npx.batch_flatten # lp16_fp32_op self.second = nn.Dense(4) else: self.first = nn.Dense(4) self.second = mx.npx.batch_flatten # lp16_fp32_op def forward(self, x): x = 2**x x1 = self.first(x) x2 = self.second(x) return x1, x2 data_example = mx.np.random.uniform(-1, 1, (4, 16)) for lp16_fp32_is_second in [False, True]: net = TestNet(lp16_fp32_is_second) net.initialize() net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True) check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=3, lp16_casts_num=1, other_casts_num=2)
def test_amp_offline_casting(lp_dtype): class TestNet(nn.HybridBlock): def __init__(self): super().__init__() self.lp16_op1 = nn.Conv2D(4, 3) self.lp16_op2 = nn.Conv2DTranspose(4, 3) self.fp32_op = nn.Dense(4) def forward(self, x): x = self.lp16_op1(x) x = self.lp16_op2(x) x = x.reshape(x.shape[0], -1) with nn.HybridBlock.OptConstraint.disable_amp(): x = self.fp32_op(x) return x net = TestNet() net.initialize() data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16)) lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True) check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4, lp16_casts_num=1, other_casts_num=1) for name, data in lp_net.collect_params().items(): assert mx.nd.get_dtype_name( data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
def test_amp_node_excluding(lp_dtype): DISABLE_AMP_ATTR_DICT = { '__opt_constraint__': str(mx.gluon.HybridBlock.OptConstraint.Flag.DisableAMP.value) } data = mx.sym.var('data') wei = mx.sym.var('weights') bias = mx.sym.var('bias') # manually excluded fc1 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc1', attr=DISABLE_AMP_ATTR_DICT) # to be excluded using the conversion API fc2 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc2') symnet = mx.sym.Group([fc1, fc2]) net = mx.gluon.SymbolBlock(symnet, [data]) net.initialize() # exclude only nodes with set attribute (only 1 node - `fc1`) data_example = mx.np.random.uniform(-1, 1, (4, 16)) net_1_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype) lp16_tensors = 4 # cast `data`, weights and bias of `fc1`, `fc1` output lp16_casts = 3 # `data` cast, casts for weights and bias of `fc1` other_casts = 1 # cast for the network output (from lp16 to f32) check_amp_net_stats(lp_dtype, net_1_excluded, data_example, lp16_tensors_num=lp16_tensors, lp16_casts_num=lp16_casts, other_casts_num=other_casts) # exclude using the `excluded_sym_names` argument (both nodes) net_2_excluded = amp.convert_hybrid_block( net, data_example, lp_dtype, excluded_sym_names=['fc1', 'fc2']) check_amp_net_stats(lp_dtype, net_2_excluded, data_example, lp16_tensors_num=0, lp16_casts_num=0, other_casts_num=0)
def test_amp_conversion_rnn(amp_tests): with mx.Device(mx.gpu(0)): model = nn.HybridSequential() model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True)) model.add(nn.Dense(2)) model.initialize() model.hybridize() out = model(mx.nd.ones((2, 3, 4))) new_model = amp.convert_hybrid_block(model) out2 = new_model(mx.nd.ones((2, 3, 4))) mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)
def test_amp_excluding_after_graph_pass(): class TestNet(nn.HybridBlock): def __init__(self): super(TestNet, self).__init__() self.fc1 = nn.Dense(16) self.fc2 = nn.Dense(16) def forward(self, x): x = self.fc1(x) with nn.HybridBlock.OptConstraint.disable_amp(): x = self.fc2(x) return x data_example = mx.np.random.uniform(-1, 1, (1, 8)) net = TestNet() net.initialize() net_before = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True) check_amp_net_stats(AMP_DTYPE, net_before, data_example, lp16_tensors_num=2, lp16_casts_num=1, other_casts_num=1) net.optimize_for(data_example, backend=SG_PASS_NAME) # introduces new nodes net_after = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True) check_amp_net_stats(AMP_DTYPE, net_after, data_example, lp16_tensors_num=2, lp16_casts_num=1, other_casts_num=1)
def check_amp_with_quantization(net, data_example, quantized_nodes): net.optimize_for(data_example, backend=QUANTIZE_SG_PASS_NAME) symnet = net.export(None)[0] nodes = {n['name'] for n in json.loads(symnet.tojson())['nodes'] if n['op'] != 'null'} quant_excluded_nodes = list(nodes - set(quantized_nodes)) _, calib_tensors1 = mx.contrib.quantization._quantize_symbol( symnet, mx.current_context(), excluded_symbols=quant_excluded_nodes) lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=AMP_DTYPE, excluded_sym_names=quantized_nodes, cast_params_offline=True, device=mx.current_context()) lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME) lp_symnet = lp_net.export(None, remove_amp_cast=False)[0] _, calib_tensors2 = mx.contrib.quantization._quantize_symbol( lp_symnet, mx.cpu(), excluded_symbols=quant_excluded_nodes) assert calib_tensors1 == calib_tensors2
def test_amp_offline_casting_shared_params(lp_dtype): COMMON_SIZE = 4 class TestNet(nn.HybridBlock): def __init__(self): super().__init__() self.lp16_op1 = nn.Dense(COMMON_SIZE) self.lp16_op2 = nn.Dense(COMMON_SIZE) self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight}) self.fp32_op = nn.Dense(COMMON_SIZE) self.fp32_op.share_parameters({'bias': self.lp16_op2.bias}) def forward(self, x): x = self.lp16_op1(x) x1 = self.lp16_op2(x) with nn.HybridBlock.OptConstraint.disable_amp(): x2 = self.fp32_op(x) x = mx.np.concat((x1, x2), axis=1) return x net = TestNet() net.initialize() data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE)) lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True) check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4, lp16_casts_num=2, other_casts_num=2) for name, data in lp_net.collect_params().items(): assert mx.nd.get_dtype_name( data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
help='Will generate random input of shape (1, 3, 224, 224) ' 'and run a dummy inference forward pass') parser.add_argument('--cast-optional-params', action='store_true', default=False, help='If enabled, will try to cast params to target dtype wherever possible') args = parser.parse_args() logging.basicConfig() logger = logging.getLogger('logger') logger.setLevel(logging.INFO) assert args.model in gluon_models, "Please choose one of the available gluon models: {}".format(gluon_models) shape = None if args.model in segmentation_models: shape = (1, 3, 480, 480) elif args.model in calib_ssd_models: shape = (1, 3, 512, 544) elif args.model in calib_inception_models: shape = (1, 3, 299, 299) else: shape = (1, 3, 224, 224) net = gluoncv.model_zoo.get_model(args.model, pretrained=True) net.hybridize() result_before1 = net.forward(mx.nd.random.uniform(shape=shape)) net.export("{}".format(args.model)) net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params) net.export("{}-amp".format(args.model), remove_amp_cast=False) if args.run_dummy_inference: logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1") result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0))) result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0))) logger.info("Inference run successfully")