def residual_unit_norm_conv(rank, data, nchw_inshape, num_filter, stride, dim_match, name, bottle_neck=True, workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False, bn_group=1, local_gpus=None, local_comm=None): """Return ResNet Unit symbol for building ResNet Parameters ---------- data : str Input data nchw_inshape : tuple of int Input minibatch shape in (n, c, h, w) format independent of actual layout num_filter : int Number of output channels bnf : int Bottle neck channels factor with regard to num_filter stride : tuple Stride used in convolution dim_match : Boolean True means channel number between input and output is the same, otherwise means differ name : str Base name of the operators workspace : int Workspace used in convolution operator Returns ------- (sym, nchw_outshape) sym : the model symbol (up to this point) nchw_output : tuple ( batch_size, features, height, width) """ batch_size = nchw_inshape[0] shape = nchw_inshape[1:] act = 'relu' if fuse_bn_relu else None if bottle_neck: shape = resnet_begin_block_log(shape) # 1st NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen (conv1, conv1_sum, conv1_sum_squares) = \ mx.sym.NormConvolution(data, no_norm=True, act_type=None, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0), name=name + '_conv1', layout=conv_layout) shape = resnet_conv2d_log(shape, 1, int(num_filter*0.25), False) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) # Second NormalizedConvolution: Stats-Apply Relu Convolution Stats-Gen (conv2, conv2_sum, conv2_sum_squares) = \ mx.sym.NormConvolution(conv1, no_norm=False, in_sum=conv1_sum, in_sum_squares=conv1_sum_squares, act_type='relu', num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1), eps=bn_eps, momentum=bn_mom, fix_gamma=False, name=name + '_conv2', layout=conv_layout) shape = resnet_relu_log(shape) shape = resnet_conv2d_log(shape, stride, int(num_filter*0.25), False) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) # Third NormalizedConvolution: Stats-Apply Relu Convolution [no Stats-Gen] # The 'no stats-gen' mode is triggered by just not using the stats outputs for anything. (conv3, conv3_sum, conv3_sum_squares) = \ mx.sym.NormConvolution(conv2, no_norm=False, in_sum=conv2_sum, in_sum_squares=conv2_sum_squares, act_type='relu', num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), eps=bn_eps, momentum=bn_mom, fix_gamma=False, name=name + '_conv3', layout=conv_layout) shape = resnet_relu_log(shape) shape = resnet_conv2d_log(shape, 1, int(num_filter), False) elem_count = element_count((batch_size,) + shape) (bn3_equiv_scale, bn3_equiv_bias, bn3_saved_mean, bn3_saved_inv_std, bn3_gamma_out, bn3_beta_out) = \ mx.sym.BNStatsFinalize(sum=conv3_sum, sum_squares=conv3_sum_squares, eps=bn_eps, momentum=bn_mom, fix_gamma=False, output_mean_var=True, elem_count=elem_count, name=name + '_bn3') shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) dbar = True if dim_match: shortcut = data dbar = False else: if fuse_bn_add_relu: #NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen (shortcut, conv1sc_sum, conv1sc_sum_squares) = \ mx.sym.NormalizedConvolution(data, no_equiv_scale_bias=True, act_type=None, num_filter=int(num_filter), kernel=(1,1), stride=stride, pad=(0,0), name=name + '_conv1sc', layout=conv_layout) shape = resnet_conv2d_log(shape, 1, int(num_filter), False) elem_count = element_count((batch_size,) + shape) (bn1sc_equiv_scale, bn1sc_equiv_bias, bn1sc_saved_mean, bn1sc_saved_inv_std, bn1sc_gamma_out, bn1sc_beta_out) = \ mx.sym.BNStatsFinalize(sum=conv1sc_sum, sum_squares=conv1sc_sum_squares, eps=bn_eps, momentum=bn_mom, fix_gamma=False, output_mean_var=True, elem_count=elem_count, name=name + '_bn1sc') shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) else: dbar = False conv1sc = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, workspace=workspace, name=name+'_conv1sc', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) proj_shape = resnet_conv2d_log(nchw_inshape[1:], stride, int(num_filter), False) shortcut = batchnorm(rank=rank, data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn_sc', cudnn_off=cudnn_bn_off) proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) proj_shape = resnet_projection_log(nchw_inshape[1:], proj_shape) if memonger: shortcut._set_attr(mirror_stage='True') if fuse_bn_add_relu: shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) shape = resnet_relu_log(shape) nchw_shape = (batch_size, ) + shape if dbar: return dual_scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean, data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out, addend_equiv_scale = bn1sc_equiv_scale, addend_equiv_bias = bn1sc_equiv_bias, addend_saved_mean = bn1sc_saved_mean, addend_saved_inv_std = bn1sc_saved_inv_std, addend_gamma_out = bn1sc_gamma_out, addend_beta_out = bn1sc_beta_out, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_dbar3', cudnn_off=cudnn_bn_off),nchw_shape else: return scale_bias_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, data_equiv_scale=bn3_equiv_scale, data_equiv_bias=bn3_equiv_bias, data_saved_mean=bn3_saved_mean, data_saved_inv_std=bn3_saved_inv_std, data_gamma_out=bn3_gamma_out, data_beta_out=bn3_beta_out, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_sbar3', cudnn_off=cudnn_bn_off),nchw_shape else: bn3 = batchnorm(rank=rank, data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) shape = resnet_relu_log(shape) nchw_shape = (batch_size, ) + shape return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), nchw_shape else: raise NotImplementedError
def resnet(rank, units, num_stages, filter_list, num_classes, image_shape, batch_size, bottle_neck=True, workspace=256, dtype='float32', memonger=False, input_layout='NCHW', conv_layout='NCHW', batchnorm_layout='NCHW', pooling_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, force_tensor_core=False, use_dali=True, norm_conv=True, label_smoothing = 0.0, bn_group=1, local_gpus=None, local_comm=None): """Return ResNet symbol of Parameters ---------- units : list Number of units in each stage num_stages : int Number of stage filter_list : list Channel size of each stage num_classes : int Ouput size of symbol image_shape : tuple of int A 3-element tuple comprising (features, height, width) of each image batch_size : int The number of images in the training mini-batch dataset : str Dataset type, only cifar10 and imagenet supports workspace : int Workspace used in convolution operator dtype : str Precision (float32 or float16) memonger : boolean Activates "memory monger" to reduce the model's memory footprint input_layout : str interpretation (e.g. NCHW vs NHWC) of data provided by the i/o pipeline (may introduce transposes if in conflict with 'layout' above) conv_layout : str interpretation (e.g. NCHW vs NHWC) of data for convolution operation. batchnorm_layout : str directs which kernel performs the batchnorm (may introduce transposes if in conflict with 'conv_layout' above) pooling_layout : str directs which kernel performs the pooling (may introduce transposes if in conflict with 'conv_layout' above) """ act = 'relu' if fuse_bn_relu else None num_unit = len(units) assert(num_unit == num_stages) data = mx.sym.Variable(name='data') if not use_dali: # double buffering of data if dtype == 'float32': data = mx.sym.identity(data=data, name='id') else: if dtype == 'float16': data = mx.sym.Cast(data=data, dtype=np.float16) (nchannel, height, width) = image_shape # Insert transpose as needed to get the input layout to match the desired processing layout data = transform_layout(data, input_layout, conv_layout) res_unit = residual_unit_norm_conv shape = image_shape body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3), no_bias=True, name="conv0", workspace=workspace, layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=force_tensor_core) shape = resnet_conv2d_log(shape, 2, filter_list[0], False) body = batchnorm(rank=rank, data=body, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name='bn0', cudnn_off=cudnn_bn_off, act_type=act, bn_group=bn_group, local_gpus=local_gpus, local_comm=local_comm) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) if not fuse_bn_relu: body = mx.sym.Activation(data=body, act_type='relu', name='relu0') shape = resnet_relu_log(shape) body = pooling(data=body, io_layout=conv_layout, pooling_layout=pooling_layout, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max') shape = resnet_max_pool_log(shape, 2) nchw_shape = (batch_size, ) + shape for i in range(num_stages): body, nchw_shape = res_unit(rank, body, nchw_shape, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace, memonger=memonger, conv_layout=conv_layout, batchnorm_layout=batchnorm_layout, verbose=verbose, cudnn_bn_off=cudnn_bn_off, bn_eps=bn_eps, bn_mom=bn_mom, conv_algo=conv_algo, fuse_bn_relu=fuse_bn_relu, fuse_bn_add_relu=fuse_bn_add_relu, cudnn_tensor_core_only=force_tensor_core, bn_group=bn_group, local_gpus=local_gpus, local_comm=local_comm) for j in range(units[i]-1): body, nchw_shape = res_unit(rank, body, nchw_shape, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, workspace=workspace, memonger=memonger, conv_layout=conv_layout, batchnorm_layout=batchnorm_layout, verbose=verbose, cudnn_bn_off=cudnn_bn_off, bn_eps = bn_eps, bn_mom=bn_mom, conv_algo=conv_algo, fuse_bn_relu=fuse_bn_relu, fuse_bn_add_relu=fuse_bn_add_relu, cudnn_tensor_core_only=force_tensor_core, bn_group=bn_group, local_gpus=local_gpus, local_comm=local_comm) shape = nchw_shape[1:] # Although kernel is not used here when global_pool=True, we should put one pool1 = pooling(data=body, io_layout=conv_layout, pooling_layout=pooling_layout, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') flat = mx.sym.Flatten(data=pool1) shape = (shape[0]) fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', cublas_algo_verbose=verbose) shape = resnet_dense_log(shape, num_classes) if dtype == 'float16': fc1 = mx.sym.Cast(data=fc1, dtype=np.float32) ########################################################################## # MXNet computes Cross Entropy loss gradients without explicitly computing # the value of loss function. # Take a look here: # https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.SoftmaxOutput # for further details ########################################################################## return mx.sym.SoftmaxOutput(data=fc1, name='softmax', smooth_alpha=label_smoothing)
def residual_unit(data, shape, num_filter, stride, dim_match, name, bottle_neck=True, workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False): """Return ResNet Unit symbol for building ResNet Parameters ---------- data : str Input data num_filter : int Number of output channels bnf : int Bottle neck channels factor with regard to num_filter stride : tuple Stride used in convolution dim_match : Boolean True means channel number between input and output is the same, otherwise means differ name : str Base name of the operators workspace : int Workspace used in convolution operator """ input_shape = shape act = 'relu' if fuse_bn_relu else None if bottle_neck: shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK) conv1 = mx.sym.Convolution( data=data, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1), pad=(0, 0), no_bias=True, workspace=workspace, name=name + '_conv1', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, 1, int(num_filter * 0.25), mlperf_log.TRUNCATED_NORMAL, False) bn1 = batchnorm(data=conv1, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn1', cudnn_off=cudnn_bn_off, act_type=act) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) act1 = mx.sym.Activation( data=bn1, act_type='relu', name=name + '_relu1') if not fuse_bn_relu else bn1 shape = resnet_relu_log(shape) conv2 = mx.sym.Convolution( data=act1, num_filter=int(num_filter * 0.25), kernel=(3, 3), stride=stride, pad=(1, 1), no_bias=True, workspace=workspace, name=name + '_conv2', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, stride, int(num_filter * 0.25), mlperf_log.TRUNCATED_NORMAL, False) bn2 = batchnorm(data=conv2, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn2', cudnn_off=cudnn_bn_off, act_type=act) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) act2 = mx.sym.Activation( data=bn2, act_type='relu', name=name + '_relu2') if not fuse_bn_relu else bn2 shape = resnet_relu_log(shape) conv3 = mx.sym.Convolution( data=act2, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), no_bias=True, workspace=workspace, name=name + '_conv3', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, 1, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) if dim_match: shortcut = data else: conv1sc = mx.sym.Convolution( data=data, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, workspace=workspace, name=name + '_conv1sc', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) proj_shape = resnet_conv2d_log(input_shape, stride, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) shortcut = batchnorm(data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_sc', cudnn_off=cudnn_bn_off) proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) proj_shape = resnet_projection_log(input_shape, proj_shape) if memonger: shortcut._set_attr(mirror_stage='True') if fuse_bn_add_relu: shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) return batchnorm_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off), shape else: bn3 = batchnorm(data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), shape else: raise NotImplementedError