def residual_unit_norm_conv(rank, data, nchw_inshape, num_filter, stride, dim_match, name, bottle_neck=True, workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False, bn_group=1, local_gpus=None, local_comm=None): """Return ResNet Unit symbol for building ResNet Parameters ---------- data : str Input data nchw_inshape : tuple of int Input minibatch shape in (n, c, h, w) format independent of actual layout num_filter : int Number of output channels bnf : int Bottle neck channels factor with regard to num_filter stride : tuple Stride used in convolution dim_match : Boolean True means channel number between input and output is the same, otherwise means differ name : str Base name of the operators workspace : int Workspace used in convolution operator Returns ------- (sym, nchw_outshape) sym : the model symbol (up to this point) nchw_output : tuple ( batch_size, features, height, width) """ batch_size = nchw_inshape[0] shape = nchw_inshape[1:] act = 'relu' if fuse_bn_relu else None if bottle_neck: shape = resnet_begin_block_log(shape) # 1st NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen (conv1, conv1_sum, conv1_sum_squares) = \ mx.sym.NormConvolution(data, no_norm=True, act_type=None, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0), name=name + '_conv1', layout=conv_layout) shape = resnet_conv2d_log(shape, 1, int(num_filter*0.25), False) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) # Second NormalizedConvolution: Stats-Apply Relu Convolution Stats-Gen (conv2, conv2_sum, conv2_sum_squares) = \ mx.sym.NormConvolution(conv1, no_norm=False, in_sum=conv1_sum, in_sum_squares=conv1_sum_squares, act_type='relu', num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1), eps=bn_eps, momentum=bn_mom, fix_gamma=False, name=name + '_conv2', layout=conv_layout) shape = resnet_relu_log(shape) shape = resnet_conv2d_log(shape, stride, int(num_filter*0.25), False) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) # Third NormalizedConvolution: Stats-Apply Relu Convolution [no Stats-Gen] # The 'no stats-gen' mode is triggered by just not using the stats outputs for anything. (conv3, conv3_sum, conv3_sum_squares) = \ mx.sym.NormConvolution(conv2, no_norm=False, in_sum=conv2_sum, in_sum_squares=conv2_sum_squares, act_type='relu', num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), eps=bn_eps, momentum=bn_mom, fix_gamma=False, name=name + '_conv3', layout=conv_layout) shape = resnet_relu_log(shape) shape = resnet_conv2d_log(shape, 1, int(num_filter), False) elem_count = element_count((batch_size,) + shape) (bn3_equiv_scale, bn3_equiv_bias, bn3_saved_mean, bn3_saved_inv_std, bn3_gamma_out, bn3_beta_out) = \ mx.sym.BNStatsFinalize(sum=conv3_sum, sum_squares=conv3_sum_squares, eps=bn_eps, momentum=bn_mom, fix_gamma=False, output_mean_var=True, elem_count=elem_count, name=name + '_bn3') shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) dbar = True if dim_match: shortcut = data dbar = False else: if fuse_bn_add_relu: #NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen (shortcut, conv1sc_sum, conv1sc_sum_squares) = \ mx.sym.NormalizedConvolution(data, no_equiv_scale_bias=True, act_type=None, num_filter=int(num_filter), kernel=(1,1), stride=stride, pad=(0,0), name=name + '_conv1sc', layout=conv_layout) shape = resnet_conv2d_log(shape, 1, int(num_filter), False) elem_count = element_count((batch_size,) + shape) (bn1sc_equiv_scale, bn1sc_equiv_bias, bn1sc_saved_mean, bn1sc_saved_inv_std, bn1sc_gamma_out, bn1sc_beta_out) = \ mx.sym.BNStatsFinalize(sum=conv1sc_sum, sum_squares=conv1sc_sum_squares, eps=bn_eps, momentum=bn_mom, fix_gamma=False, output_mean_var=True, elem_count=elem_count, name=name + '_bn1sc') shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) else: dbar = False conv1sc = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, workspace=workspace, name=name+'_conv1sc', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) proj_shape = resnet_conv2d_log(nchw_inshape[1:], stride, int(num_filter), False) shortcut = batchnorm(rank=rank, data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn_sc', cudnn_off=cudnn_bn_off) proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) proj_shape = resnet_projection_log(nchw_inshape[1:], proj_shape) if memonger: shortcut._set_attr(mirror_stage='True') if fuse_bn_add_relu: shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) shape = resnet_relu_log(shape) nchw_shape = (batch_size, ) + shape if dbar: return dual_scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean, data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out, addend_equiv_scale = bn1sc_equiv_scale, addend_equiv_bias = bn1sc_equiv_bias, addend_saved_mean = bn1sc_saved_mean, addend_saved_inv_std = bn1sc_saved_inv_std, addend_gamma_out = bn1sc_gamma_out, addend_beta_out = bn1sc_beta_out, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_dbar3', cudnn_off=cudnn_bn_off),nchw_shape else: return scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean, data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_sbar3', cudnn_off=cudnn_bn_off),nchw_shape else: bn3 = batchnorm(rank=rank, data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) shape = resnet_relu_log(shape) nchw_shape = (batch_size, ) + shape return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), nchw_shape else: raise NotImplementedError
def residual_unit(data, shape, num_filter, stride, dim_match, name, bottle_neck=True, workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False): """Return ResNet Unit symbol for building ResNet Parameters ---------- data : str Input data num_filter : int Number of output channels bnf : int Bottle neck channels factor with regard to num_filter stride : tuple Stride used in convolution dim_match : Boolean True means channel number between input and output is the same, otherwise means differ name : str Base name of the operators workspace : int Workspace used in convolution operator """ input_shape = shape act = 'relu' if fuse_bn_relu else None if bottle_neck: shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK) conv1 = mx.sym.Convolution( data=data, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1), pad=(0, 0), no_bias=True, workspace=workspace, name=name + '_conv1', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, 1, int(num_filter * 0.25), mlperf_log.TRUNCATED_NORMAL, False) bn1 = batchnorm(data=conv1, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn1', cudnn_off=cudnn_bn_off, act_type=act) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) act1 = mx.sym.Activation( data=bn1, act_type='relu', name=name + '_relu1') if not fuse_bn_relu else bn1 shape = resnet_relu_log(shape) conv2 = mx.sym.Convolution( data=act1, num_filter=int(num_filter * 0.25), kernel=(3, 3), stride=stride, pad=(1, 1), no_bias=True, workspace=workspace, name=name + '_conv2', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, stride, int(num_filter * 0.25), mlperf_log.TRUNCATED_NORMAL, False) bn2 = batchnorm(data=conv2, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn2', cudnn_off=cudnn_bn_off, act_type=act) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) act2 = mx.sym.Activation( data=bn2, act_type='relu', name=name + '_relu2') if not fuse_bn_relu else bn2 shape = resnet_relu_log(shape) conv3 = mx.sym.Convolution( data=act2, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), no_bias=True, workspace=workspace, name=name + '_conv3', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, 1, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) if dim_match: shortcut = data else: conv1sc = mx.sym.Convolution( data=data, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, workspace=workspace, name=name + '_conv1sc', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) proj_shape = resnet_conv2d_log(input_shape, stride, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) shortcut = batchnorm(data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_sc', cudnn_off=cudnn_bn_off) proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) proj_shape = resnet_projection_log(input_shape, proj_shape) if memonger: shortcut._set_attr(mirror_stage='True') if fuse_bn_add_relu: shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) return batchnorm_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off), shape else: bn3 = batchnorm(data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), shape else: raise NotImplementedError