def residual_unit_norm_conv(rank, data, nchw_inshape, num_filter, stride, dim_match, name, bottle_neck=True,
                  workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW',
                  verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1,
                  fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False,
                  bn_group=1, local_gpus=None, local_comm=None):
    """Return ResNet Unit symbol for building ResNet
    Parameters
    ----------
    data : str
        Input data
    nchw_inshape : tuple of int
        Input minibatch shape in (n, c, h, w) format independent of actual layout
    num_filter : int
        Number of output channels
    bnf : int
        Bottle neck channels factor with regard to num_filter
    stride : tuple
        Stride used in convolution
    dim_match : Boolean
        True means channel number between input and output is the same, otherwise means differ
    name : str
        Base name of the operators
    workspace : int
        Workspace used in convolution operator

    Returns
    -------
    (sym, nchw_outshape)

    sym : the model symbol (up to this point)

    nchw_output : tuple
        ( batch_size, features, height, width)
    """

    batch_size = nchw_inshape[0]
    shape = nchw_inshape[1:]
    act = 'relu' if fuse_bn_relu else None
    if bottle_neck:
        shape = resnet_begin_block_log(shape)
        # 1st NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen
        (conv1, conv1_sum, conv1_sum_squares) = \
            mx.sym.NormConvolution(data, no_norm=True, act_type=None,
                                   num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
                                   name=name + '_conv1', layout=conv_layout)
        shape = resnet_conv2d_log(shape, 1, int(num_filter*0.25), False)
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

        # Second NormalizedConvolution: Stats-Apply Relu Convolution Stats-Gen
        (conv2, conv2_sum, conv2_sum_squares) = \
            mx.sym.NormConvolution(conv1, no_norm=False, in_sum=conv1_sum, in_sum_squares=conv1_sum_squares, act_type='relu',
                                   num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1),
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   name=name + '_conv2', layout=conv_layout)
        shape = resnet_relu_log(shape)
        shape = resnet_conv2d_log(shape, stride, int(num_filter*0.25), False)
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

        # Third NormalizedConvolution: Stats-Apply Relu Convolution [no Stats-Gen]
        # The 'no stats-gen' mode is triggered by just not using the stats outputs for anything.
        (conv3, conv3_sum, conv3_sum_squares) = \
            mx.sym.NormConvolution(conv2, no_norm=False, in_sum=conv2_sum, in_sum_squares=conv2_sum_squares, act_type='relu',
                                   num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   name=name + '_conv3', layout=conv_layout)
        shape = resnet_relu_log(shape)
        shape = resnet_conv2d_log(shape, 1, int(num_filter), False)
        elem_count = element_count((batch_size,) + shape)
        (bn3_equiv_scale, bn3_equiv_bias, bn3_saved_mean, bn3_saved_inv_std, bn3_gamma_out, bn3_beta_out) = \
            mx.sym.BNStatsFinalize(sum=conv3_sum, sum_squares=conv3_sum_squares,
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   output_mean_var=True, elem_count=elem_count, name=name + '_bn3')
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
        dbar = True     
        if dim_match:
            shortcut = data
            dbar = False
        else:
            if fuse_bn_add_relu:
                #NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen
                (shortcut, conv1sc_sum, conv1sc_sum_squares) = \
                    mx.sym.NormalizedConvolution(data, no_equiv_scale_bias=True, act_type=None,
                            num_filter=int(num_filter), kernel=(1,1), stride=stride, pad=(0,0),
                            name=name + '_conv1sc', layout=conv_layout)
                shape = resnet_conv2d_log(shape, 1, int(num_filter), False)
                elem_count = element_count((batch_size,) + shape)
                (bn1sc_equiv_scale, bn1sc_equiv_bias, bn1sc_saved_mean, bn1sc_saved_inv_std, bn1sc_gamma_out, bn1sc_beta_out) = \
                mx.sym.BNStatsFinalize(sum=conv1sc_sum, sum_squares=conv1sc_sum_squares,
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   output_mean_var=True, elem_count=elem_count, name=name + '_bn1sc')
                shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

            else:
                dbar = False
                conv1sc = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
                                            workspace=workspace, name=name+'_conv1sc', layout=conv_layout,
                                         cudnn_algo_verbose=verbose,
                                         cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo,
                                         cudnn_tensor_core_only=cudnn_tensor_core_only)
                proj_shape = resnet_conv2d_log(nchw_inshape[1:], stride, int(num_filter), False)
                shortcut = batchnorm(rank=rank, data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                                 fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn_sc', cudnn_off=cudnn_bn_off)
                proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
                proj_shape = resnet_projection_log(nchw_inshape[1:], proj_shape)

        if memonger:
            shortcut._set_attr(mirror_stage='True')
        if fuse_bn_add_relu:
            shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
            shape = resnet_end_block_log(shape)
            shape = resnet_relu_log(shape)
            nchw_shape = (batch_size, ) + shape
            if dbar:
                return dual_scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                        data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean,
                        data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out,
                        addend_equiv_scale = bn1sc_equiv_scale, addend_equiv_bias = bn1sc_equiv_bias, addend_saved_mean = bn1sc_saved_mean,
                        addend_saved_inv_std = bn1sc_saved_inv_std, addend_gamma_out = bn1sc_gamma_out, addend_beta_out = bn1sc_beta_out,
                        fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_dbar3', cudnn_off=cudnn_bn_off),nchw_shape
            else:
                return scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                        data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean,
                        data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out,
                        fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_sbar3', cudnn_off=cudnn_bn_off),nchw_shape
        else:
            bn3 = batchnorm(rank=rank, data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                            fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off)
            shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
            shape = resnet_end_block_log(shape)
            shape = resnet_relu_log(shape)
            nchw_shape = (batch_size, ) + shape
            return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), nchw_shape

    else:
        raise NotImplementedError 
def resnet(rank, units, num_stages, filter_list, num_classes, image_shape, batch_size, bottle_neck=True, workspace=256, dtype='float32', memonger=False,
           input_layout='NCHW', conv_layout='NCHW',  batchnorm_layout='NCHW', pooling_layout='NCHW', verbose=False,
           cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1,
           fuse_bn_relu=False, fuse_bn_add_relu=False, force_tensor_core=False, use_dali=True, norm_conv=True, label_smoothing = 0.0,
           bn_group=1, local_gpus=None, local_comm=None):
    """Return ResNet symbol of
    Parameters
    ----------
    units : list
        Number of units in each stage
    num_stages : int
        Number of stage
    filter_list : list
        Channel size of each stage
    num_classes : int
        Ouput size of symbol
    image_shape : tuple of int
        A 3-element tuple comprising (features, height, width) of each image
    batch_size : int
        The number of images in the training mini-batch
    dataset : str
        Dataset type, only cifar10 and imagenet supports
    workspace : int
        Workspace used in convolution operator
    dtype : str
        Precision (float32 or float16)
    memonger : boolean
        Activates "memory monger" to reduce the model's memory footprint
    input_layout : str
        interpretation (e.g. NCHW vs NHWC) of data provided by the i/o pipeline (may introduce transposes
        if in conflict with 'layout' above)
    conv_layout : str
        interpretation (e.g. NCHW vs NHWC) of data for convolution operation.
    batchnorm_layout : str
        directs which kernel performs the batchnorm (may introduce transposes if in conflict with 'conv_layout' above)
    pooling_layout : str
        directs which kernel performs the pooling (may introduce transposes if in conflict with 'conv_layout' above)
    """

    act = 'relu' if fuse_bn_relu else None
    num_unit = len(units)
    assert(num_unit == num_stages)
    data = mx.sym.Variable(name='data')
    if not use_dali:
        # double buffering of data
        if dtype == 'float32':
            data = mx.sym.identity(data=data, name='id')
        else:
            if dtype == 'float16':
                data = mx.sym.Cast(data=data, dtype=np.float16)
    (nchannel, height, width) = image_shape

    # Insert transpose as needed to get the input layout to match the desired processing layout
    data = transform_layout(data, input_layout, conv_layout)
    res_unit = residual_unit_norm_conv

    shape = image_shape
    body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
                              no_bias=True, name="conv0", workspace=workspace, layout=conv_layout,
                              cudnn_algo_verbose=verbose,
                              cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo,
                              cudnn_tensor_core_only=force_tensor_core)
    shape = resnet_conv2d_log(shape, 2, filter_list[0], False)
    body = batchnorm(rank=rank, data=body, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                     fix_gamma=False, eps=bn_eps, momentum=bn_mom, name='bn0', cudnn_off=cudnn_bn_off, act_type=act, 
                     bn_group=bn_group, local_gpus=local_gpus, local_comm=local_comm)
    shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
    if not fuse_bn_relu:
        body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
    shape = resnet_relu_log(shape)
    body = pooling(data=body, io_layout=conv_layout, pooling_layout=pooling_layout,
                   kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max')
    shape = resnet_max_pool_log(shape, 2)

    nchw_shape = (batch_size, ) + shape

    for i in range(num_stages):
        body, nchw_shape = res_unit(rank, body, nchw_shape, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
                                    name='stage%d_unit%d' % (i + 1, 1),
                                    bottle_neck=bottle_neck, workspace=workspace,
                                    memonger=memonger, conv_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                                    verbose=verbose, cudnn_bn_off=cudnn_bn_off, bn_eps=bn_eps, bn_mom=bn_mom,
                                    conv_algo=conv_algo, fuse_bn_relu=fuse_bn_relu, fuse_bn_add_relu=fuse_bn_add_relu,
                                    cudnn_tensor_core_only=force_tensor_core,
                                    bn_group=bn_group, local_gpus=local_gpus, local_comm=local_comm)
        for j in range(units[i]-1):
            body, nchw_shape = res_unit(rank, body, nchw_shape, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
                                        bottle_neck=bottle_neck, workspace=workspace,
                                        memonger=memonger, conv_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                                        verbose=verbose, cudnn_bn_off=cudnn_bn_off, bn_eps = bn_eps, bn_mom=bn_mom,
                                        conv_algo=conv_algo, fuse_bn_relu=fuse_bn_relu, fuse_bn_add_relu=fuse_bn_add_relu,
                                        cudnn_tensor_core_only=force_tensor_core,
                                        bn_group=bn_group, local_gpus=local_gpus, local_comm=local_comm)
    shape = nchw_shape[1:]
    # Although kernel is not used here when global_pool=True, we should put one
    pool1 = pooling(data=body, io_layout=conv_layout, pooling_layout=pooling_layout,
                    global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
    flat = mx.sym.Flatten(data=pool1)
    shape = (shape[0])
    fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', cublas_algo_verbose=verbose)
    shape = resnet_dense_log(shape, num_classes)

    if dtype == 'float16':
        fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)

    ##########################################################################
    # MXNet computes Cross Entropy loss gradients without explicitly computing
    # the value of loss function.
    # Take a look here:
    # https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.SoftmaxOutput
    # for further details
    ##########################################################################
    return mx.sym.SoftmaxOutput(data=fc1, name='softmax', smooth_alpha=label_smoothing)
def residual_unit(data,
                  shape,
                  num_filter,
                  stride,
                  dim_match,
                  name,
                  bottle_neck=True,
                  workspace=256,
                  memonger=False,
                  conv_layout='NCHW',
                  batchnorm_layout='NCHW',
                  verbose=False,
                  cudnn_bn_off=False,
                  bn_eps=2e-5,
                  bn_mom=0.9,
                  conv_algo=-1,
                  fuse_bn_relu=False,
                  fuse_bn_add_relu=False,
                  cudnn_tensor_core_only=False):
    """Return ResNet Unit symbol for building ResNet
    Parameters
    ----------
    data : str
        Input data
    num_filter : int
        Number of output channels
    bnf : int
        Bottle neck channels factor with regard to num_filter
    stride : tuple
        Stride used in convolution
    dim_match : Boolean
        True means channel number between input and output is the same, otherwise means differ
    name : str
        Base name of the operators
    workspace : int
        Workspace used in convolution operator
    """
    input_shape = shape
    act = 'relu' if fuse_bn_relu else None
    if bottle_neck:
        shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK)
        conv1 = mx.sym.Convolution(
            data=data,
            num_filter=int(num_filter * 0.25),
            kernel=(1, 1),
            stride=(1, 1),
            pad=(0, 0),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv1',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, 1, int(num_filter * 0.25),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        bn1 = batchnorm(data=conv1,
                        io_layout=conv_layout,
                        batchnorm_layout=batchnorm_layout,
                        fix_gamma=False,
                        eps=bn_eps,
                        momentum=bn_mom,
                        name=name + '_bn1',
                        cudnn_off=cudnn_bn_off,
                        act_type=act)
        shape = resnet_batchnorm_log(shape,
                                     momentum=bn_mom,
                                     eps=bn_eps,
                                     center=True,
                                     scale=True,
                                     training=True)

        act1 = mx.sym.Activation(
            data=bn1, act_type='relu', name=name +
            '_relu1') if not fuse_bn_relu else bn1
        shape = resnet_relu_log(shape)

        conv2 = mx.sym.Convolution(
            data=act1,
            num_filter=int(num_filter * 0.25),
            kernel=(3, 3),
            stride=stride,
            pad=(1, 1),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv2',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, stride, int(num_filter * 0.25),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        bn2 = batchnorm(data=conv2,
                        io_layout=conv_layout,
                        batchnorm_layout=batchnorm_layout,
                        fix_gamma=False,
                        eps=bn_eps,
                        momentum=bn_mom,
                        name=name + '_bn2',
                        cudnn_off=cudnn_bn_off,
                        act_type=act)
        shape = resnet_batchnorm_log(shape,
                                     momentum=bn_mom,
                                     eps=bn_eps,
                                     center=True,
                                     scale=True,
                                     training=True)

        act2 = mx.sym.Activation(
            data=bn2, act_type='relu', name=name +
            '_relu2') if not fuse_bn_relu else bn2
        shape = resnet_relu_log(shape)

        conv3 = mx.sym.Convolution(
            data=act2,
            num_filter=num_filter,
            kernel=(1, 1),
            stride=(1, 1),
            pad=(0, 0),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv3',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, 1, int(num_filter),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        if dim_match:
            shortcut = data
        else:
            conv1sc = mx.sym.Convolution(
                data=data,
                num_filter=num_filter,
                kernel=(1, 1),
                stride=stride,
                no_bias=True,
                workspace=workspace,
                name=name + '_conv1sc',
                layout=conv_layout,
                cudnn_algo_verbose=verbose,
                cudnn_algo_fwd=conv_algo,
                cudnn_algo_bwd_data=conv_algo,
                cudnn_algo_bwd_filter=conv_algo,
                cudnn_tensor_core_only=cudnn_tensor_core_only)
            proj_shape = resnet_conv2d_log(input_shape, stride,
                                           int(num_filter),
                                           mlperf_log.TRUNCATED_NORMAL, False)
            shortcut = batchnorm(data=conv1sc,
                                 io_layout=conv_layout,
                                 batchnorm_layout=batchnorm_layout,
                                 fix_gamma=False,
                                 eps=bn_eps,
                                 momentum=bn_mom,
                                 name=name + '_sc',
                                 cudnn_off=cudnn_bn_off)
            proj_shape = resnet_batchnorm_log(proj_shape,
                                              momentum=bn_mom,
                                              eps=bn_eps,
                                              center=True,
                                              scale=True,
                                              training=True)
            proj_shape = resnet_projection_log(input_shape, proj_shape)
        if memonger:
            shortcut._set_attr(mirror_stage='True')
        if fuse_bn_add_relu:
            shape = resnet_batchnorm_log(shape,
                                         momentum=bn_mom,
                                         eps=bn_eps,
                                         center=True,
                                         scale=True,
                                         training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            return batchnorm_add_relu(data=conv3,
                                      addend=shortcut,
                                      io_layout=conv_layout,
                                      batchnorm_layout=batchnorm_layout,
                                      fix_gamma=False,
                                      eps=bn_eps,
                                      momentum=bn_mom,
                                      name=name + '_bn3',
                                      cudnn_off=cudnn_bn_off), shape
        else:
            bn3 = batchnorm(data=conv3,
                            io_layout=conv_layout,
                            batchnorm_layout=batchnorm_layout,
                            fix_gamma=False,
                            eps=bn_eps,
                            momentum=bn_mom,
                            name=name + '_bn3',
                            cudnn_off=cudnn_bn_off)
            shape = resnet_batchnorm_log(shape,
                                         momentum=bn_mom,
                                         eps=bn_eps,
                                         center=True,
                                         scale=True,
                                         training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            return mx.sym.Activation(data=bn3 + shortcut,
                                     act_type='relu',
                                     name=name + '_relu3'), shape

    else:
        raise NotImplementedError