def residual_unit_norm_conv(rank, data, nchw_inshape, num_filter, stride, dim_match, name, bottle_neck=True,
                  workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW',
                  verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1,
                  fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False,
                  bn_group=1, local_gpus=None, local_comm=None):
    """Return ResNet Unit symbol for building ResNet
    Parameters
    ----------
    data : str
        Input data
    nchw_inshape : tuple of int
        Input minibatch shape in (n, c, h, w) format independent of actual layout
    num_filter : int
        Number of output channels
    bnf : int
        Bottle neck channels factor with regard to num_filter
    stride : tuple
        Stride used in convolution
    dim_match : Boolean
        True means channel number between input and output is the same, otherwise means differ
    name : str
        Base name of the operators
    workspace : int
        Workspace used in convolution operator

    Returns
    -------
    (sym, nchw_outshape)

    sym : the model symbol (up to this point)

    nchw_output : tuple
        ( batch_size, features, height, width)
    """

    batch_size = nchw_inshape[0]
    shape = nchw_inshape[1:]
    act = 'relu' if fuse_bn_relu else None
    if bottle_neck:
        shape = resnet_begin_block_log(shape)
        # 1st NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen
        (conv1, conv1_sum, conv1_sum_squares) = \
            mx.sym.NormConvolution(data, no_norm=True, act_type=None,
                                   num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
                                   name=name + '_conv1', layout=conv_layout)
        shape = resnet_conv2d_log(shape, 1, int(num_filter*0.25), False)
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

        # Second NormalizedConvolution: Stats-Apply Relu Convolution Stats-Gen
        (conv2, conv2_sum, conv2_sum_squares) = \
            mx.sym.NormConvolution(conv1, no_norm=False, in_sum=conv1_sum, in_sum_squares=conv1_sum_squares, act_type='relu',
                                   num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1),
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   name=name + '_conv2', layout=conv_layout)
        shape = resnet_relu_log(shape)
        shape = resnet_conv2d_log(shape, stride, int(num_filter*0.25), False)
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

        # Third NormalizedConvolution: Stats-Apply Relu Convolution [no Stats-Gen]
        # The 'no stats-gen' mode is triggered by just not using the stats outputs for anything.
        (conv3, conv3_sum, conv3_sum_squares) = \
            mx.sym.NormConvolution(conv2, no_norm=False, in_sum=conv2_sum, in_sum_squares=conv2_sum_squares, act_type='relu',
                                   num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   name=name + '_conv3', layout=conv_layout)
        shape = resnet_relu_log(shape)
        shape = resnet_conv2d_log(shape, 1, int(num_filter), False)
        elem_count = element_count((batch_size,) + shape)
        (bn3_equiv_scale, bn3_equiv_bias, bn3_saved_mean, bn3_saved_inv_std, bn3_gamma_out, bn3_beta_out) = \
            mx.sym.BNStatsFinalize(sum=conv3_sum, sum_squares=conv3_sum_squares,
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   output_mean_var=True, elem_count=elem_count, name=name + '_bn3')
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
        dbar = True     
        if dim_match:
            shortcut = data
            dbar = False
        else:
            if fuse_bn_add_relu:
                #NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen
                (shortcut, conv1sc_sum, conv1sc_sum_squares) = \
                    mx.sym.NormalizedConvolution(data, no_equiv_scale_bias=True, act_type=None,
                            num_filter=int(num_filter), kernel=(1,1), stride=stride, pad=(0,0),
                            name=name + '_conv1sc', layout=conv_layout)
                shape = resnet_conv2d_log(shape, 1, int(num_filter), False)
                elem_count = element_count((batch_size,) + shape)
                (bn1sc_equiv_scale, bn1sc_equiv_bias, bn1sc_saved_mean, bn1sc_saved_inv_std, bn1sc_gamma_out, bn1sc_beta_out) = \
                mx.sym.BNStatsFinalize(sum=conv1sc_sum, sum_squares=conv1sc_sum_squares,
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   output_mean_var=True, elem_count=elem_count, name=name + '_bn1sc')
                shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

            else:
                dbar = False
                conv1sc = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
                                            workspace=workspace, name=name+'_conv1sc', layout=conv_layout,
                                         cudnn_algo_verbose=verbose,
                                         cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo,
                                         cudnn_tensor_core_only=cudnn_tensor_core_only)
                proj_shape = resnet_conv2d_log(nchw_inshape[1:], stride, int(num_filter), False)
                shortcut = batchnorm(rank=rank, data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                                 fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn_sc', cudnn_off=cudnn_bn_off)
                proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
                proj_shape = resnet_projection_log(nchw_inshape[1:], proj_shape)

        if memonger:
            shortcut._set_attr(mirror_stage='True')
        if fuse_bn_add_relu:
            shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
            shape = resnet_end_block_log(shape)
            shape = resnet_relu_log(shape)
            nchw_shape = (batch_size, ) + shape
            if dbar:
                return dual_scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                        data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean,
                        data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out,
                        addend_equiv_scale = bn1sc_equiv_scale, addend_equiv_bias = bn1sc_equiv_bias, addend_saved_mean = bn1sc_saved_mean,
                        addend_saved_inv_std = bn1sc_saved_inv_std, addend_gamma_out = bn1sc_gamma_out, addend_beta_out = bn1sc_beta_out,
                        fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_dbar3', cudnn_off=cudnn_bn_off),nchw_shape
            else:
                return scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                        data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean,
                        data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out,
                        fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_sbar3', cudnn_off=cudnn_bn_off),nchw_shape
        else:
            bn3 = batchnorm(rank=rank, data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                            fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off)
            shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
            shape = resnet_end_block_log(shape)
            shape = resnet_relu_log(shape)
            nchw_shape = (batch_size, ) + shape
            return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), nchw_shape

    else:
        raise NotImplementedError 
def residual_unit(data,
                  shape,
                  num_filter,
                  stride,
                  dim_match,
                  name,
                  bottle_neck=True,
                  workspace=256,
                  memonger=False,
                  conv_layout='NCHW',
                  batchnorm_layout='NCHW',
                  verbose=False,
                  cudnn_bn_off=False,
                  bn_eps=2e-5,
                  bn_mom=0.9,
                  conv_algo=-1,
                  fuse_bn_relu=False,
                  fuse_bn_add_relu=False,
                  cudnn_tensor_core_only=False):
    """Return ResNet Unit symbol for building ResNet
    Parameters
    ----------
    data : str
        Input data
    num_filter : int
        Number of output channels
    bnf : int
        Bottle neck channels factor with regard to num_filter
    stride : tuple
        Stride used in convolution
    dim_match : Boolean
        True means channel number between input and output is the same, otherwise means differ
    name : str
        Base name of the operators
    workspace : int
        Workspace used in convolution operator
    """
    input_shape = shape
    act = 'relu' if fuse_bn_relu else None
    if bottle_neck:
        shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK)
        conv1 = mx.sym.Convolution(
            data=data,
            num_filter=int(num_filter * 0.25),
            kernel=(1, 1),
            stride=(1, 1),
            pad=(0, 0),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv1',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, 1, int(num_filter * 0.25),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        bn1 = batchnorm(data=conv1,
                        io_layout=conv_layout,
                        batchnorm_layout=batchnorm_layout,
                        fix_gamma=False,
                        eps=bn_eps,
                        momentum=bn_mom,
                        name=name + '_bn1',
                        cudnn_off=cudnn_bn_off,
                        act_type=act)
        shape = resnet_batchnorm_log(shape,
                                     momentum=bn_mom,
                                     eps=bn_eps,
                                     center=True,
                                     scale=True,
                                     training=True)

        act1 = mx.sym.Activation(
            data=bn1, act_type='relu', name=name +
            '_relu1') if not fuse_bn_relu else bn1
        shape = resnet_relu_log(shape)

        conv2 = mx.sym.Convolution(
            data=act1,
            num_filter=int(num_filter * 0.25),
            kernel=(3, 3),
            stride=stride,
            pad=(1, 1),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv2',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, stride, int(num_filter * 0.25),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        bn2 = batchnorm(data=conv2,
                        io_layout=conv_layout,
                        batchnorm_layout=batchnorm_layout,
                        fix_gamma=False,
                        eps=bn_eps,
                        momentum=bn_mom,
                        name=name + '_bn2',
                        cudnn_off=cudnn_bn_off,
                        act_type=act)
        shape = resnet_batchnorm_log(shape,
                                     momentum=bn_mom,
                                     eps=bn_eps,
                                     center=True,
                                     scale=True,
                                     training=True)

        act2 = mx.sym.Activation(
            data=bn2, act_type='relu', name=name +
            '_relu2') if not fuse_bn_relu else bn2
        shape = resnet_relu_log(shape)

        conv3 = mx.sym.Convolution(
            data=act2,
            num_filter=num_filter,
            kernel=(1, 1),
            stride=(1, 1),
            pad=(0, 0),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv3',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, 1, int(num_filter),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        if dim_match:
            shortcut = data
        else:
            conv1sc = mx.sym.Convolution(
                data=data,
                num_filter=num_filter,
                kernel=(1, 1),
                stride=stride,
                no_bias=True,
                workspace=workspace,
                name=name + '_conv1sc',
                layout=conv_layout,
                cudnn_algo_verbose=verbose,
                cudnn_algo_fwd=conv_algo,
                cudnn_algo_bwd_data=conv_algo,
                cudnn_algo_bwd_filter=conv_algo,
                cudnn_tensor_core_only=cudnn_tensor_core_only)
            proj_shape = resnet_conv2d_log(input_shape, stride,
                                           int(num_filter),
                                           mlperf_log.TRUNCATED_NORMAL, False)
            shortcut = batchnorm(data=conv1sc,
                                 io_layout=conv_layout,
                                 batchnorm_layout=batchnorm_layout,
                                 fix_gamma=False,
                                 eps=bn_eps,
                                 momentum=bn_mom,
                                 name=name + '_sc',
                                 cudnn_off=cudnn_bn_off)
            proj_shape = resnet_batchnorm_log(proj_shape,
                                              momentum=bn_mom,
                                              eps=bn_eps,
                                              center=True,
                                              scale=True,
                                              training=True)
            proj_shape = resnet_projection_log(input_shape, proj_shape)
        if memonger:
            shortcut._set_attr(mirror_stage='True')
        if fuse_bn_add_relu:
            shape = resnet_batchnorm_log(shape,
                                         momentum=bn_mom,
                                         eps=bn_eps,
                                         center=True,
                                         scale=True,
                                         training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            return batchnorm_add_relu(data=conv3,
                                      addend=shortcut,
                                      io_layout=conv_layout,
                                      batchnorm_layout=batchnorm_layout,
                                      fix_gamma=False,
                                      eps=bn_eps,
                                      momentum=bn_mom,
                                      name=name + '_bn3',
                                      cudnn_off=cudnn_bn_off), shape
        else:
            bn3 = batchnorm(data=conv3,
                            io_layout=conv_layout,
                            batchnorm_layout=batchnorm_layout,
                            fix_gamma=False,
                            eps=bn_eps,
                            momentum=bn_mom,
                            name=name + '_bn3',
                            cudnn_off=cudnn_bn_off)
            shape = resnet_batchnorm_log(shape,
                                         momentum=bn_mom,
                                         eps=bn_eps,
                                         center=True,
                                         scale=True,
                                         training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            return mx.sym.Activation(data=bn3 + shortcut,
                                     act_type='relu',
                                     name=name + '_relu3'), shape

    else:
        raise NotImplementedError