예제 #1
0
def check_amp_fuse(net,
                   data_example,
                   expected_sym=None,
                   quantized_nodes=[],
                   rtol=0.05):
    net.hybridize()
    out_ref = net(*data_example)

    net.optimize_for(
        data_example,
        backend=SG_PASS_NAME)  # amp pass works only on oneDNN nodes
    lp_net = amp.convert_hybrid_block(net,
                                      data_example,
                                      target_dtype=AMP_DTYPE,
                                      excluded_sym_names=quantized_nodes,
                                      cast_params_offline=True,
                                      device=mx.current_context())
    lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME)
    out_lp_net = lp_net(*data_example)

    # check outputs
    out_ref = [out_ref] if not isinstance(out_ref, list) else out_ref
    out_lp_net = [out_lp_net] if not isinstance(out_ref, list) else out_lp_net
    for ref_out, lp_out in zip(out_ref, out_lp_net):
        assert_almost_equal(ref_out, lp_out, rtol=rtol, atol=1.0)

    # check graph
    if expected_sym is not None:
        lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
        same_graph_structure(lp_symnet, expected_sym, True)
예제 #2
0
def test_bf16_offline_casting_shared_params():
  COMMON_SIZE = 4

  class TestNet(nn.HybridBlock):
    def __init__(self):
      super().__init__()
      self.lp16_op1 = nn.Dense(COMMON_SIZE)
      self.lp16_op2 = nn.Dense(COMMON_SIZE)
      self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
      self.fp32_op = nn.Conv1D(COMMON_SIZE, 3)
      self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})

    def forward(self, x):
      x = self.lp16_op1(x)
      x1 = self.lp16_op2(x)
      x2 = mx.np.expand_dims(x, 1)
      x2 = self.fp32_op(x2)
      x2 = mx.npx.batch_flatten(x2)
      x = mx.np.concat((x1, x2), axis=1)
      return x

  net = TestNet()
  net.initialize()
  data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=bfloat16,
                                    target_dtype_ops=['FullyConnected'], fp32_ops=['Convolution'],
                                    cast_params_offline=True, device=mx.current_context())
  lp_net(data_example)
  for name, data in lp_net.collect_params().items():
    assert data.dtype == (np.float32 if 'fp32_op' in name else bfloat16)
예제 #3
0
파일: common.py 프로젝트: DickJC123/mxnet
def test_amp_basic_use(lp_dtype):
    class TestNet(nn.HybridBlock):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Dense(4)
            self.fc2 = nn.Dense(4)

        def forward(self, x):
            x = self.fc1(x)
            x = self.fc2(x)
            return x.reshape((-1, 2, 2))

    data_example = mx.np.random.uniform(-1, 1, (4, 4))

    net = TestNet()
    net.initialize()
    net = amp.convert_hybrid_block(net, data_example, lp_dtype)

    lp16_casts = 1  # cast for network input
    lp16_casts += 2  # cast for weights and bias of `fc1`
    lp16_casts += 2  # cast for weights and bias of `fc2`

    other_casts = 1  # cast for the network output (from lp16 to f32)

    lp16_tensors = 1  # cast network input
    lp16_tensors += 3  # cast weights and bias of `fc1`, `fc1` output
    lp16_tensors += 3  # cast weights and bias of `fc2`, `fc2` output
    lp16_tensors += 1  # reshape output
    check_amp_net_stats(lp_dtype,
                        net,
                        data_example,
                        lp16_tensors_num=lp16_tensors,
                        lp16_casts_num=lp16_casts,
                        other_casts_num=other_casts)
예제 #4
0
def test_fp16_offline_casting():
    class TestNet(nn.HybridBlock):
        def __init__(self):
            super().__init__()
            self.lp16_op1 = nn.Conv2D(4, 3)
            self.lp16_op2 = nn.Conv2DTranspose(4, 3)
            self.fp32_op = nn.Dense(4)

        def forward(self, x):
            x = self.lp16_op1(x)
            x = self.lp16_op2(x)
            x = x.reshape(x.shape[0], -1)
            x = self.fp32_op(x)
            return x

    net = TestNet()
    net.initialize()
    data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
    lp_net = amp.convert_hybrid_block(net,
                                      data_example,
                                      target_dtype='float16',
                                      target_dtype_ops=['Convolution'],
                                      fp32_ops=['FullyConnected'],
                                      cast_params_offline=True,
                                      device=mx.current_context())
    lp_net(data_example)
    for name, data in lp_net.collect_params().items():
        assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
예제 #5
0
파일: common.py 프로젝트: DickJC123/mxnet
def test_lp16_fp32_ops_order_independence(lp_dtype):
    class TestNet(nn.HybridBlock):
        def __init__(self, lp16_fp32_is_first):
            super().__init__()
            if lp16_fp32_is_first:
                self.first = mx.npx.batch_flatten  # lp16_fp32_op
                self.second = nn.Dense(4)
            else:
                self.first = nn.Dense(4)
                self.second = mx.npx.batch_flatten  # lp16_fp32_op

        def forward(self, x):
            x = 2**x
            x1 = self.first(x)
            x2 = self.second(x)
            return x1, x2

    data_example = mx.np.random.uniform(-1, 1, (4, 16))

    for lp16_fp32_is_second in [False, True]:
        net = TestNet(lp16_fp32_is_second)
        net.initialize()
        net = amp.convert_hybrid_block(net,
                                       data_example,
                                       lp_dtype,
                                       cast_params_offline=True)
        check_amp_net_stats(lp_dtype,
                            net,
                            data_example,
                            lp16_tensors_num=3,
                            lp16_casts_num=1,
                            other_casts_num=2)
예제 #6
0
파일: common.py 프로젝트: DickJC123/mxnet
def test_amp_offline_casting(lp_dtype):
    class TestNet(nn.HybridBlock):
        def __init__(self):
            super().__init__()
            self.lp16_op1 = nn.Conv2D(4, 3)
            self.lp16_op2 = nn.Conv2DTranspose(4, 3)
            self.fp32_op = nn.Dense(4)

        def forward(self, x):
            x = self.lp16_op1(x)
            x = self.lp16_op2(x)
            x = x.reshape(x.shape[0], -1)
            with nn.HybridBlock.OptConstraint.disable_amp():
                x = self.fp32_op(x)
            return x

    net = TestNet()
    net.initialize()
    data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
    lp_net = amp.convert_hybrid_block(net,
                                      data_example,
                                      lp_dtype,
                                      cast_params_offline=True)

    check_amp_net_stats(lp_dtype,
                        lp_net,
                        data_example,
                        lp16_tensors_num=4,
                        lp16_casts_num=1,
                        other_casts_num=1)
    for name, data in lp_net.collect_params().items():
        assert mx.nd.get_dtype_name(
            data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
예제 #7
0
파일: common.py 프로젝트: DickJC123/mxnet
def test_amp_node_excluding(lp_dtype):
    DISABLE_AMP_ATTR_DICT = {
        '__opt_constraint__':
        str(mx.gluon.HybridBlock.OptConstraint.Flag.DisableAMP.value)
    }

    data = mx.sym.var('data')
    wei = mx.sym.var('weights')
    bias = mx.sym.var('bias')
    # manually excluded
    fc1 = mx.sym.FullyConnected(data,
                                wei,
                                bias,
                                num_hidden=4,
                                name='fc1',
                                attr=DISABLE_AMP_ATTR_DICT)
    # to be excluded using the conversion API
    fc2 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc2')
    symnet = mx.sym.Group([fc1, fc2])

    net = mx.gluon.SymbolBlock(symnet, [data])
    net.initialize()

    # exclude only nodes with set attribute (only 1 node - `fc1`)
    data_example = mx.np.random.uniform(-1, 1, (4, 16))
    net_1_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype)

    lp16_tensors = 4  # cast `data`, weights and bias of `fc1`, `fc1` output
    lp16_casts = 3  # `data` cast, casts for weights and bias of `fc1`
    other_casts = 1  # cast for the network output (from lp16 to f32)
    check_amp_net_stats(lp_dtype,
                        net_1_excluded,
                        data_example,
                        lp16_tensors_num=lp16_tensors,
                        lp16_casts_num=lp16_casts,
                        other_casts_num=other_casts)

    # exclude using the `excluded_sym_names` argument (both nodes)
    net_2_excluded = amp.convert_hybrid_block(
        net, data_example, lp_dtype, excluded_sym_names=['fc1', 'fc2'])
    check_amp_net_stats(lp_dtype,
                        net_2_excluded,
                        data_example,
                        lp16_tensors_num=0,
                        lp16_casts_num=0,
                        other_casts_num=0)
예제 #8
0
def test_amp_conversion_rnn(amp_tests):
    with mx.Device(mx.gpu(0)):
        model = nn.HybridSequential()
        model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True))
        model.add(nn.Dense(2))
        model.initialize()
        model.hybridize()
        out = model(mx.nd.ones((2, 3, 4)))
        new_model = amp.convert_hybrid_block(model)
        out2 = new_model(mx.nd.ones((2, 3, 4)))
        mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)
예제 #9
0
def test_amp_excluding_after_graph_pass():
    class TestNet(nn.HybridBlock):
        def __init__(self):
            super(TestNet, self).__init__()
            self.fc1 = nn.Dense(16)
            self.fc2 = nn.Dense(16)

        def forward(self, x):
            x = self.fc1(x)
            with nn.HybridBlock.OptConstraint.disable_amp():
                x = self.fc2(x)
            return x

    data_example = mx.np.random.uniform(-1, 1, (1, 8))
    net = TestNet()
    net.initialize()

    net_before = amp.convert_hybrid_block(net,
                                          data_example,
                                          AMP_DTYPE,
                                          cast_params_offline=True)
    check_amp_net_stats(AMP_DTYPE,
                        net_before,
                        data_example,
                        lp16_tensors_num=2,
                        lp16_casts_num=1,
                        other_casts_num=1)

    net.optimize_for(data_example,
                     backend=SG_PASS_NAME)  # introduces new nodes
    net_after = amp.convert_hybrid_block(net,
                                         data_example,
                                         AMP_DTYPE,
                                         cast_params_offline=True)
    check_amp_net_stats(AMP_DTYPE,
                        net_after,
                        data_example,
                        lp16_tensors_num=2,
                        lp16_casts_num=1,
                        other_casts_num=1)
예제 #10
0
def check_amp_with_quantization(net, data_example, quantized_nodes):
  net.optimize_for(data_example, backend=QUANTIZE_SG_PASS_NAME)
  symnet = net.export(None)[0]
  nodes = {n['name'] for n in json.loads(symnet.tojson())['nodes'] if n['op'] != 'null'}
  quant_excluded_nodes = list(nodes - set(quantized_nodes))

  _, calib_tensors1 = mx.contrib.quantization._quantize_symbol(
      symnet, mx.current_context(), excluded_symbols=quant_excluded_nodes)

  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=AMP_DTYPE,
                                    excluded_sym_names=quantized_nodes, cast_params_offline=True,
                                    device=mx.current_context())
  lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME)
  lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
  _, calib_tensors2 = mx.contrib.quantization._quantize_symbol(
      lp_symnet, mx.cpu(), excluded_symbols=quant_excluded_nodes)
  assert calib_tensors1 == calib_tensors2
예제 #11
0
파일: common.py 프로젝트: DickJC123/mxnet
def test_amp_offline_casting_shared_params(lp_dtype):
    COMMON_SIZE = 4

    class TestNet(nn.HybridBlock):
        def __init__(self):
            super().__init__()
            self.lp16_op1 = nn.Dense(COMMON_SIZE)
            self.lp16_op2 = nn.Dense(COMMON_SIZE)
            self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
            self.fp32_op = nn.Dense(COMMON_SIZE)
            self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})

        def forward(self, x):
            x = self.lp16_op1(x)
            x1 = self.lp16_op2(x)
            with nn.HybridBlock.OptConstraint.disable_amp():
                x2 = self.fp32_op(x)
            x = mx.np.concat((x1, x2), axis=1)
            return x

    net = TestNet()
    net.initialize()
    data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
    lp_net = amp.convert_hybrid_block(net,
                                      data_example,
                                      lp_dtype,
                                      cast_params_offline=True)

    check_amp_net_stats(lp_dtype,
                        lp_net,
                        data_example,
                        lp16_tensors_num=4,
                        lp16_casts_num=2,
                        other_casts_num=2)
    for name, data in lp_net.collect_params().items():
        assert mx.nd.get_dtype_name(
            data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
예제 #12
0
                        help='Will generate random input of shape (1, 3, 224, 224) '
                             'and run a dummy inference forward pass')
    parser.add_argument('--cast-optional-params', action='store_true', default=False,
                        help='If enabled, will try to cast params to target dtype wherever possible')
    args = parser.parse_args()
    logging.basicConfig()
    logger = logging.getLogger('logger')
    logger.setLevel(logging.INFO)

    assert args.model in gluon_models, "Please choose one of the available gluon models: {}".format(gluon_models)
    shape = None
    if args.model in segmentation_models:
        shape = (1, 3, 480, 480)
    elif args.model in calib_ssd_models:
        shape = (1, 3, 512, 544)
    elif args.model in calib_inception_models:
        shape = (1, 3, 299, 299)
    else:
        shape = (1, 3, 224, 224)
    net = gluoncv.model_zoo.get_model(args.model, pretrained=True)
    net.hybridize()
    result_before1 = net.forward(mx.nd.random.uniform(shape=shape))
    net.export("{}".format(args.model))
    net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params)
    net.export("{}-amp".format(args.model), remove_amp_cast=False)
    if args.run_dummy_inference:
        logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1")
        result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
        result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
        logger.info("Inference run successfully")