def get_rec_iter(args, trainpipes, valpipes, data_paths, kv=None): (rank, num_workers) = _get_rank_and_worker_count(args, kv) # now data is available in the provided paths to DALI, it ensures that the data has not been touched # user need to clean up the /tmp from the created symlinks # DALIClassificationIterator() does the init so we need to provide the real data here if args.dali_cache_size > 0 and args.lazy_init_sanity: link_to_tmp_file(args.data_train, data_paths["train_data_tmp"]) link_to_tmp_file(args.data_train_idx, data_paths["train_idx_tmp"]) link_to_tmp_file(args.data_val, data_paths["val_data_tmp"]) link_to_tmp_file(args.data_val_idx, data_paths["val_idx_tmp"]) mx_resnet_print(key=mlperf_constants.TRAIN_SAMPLES, val=args.num_examples) dali_train_iter = DALIClassificationIterator( trainpipes, args.num_examples // num_workers) if args.num_examples < trainpipes[0].epoch_size("Reader"): warnings.warn( "{} training examples will be used, although full training set contains {} examples" .format(args.num_examples, trainpipes[0].epoch_size("Reader"))) worker_val_examples = valpipes[0].epoch_size("Reader") mx_resnet_print(key=mlperf_constants.EVAL_SAMPLES, val=worker_val_examples) if not args.separ_val: worker_val_examples = worker_val_examples // num_workers if rank < valpipes[0].epoch_size("Reader") % num_workers: worker_val_examples += 1 dali_val_iter = DALIClassificationIterator( valpipes, worker_val_examples, fill_last_batch=False) if args.data_val else None return dali_train_iter, dali_val_iter
def _get_lr_scheduler(args, kv): if 'lr_factor' not in args or args.lr_factor >= 1: return (args.lr, None) epoch_size = get_epoch_size(args, kv) begin_epoch = 0 mx_resnet_print(key=mlperf_constants.OPT_BASE_LR, val=args.lr) mx_resnet_print(key=mlperf_constants.OPT_LR_WARMUP_EPOCHS, val=args.warmup_epochs) if 'pow' in args.lr_step_epochs: lr = args.lr max_up = args.num_epochs * epoch_size pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs)) poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr) return (lr, poly_sched) step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] if len( args.lr_step_epochs.strip()) else [] lr = args.lr for s in step_epochs: if begin_epoch >= s: lr *= args.lr_factor if lr != args.lr: logging.info('Adjust learning rate to %e for epoch %d', lr, begin_epoch) steps = [ epoch_size * (x - begin_epoch) for x in step_epochs if x - begin_epoch > 0 ] if steps: if kv: num_workers = kv.num_workers else: num_workers = 1 epoch_size = math.ceil( int(args.num_examples / num_workers) / args.batch_size) return (lr, mx.lr_scheduler.MultiFactorScheduler( step=steps, factor=args.lr_factor, base_lr=args.lr, warmup_steps=epoch_size * args.warmup_epochs, warmup_mode=args.warmup_strategy)) else: return (lr, None)
def __init__(self, batch_size, num_threads, device_id, rec_path, idx_path, shard_id, num_shards, crop_shape, nvjpeg_padding, prefetch_queue=3, seed=12, resize_shp=None, output_layout=types.NCHW, pad_output=True, dtype='float16', mlperf_print=True): super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id, prefetch_queue_depth=prefetch_queue) self.input = ops.MXNetReader(path=[rec_path], index_path=[idx_path], random_shuffle=False, shard_id=shard_id, num_shards=num_shards) self.decode = ops.nvJPEGDecoder(device="mixed", output_type=types.RGB, device_memory_padding=nvjpeg_padding, host_memory_padding=nvjpeg_padding) self.resize = ops.Resize( device="gpu", resize_shorter=resize_shp) if resize_shp else None self.cmnp = ops.CropMirrorNormalize( device="gpu", output_dtype=types.FLOAT16 if dtype == 'float16' else types.FLOAT, output_layout=output_layout, crop=crop_shape, pad_output=pad_output, image_type=types.RGB, mean=_mean_pixel, std=_std_pixel) if mlperf_print: mx_resnet_print(key=mlperf_log.INPUT_MEAN_SUBTRACTION, val=_mean_pixel) mx_resnet_print(key=mlperf_log.INPUT_RESIZE_ASPECT_PRESERVING) mx_resnet_print(key=mlperf_log.INPUT_CENTRAL_CROP)
def get_rec_iter(args, trainpipes, valpipes, cvalpipes, kv=None): rank = kv.rank if kv else 0 nWrk = kv.num_workers if kv else 1 dali_train_iter = DALIClassificationIterator(trainpipes, args.num_examples // nWrk) if args.no_augument_epoch < args.num_epochs: dali_cval_iter = DALIClassificationIterator(cvalpipes, args.num_examples // nWrk) else: dali_cval_iter = None mx_resnet_print(key=mlperf_log.INPUT_SIZE, val=trainpipes[0].epoch_size("Reader")) mx_resnet_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, val=trainpipes[0].epoch_size("Reader")) if args.data_val: mx_resnet_print(key=mlperf_log.EVAL_SIZE, val=valpipes[0].epoch_size("Reader")) mx_resnet_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, val=valpipes[0].epoch_size("Reader")) if args.num_examples < trainpipes[0].epoch_size("Reader"): warnings.warn( "{} training examples will be used, although full training set contains {} examples" .format(args.num_examples, trainpipes[0].epoch_size("Reader"))) worker_val_examples = valpipes[0].epoch_size("Reader") if not args.separ_val: worker_val_examples = worker_val_examples // nWrk if rank < valpipes[0].epoch_size("Reader") % nWrk: worker_val_examples += 1 dali_val_iter = DALIClassificationIterator( valpipes, worker_val_examples, fill_last_batch=False) if args.data_val else None return dali_train_iter, dali_val_iter, dali_cval_iter
def get_rec_iter(args, kv=None): # resize is default base length of shorter edge for dataset; # all images will be reshaped to this size resize = int(args.resize) # target shape is final shape of images pipelined to network; # all images will be cropped to this size target_shape = tuple([int(l) for l in args.image_shape.split(',')]) pad_output = target_shape[0] == 4 gpus = list(map(int, filter(None, args.gpus.split(',')))) # filter to not encount eventually empty strings batch_size = args.batch_size//len(gpus) mx_resnet_print( key=mlperf_log.INPUT_BATCH_SIZE, val=batch_size) # TODO MPI WORLD SIZE num_threads = args.dali_threads # the input_layout w.r.t. the model is the output_layout of the image pipeline output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW rank = kv.rank if kv else 0 nWrk = kv.num_workers if kv else 1 trainpipes = [HybridTrainPipe(batch_size = batch_size, num_threads = num_threads, device_id = gpu_id, rec_path = args.data_train, idx_path = args.data_train_idx, shard_id = gpus.index(gpu_id) + len(gpus)*rank, num_shards = len(gpus)*nWrk, crop_shape = target_shape[1:], min_random_area = args.min_random_area, max_random_area = args.max_random_area, min_random_aspect_ratio = args.min_random_aspect_ratio, max_random_aspect_ratio = args.max_random_aspect_ratio, nvjpeg_padding = args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue = args.dali_prefetch_queue, seed = args.seed, output_layout = output_layout, pad_output = pad_output, dtype = args.dtype, mlperf_print = gpu_id == gpus[0]) for gpu_id in gpus] valpipes = [HybridValPipe(batch_size = batch_size, num_threads = num_threads, device_id = gpu_id, rec_path = args.data_val, idx_path = args.data_val_idx, shard_id = 0 if args.separ_val else gpus.index(gpu_id) + len(gpus)*rank, num_shards = 1 if args.separ_val else len(gpus)*nWrk, crop_shape = target_shape[1:], nvjpeg_padding = args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue = args.dali_prefetch_queue, seed = args.seed, resize_shp = resize, output_layout = output_layout, pad_output = pad_output, dtype = args.dtype, mlperf_print = gpu_id == gpus[0]) for gpu_id in gpus] if args.data_val else None trainpipes[0].build() if args.data_val: valpipes[0].build() mx_resnet_print( key=mlperf_log.INPUT_SIZE, val=trainpipes[0].epoch_size("Reader")) mx_resnet_print( key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, val=trainpipes[0].epoch_size("Reader")) if args.data_val: mx_resnet_print( key=mlperf_log.EVAL_SIZE, val=valpipes[0].epoch_size("Reader")) mx_resnet_print( key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, val=valpipes[0].epoch_size("Reader")) if args.num_examples < trainpipes[0].epoch_size("Reader"): warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader"))) dali_train_iter = DALIClassificationIterator(trainpipes, args.num_examples // nWrk) worker_val_examples = valpipes[0].epoch_size("Reader") if not args.separ_val: worker_val_examples = worker_val_examples // nWrk if rank < valpipes[0].epoch_size("Reader") % nWrk: worker_val_examples += 1 dali_val_iter = DALIClassificationIterator(valpipes, worker_val_examples, fill_last_batch = False) if args.data_val else None return dali_train_iter, dali_val_iter
def build_input_pipeline(args, kv=None): # resize is default base length of shorter edge for dataset; # all images will be reshaped to this size resize = int(args.resize) # target shape is final shape of images pipelined to network; # all images will be cropped to this size target_shape = tuple([int(l) for l in args.image_shape.split(',')]) pad_output = target_shape[0] == 4 gpus = list(map(int, filter(None, args.gpus.split( ',')))) # filter to not encount eventually empty strings batch_size = args.batch_size // len(gpus) mx_resnet_print(key=mlperf_constants.MODEL_BN_SPAN, val=batch_size) num_threads = args.dali_threads # the input_layout w.r.t. the model is the output_layout of the image pipeline output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW (rank, num_workers) = _get_rank_and_worker_count(args, kv) data_paths = {} if args.dali_cache_size > 0 and args.lazy_init_sanity: data_paths["train_data_tmp"] = get_tmp_file() data_paths["train_idx_tmp"] = get_tmp_file() data_paths["val_data_tmp"] = get_tmp_file() data_paths["val_idx_tmp"] = get_tmp_file() else: data_paths["train_data_tmp"] = args.data_train data_paths["train_idx_tmp"] = args.data_train_idx data_paths["val_data_tmp"] = args.data_val data_paths["val_idx_tmp"] = args.data_val_idx trainpipes = [ HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=gpu_id, rec_path=data_paths["train_data_tmp"], idx_path=data_paths["train_idx_tmp"], shard_id=gpus.index(gpu_id) + len(gpus) * rank, num_shards=len(gpus) * num_workers, crop_shape=target_shape[1:], min_random_area=args.min_random_area, max_random_area=args.max_random_area, min_random_aspect_ratio=args.min_random_aspect_ratio, max_random_aspect_ratio=args.max_random_aspect_ratio, nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue=args.dali_prefetch_queue, seed=args.seed, output_layout=output_layout, pad_output=pad_output, dtype=args.dtype, mlperf_print=gpu_id == gpus[0], use_roi_decode=args.dali_roi_decode, cache_size=args.dali_cache_size) for gpu_id in gpus ] valpipes = [ HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=gpu_id, rec_path=data_paths["val_data_tmp"], idx_path=data_paths["val_idx_tmp"], shard_id=0 if args.separ_val else gpus.index(gpu_id) + len(gpus) * rank, num_shards=1 if args.separ_val else len(gpus) * num_workers, crop_shape=target_shape[1:], nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue=args.dali_prefetch_queue, seed=args.seed, resize_shp=resize, output_layout=output_layout, pad_output=pad_output, dtype=args.dtype, mlperf_print=gpu_id == gpus[0], cache_size=args.dali_cache_size) for gpu_id in gpus ] if args.data_val else None [trainpipe.build() for trainpipe in trainpipes] if args.data_val: [valpipe.build() for valpipe in valpipes] return lambda args, kv: get_rec_iter(args, trainpipes, valpipes, data_paths, kv)
def build_input_pipeline(args, kv=None): # resize is default base length of shorter edge for dataset; # all images will be reshaped to this size resize = int(args.resize) # target shape is final shape of images pipelined to network; # all images will be cropped to this size target_shape = tuple([int(l) for l in args.image_shape.split(',')]) pad_output = target_shape[0] == 4 gpus = list(map(int, filter(None, args.gpus.split( ',')))) # filter to not encount eventually empty strings batch_size = args.batch_size // len(gpus) mx_resnet_print(key=mlperf_constants.MODEL_BN_SPAN, val=batch_size) num_threads = args.dali_threads # the input_layout w.r.t. the model is the output_layout of the image pipeline output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW rank = kv.rank if kv else 0 nWrk = kv.num_workers if kv else 1 trainpipes = [ HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=gpu_id, rec_path=args.data_train, idx_path=args.data_train_idx, shard_id=gpus.index(gpu_id) + len(gpus) * rank, num_shards=len(gpus) * nWrk, crop_shape=target_shape[1:], min_random_area=args.min_random_area, max_random_area=args.max_random_area, min_random_aspect_ratio=args.min_random_aspect_ratio, max_random_aspect_ratio=args.max_random_aspect_ratio, nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue=args.dali_prefetch_queue, seed=args.seed, output_layout=output_layout, pad_output=pad_output, dtype=args.dtype, mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus ] if args.no_augument_epoch >= args.num_epochs: cvalpipes = None else: cvalpipes = [ HybridTrainPipe( batch_size=batch_size, num_threads=num_threads, device_id=gpu_id, rec_path=args.data_train, idx_path=args.data_train_idx, shard_id=gpus.index(gpu_id) + len(gpus) * rank, num_shards=len(gpus) * nWrk, crop_shape=target_shape[1:], min_random_area=args.min_random_area_2, max_random_area=args.max_random_area_2, min_random_aspect_ratio=args.min_random_aspect_ratio_2, max_random_aspect_ratio=args.max_random_aspect_ratio_2, nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue=args.dali_prefetch_queue, seed=args.seed, output_layout=output_layout, pad_output=pad_output, dtype=args.dtype, mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus ] if args.use_new_cval else [ HybridCvalPipe(batch_size=batch_size, num_threads=num_threads, device_id=gpu_id, rec_path=args.data_train, idx_path=args.data_train_idx, shard_id=gpus.index(gpu_id) + len(gpus) * rank, num_shards=len(gpus) * nWrk, crop_shape=target_shape[1:], nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue=args.dali_prefetch_queue, seed=args.seed, resize_shp=resize, output_layout=output_layout, pad_output=pad_output, dtype=args.dtype, mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus ] valpipes = [ HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=gpu_id, rec_path=args.data_val, idx_path=args.data_val_idx, shard_id=0 if args.separ_val else gpus.index(gpu_id) + len(gpus) * rank, num_shards=1 if args.separ_val else len(gpus) * nWrk, crop_shape=target_shape[1:], nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 * 1024, prefetch_queue=args.dali_prefetch_queue, seed=args.seed, resize_shp=resize, output_layout=output_layout, pad_output=pad_output, dtype=args.dtype, mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus ] if args.data_val else None trainpipes[0].build() if args.no_augument_epoch < args.num_epochs: cvalpipes[0].build() if args.data_val: valpipes[0].build() return lambda args, kv: get_rec_iter(args, trainpipes, valpipes, cvalpipes, kv)
def residual_unit(data, shape, num_filter, stride, dim_match, name, bottle_neck=True, workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False): """Return ResNet Unit symbol for building ResNet Parameters ---------- data : str Input data num_filter : int Number of output channels bnf : int Bottle neck channels factor with regard to num_filter stride : tuple Stride used in convolution dim_match : Boolean True means channel number between input and output is the same, otherwise means differ name : str Base name of the operators workspace : int Workspace used in convolution operator """ input_shape = shape act = 'relu' if fuse_bn_relu else None if bottle_neck: shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK) conv1 = mx.sym.Convolution( data=data, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1), pad=(0, 0), no_bias=True, workspace=workspace, name=name + '_conv1', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, 1, int(num_filter * 0.25), mlperf_log.TRUNCATED_NORMAL, False) bn1 = batchnorm(data=conv1, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn1', cudnn_off=cudnn_bn_off, act_type=act) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) act1 = mx.sym.Activation( data=bn1, act_type='relu', name=name + '_relu1') if not fuse_bn_relu else bn1 shape = resnet_relu_log(shape) conv2 = mx.sym.Convolution( data=act1, num_filter=int(num_filter * 0.25), kernel=(3, 3), stride=stride, pad=(1, 1), no_bias=True, workspace=workspace, name=name + '_conv2', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, stride, int(num_filter * 0.25), mlperf_log.TRUNCATED_NORMAL, False) bn2 = batchnorm(data=conv2, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn2', cudnn_off=cudnn_bn_off, act_type=act) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) act2 = mx.sym.Activation( data=bn2, act_type='relu', name=name + '_relu2') if not fuse_bn_relu else bn2 shape = resnet_relu_log(shape) conv3 = mx.sym.Convolution( data=act2, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), no_bias=True, workspace=workspace, name=name + '_conv3', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) shape = resnet_conv2d_log(shape, 1, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) if dim_match: shortcut = data else: conv1sc = mx.sym.Convolution( data=data, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, workspace=workspace, name=name + '_conv1sc', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) proj_shape = resnet_conv2d_log(input_shape, stride, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) shortcut = batchnorm(data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_sc', cudnn_off=cudnn_bn_off) proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) proj_shape = resnet_projection_log(input_shape, proj_shape) if memonger: shortcut._set_attr(mirror_stage='True') if fuse_bn_add_relu: shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) return batchnorm_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off), shape else: bn3 = batchnorm(data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), shape else: raise NotImplementedError
def residual_unit_norm_conv(data, nchw_inshape, num_filter, stride, dim_match, name, bottle_neck=True, workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False): """Return ResNet Unit symbol for building ResNet Parameters ---------- data : str Input data nchw_inshape : tuple of int Input minibatch shape in (n, c, h, w) format independent of actual layout num_filter : int Number of output channels bnf : int Bottle neck channels factor with regard to num_filter stride : tuple Stride used in convolution dim_match : Boolean True means channel number between input and output is the same, otherwise means differ name : str Base name of the operators workspace : int Workspace used in convolution operator Returns ------- (sym, nchw_outshape) sym : the model symbol (up to this point) nchw_output : tuple ( batch_size, features, height, width) """ batch_size = nchw_inshape[0] shape = nchw_inshape[1:] act = 'relu' if fuse_bn_relu else None if bottle_neck: shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK) # 1st NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen (conv1, conv1_sum, conv1_sum_squares) = \ mx.sym.NormalizedConvolution(data, no_equiv_scale_bias=True, act_type=None, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0), name=name + '_conv1', layout=conv_layout) shape = resnet_conv2d_log(shape, 1, int(num_filter*0.25), mlperf_log.TRUNCATED_NORMAL, False) # As prep for 2nd NormalizedConvolution: Finalize kernel converts sum,sum_squares to equiv_scale,equiv_bias elem_count = element_count((batch_size,) + shape) (bn1_equiv_scale, bn1_equiv_bias, bn1_saved_mean, bn1_saved_inv_std, bn1_gamma_out, bn1_beta_out) = \ mx.sym.BNStatsFinalize(sum=conv1_sum, sum_squares=conv1_sum_squares, eps=bn_eps, momentum=bn_mom, fix_gamma=False, output_mean_var=True, elem_count=elem_count, name=name + '_bn1') shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) # Second NormalizedConvolution: Stats-Apply Relu Convolution Stats-Gen (conv2, conv2_sum, conv2_sum_squares) = \ mx.sym.NormalizedConvolution(conv1, no_equiv_scale_bias=False, act_type='relu', num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1), equiv_scale=bn1_equiv_scale, equiv_bias=bn1_equiv_bias, mean=bn1_saved_mean, var=bn1_saved_inv_std, gamma=bn1_gamma_out, beta=bn1_beta_out, name=name + '_conv2', layout=conv_layout) shape = resnet_relu_log(shape) shape = resnet_conv2d_log(shape, stride, int(num_filter*0.25), mlperf_log.TRUNCATED_NORMAL, False) # As prep for 3nd NormalizedConvolution: Finalize kernel converts sum,sum_squares to equiv_scale,equiv_bias elem_count = element_count((batch_size,) + shape) (bn2_equiv_scale, bn2_equiv_bias, bn2_saved_mean, bn2_saved_inv_std, bn2_gamma_out, bn2_beta_out) = \ mx.sym.BNStatsFinalize(sum=conv2_sum, sum_squares=conv2_sum_squares, eps=bn_eps, momentum=bn_mom, fix_gamma=False, output_mean_var=True, elem_count=elem_count, name=name + '_bn2') shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) # Third NormalizedConvolution: Stats-Apply Relu Convolution [no Stats-Gen] # The 'no stats-gen' mode is triggered by just not using the stats outputs for anything. (conv3, _, _) = \ mx.sym.NormalizedConvolution(conv2, no_equiv_scale_bias=False, act_type='relu', num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), equiv_scale=bn2_equiv_scale, equiv_bias=bn2_equiv_bias, mean=bn2_saved_mean, var=bn2_saved_inv_std, gamma=bn2_gamma_out, beta=bn2_beta_out, name=name + '_conv3', layout=conv_layout) shape = resnet_relu_log(shape) shape = resnet_conv2d_log(shape, 1, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) if dim_match: shortcut = data else: conv1sc = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, workspace=workspace, name=name+'_conv1sc', layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=cudnn_tensor_core_only) proj_shape = resnet_conv2d_log(nchw_inshape[1:], stride, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False) shortcut = batchnorm(data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn_sc', cudnn_off=cudnn_bn_off) proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) proj_shape = resnet_projection_log(nchw_inshape[1:], proj_shape) if memonger: shortcut._set_attr(mirror_stage='True') if fuse_bn_add_relu: shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) nchw_shape = (batch_size, ) + shape return batchnorm_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off), nchw_shape else: bn3 = batchnorm(data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) shape = resnet_end_block_log(shape) mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD) shape = resnet_relu_log(shape) nchw_shape = (batch_size, ) + shape return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), nchw_shape else: raise NotImplementedError
def fit(args, kv, model, initializer, data_loader, devs, arg_params, aux_params, **kwargs): """ train a model args : argparse returns model : loaded model of the neural network initializer : weight initializer data_loader : function that returns the train and val data iterators devs : devices for training arg_params : model parameters aux_params : model parameters """ if args.profile_server_suffix: mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server') mx.profiler.set_state(state='run', profile_process='server') if args.profile_worker_suffix: if kv.num_workers > 1: filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix else: filename = args.profile_worker_suffix mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker') mx.profiler.set_state(state='run', profile_process='worker') # logging head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) logging.info('start with arguments %s', args) epoch_size = get_epoch_size(args, kv) # save model epoch_end_callbacks = [] # learning rate lr, lr_scheduler = _get_lr_scheduler(args, kv) total_steps = math.ceil(args.num_examples * args.num_epochs / kv.num_workers / args.batch_size) warmup_steps = get_epoch_size(args, kv) * args.warmup_epochs if args.decay_steps > 0: decay_steps = args.decay_steps decay_epochs = math.ceil(args.decay_steps / get_epoch_size(args, kv)) else: if args.decay_after_warmup: decay_steps = total_steps - warmup_steps decay_epochs = args.num_epochs - args.warmup_epochs else: decay_steps = total_steps decay_epochs = args.num_epochs explorer_params = { 'burn_in_iter': args.burn_in, 'lr_range_max': args.lr_range_max, 'lr_range_min': args.lr_range_min, 'wd_range_max': args.wd_range_max, 'wd_range_min': args.wd_range_min, 'cg_range_max': args.cg_range_max, 'cg_range_min': args.cg_range_min, 'start_epoch': 0, 'end_epoch': args.num_epochs, 'lr_decay': args.lr_decay, 'wd_decay': args.wd_decay if args.wd_decay != -1 else args.lr_decay, 'num_grps': args.num_grps, 'num_cg_grps': args.num_cg_grps, 'lr_decay_mode': args.lr_decay_mode, 'wd_decay_mode': args.wd_decay_mode, 'wd_step_epochs': [int(s.strip()) for s in args.wd_step_epochs.split(',')] if len(args.wd_step_epochs.strip()) else [], 'wd_factor': args.wd_factor, 'lr_rate': args.lr, 'warmup_step': warmup_steps, 'warmup_epochs': args.warmup_epochs, 'wd_warmup': args.wd_warmup, 'ds_upper_factor': args.ds_upper_factor, 'ds_lower_factor': args.ds_lower_factor, 'ds_fix_min': args.ds_fix_min, 'ds_fix_max': args.ds_fix_max, 'explore_freq': args.explore_freq, 'natan_turn_epoch': args.natan_turn_epoch, 'natan_final_ratio': args.natan_final_ratio, 'explore_start_epoch': args.explore_start_epoch, 'momentum': args.mom, 'momentum_end': args.mom_end if args.mom_end else args.mom, 'epoch_size': epoch_size, 'smooth_decay': args.smooth_decay, 'add_one_fwd_epoch': args.add_one_fwd_epoch if args.add_one_fwd_epoch is not None else args.num_epochs, 'no_augument_epoch': args.no_augument_epoch if args.no_augument_epoch is not None else args.num_epochs, 'decay_after_warmup': args.decay_after_warmup, 'end_lr_ratio': args.end_lr_ratio, 'total_steps': total_steps, 'decay_steps': decay_steps, 'decay_epochs': decay_epochs, } optimizer_params = { 'learning_rate': lr, 'wd': args.wd * args.lr, 'lr_scheduler': lr_scheduler, 'multi_precision': True } mx_resnet_print(key=mlperf_constants.OPT_NAME, val='lars') #args.optimizer mx_resnet_print(key=mlperf_constants.LARS_EPSILON, val=1e-9) mx_resnet_print(key=mlperf_constants.LARS_OPT_WEIGHT_DECAY, val=args.wd) mx_resnet_print(key=mlperf_constants.LARS_OPT_LR_DECAY_POLY_POWER, val=args.lr_decay) mx_resnet_print(key=mlperf_constants.LARS_OPT_END_LR, val=args.lr * args.end_lr_ratio) mx_resnet_print(key=mlperf_constants.LARS_OPT_LR_DECAY_STEPS, val=decay_steps) ########################################################################## # MXNet excludes BN layers from L2 Penalty by default, # so this won't be explicitly stated anywhere in the code ########################################################################## mx_resnet_print(key=mlperf_log.MODEL_EXCLUDE_BN_FROM_L2, val=True) # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'} if args.optimizer in has_momentum: optimizer_params['momentum'] = args.mom if args.optimizer == 'sgd': optimizer_params['bias_wd'] = args.bias_wd optimizer_params['bn_lr_decay'] = args.bn_lr_decay ### copy from nvidia-mxnet/3rdparty/horovod/example/mxnet/common/fit.py # A limited number of optimizers have a warmup period has_warmup = {'lbsgd', 'lbnag'} if args.optimizer in has_warmup: nworkers = kv.num_workers epoch_size = args.num_examples / args.batch_size / nworkers if epoch_size < 1: epoch_size = 1 macrobatch_size = args.macrobatch_size if macrobatch_size < args.batch_size * nworkers: macrobatch_size = args.batch_size * nworkers #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999) batch_scale = math.ceil( float(macrobatch_size) / args.batch_size / nworkers) optimizer_params['updates_per_epoch'] = epoch_size #optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0 optimizer_params['batch_scale'] = batch_scale optimizer_params['warmup_strategy'] = args.warmup_strategy optimizer_params['warmup_epochs'] = args.warmup_epochs optimizer_params['num_epochs'] = args.num_epochs ### # evaluation metrices eval_metrics = ['accuracy'] if args.top_k > 0: eval_metrics.append( mx.metric.create('top_k_accuracy', top_k=args.top_k)) # callbacks that run after each batch batch_end_callbacks = [] if 'horovod' in args.kv_store: # if using horovod, only report on rank 0 with global batch size if kv.rank == 0: batch_end_callbacks.append( mx.callback.Speedometer(kv.num_workers * args.batch_size, args.disp_batches)) mx_resnet_print(key=mlperf_constants.GLOBAL_BATCH_SIZE, val=kv.num_workers * args.batch_size) else: batch_end_callbacks.append( mx.callback.Speedometer(args.batch_size, args.disp_batches)) mx_resnet_print(key=mlperf_constants.GLOBAL_BATCH_SIZE, val=args.batch_size) mx_resnet_print(key=mlperf_log.EVAL_TARGET, val=args.accuracy_threshold) # run last_epoch = mlperf_fit( model, args, data_loader, epoch_size, begin_epoch=0, num_epoch=args.num_epochs, eval_metric=eval_metrics, kvstore=kv, optimizer=args.optimizer, optimizer_params=optimizer_params, explorer=args.explorer, explorer_params=explorer_params, initializer=initializer, arg_params=arg_params, aux_params=aux_params, batch_end_callback=batch_end_callbacks, epoch_end_callback= epoch_end_callbacks, #checkpoint if args.use_dali else ,, allow_missing=True, eval_offset=args.eval_offset, eval_period=args.eval_period, accuracy_threshold=args.accuracy_threshold) if ('horovod' not in args.kv_store) or kv.rank == 0: arg_params, aux_params = model.get_params() model.set_params(arg_params, aux_params) model.save_checkpoint('MLPerf-RN50v15', last_epoch, save_optimizer_states=False) # When using horovod, ensure all ops scheduled by the engine complete before exiting if 'horovod' in args.kv_store: mx.ndarray.waitall() if args.profile_server_suffix: mx.profiler.set_state(state='run', profile_process='server') if args.profile_worker_suffix: mx.profiler.set_state(state='run', profile_process='worker')
def mlperf_fit(self, args, data_loader, epoch_size, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), explorer='linear', explorer_params=None, eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None, eval_offset=0, eval_period=1, accuracy_threshold=1.0): assert num_epoch is not None, 'please specify number of epochs' if monitor is not None: self.install_monitor(monitor) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) explorer = Explorer.create_explorer(name=explorer, optimizer=self._optimizer, explorer_params=explorer_params) #This mxnet can not use several optimizers without sgd series explorer.set_best_coeff(0) explorer.set_best_wd_coeff(0) explorer.set_best_cg(0) exp_freq = explorer_params['explore_freq'] exp_start_epoch = explorer_params['explore_start_epoch'] if validation_metric is None: validation_metric = eval_metric ########################################################################### # Adding Correct and Total Count metrics ########################################################################### if not isinstance(validation_metric, list): validation_metric = [validation_metric] validation_metric = mx.metric.create(validation_metric) if not isinstance(validation_metric, mx.metric.CompositeEvalMetric): vm = mx.metric.CompositeEvalMetric() vm.append(validation_metric) validation_metric = vm for m in [CorrectCount(), TotalCount()]: validation_metric.metrics.append(m) ########################################################################### if not isinstance(eval_metric, mx.metric.EvalMetric): eval_metric = mx.metric.create(eval_metric) try: world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) except: world_rank = 0 world_size = 1 use_cval_data = explorer_params['add_one_fwd_epoch'] < num_epoch \ or explorer_params['no_augument_epoch'] < num_epoch best_rank = 0 self.prepare_states() mx_resnet_print(key=mlperf_constants.INIT_STOP, sync=True) mx_resnet_print(key=mlperf_constants.RUN_START, sync=True) # data iterators (train_data, eval_data, cval_data) = data_loader(args, kvstore) if 'dist' in args.kv_store and not 'async' in args.kv_store: logging.info('Resizing training data to %d batches per machine', epoch_size) # resize train iter to ensure each machine has same number of batches per epoch # if not, dist_sync can hang at the end with one machine waiting for other machines if not args.use_dali: train = mx.io.ResizeIter(train_data, epoch_size) block_epoch_start = begin_epoch block_epoch_count = eval_offset + 1 - (begin_epoch % eval_period) if block_epoch_count < 0: block_epoch_count += eval_period mx_resnet_print(key=mlperf_constants.BLOCK_START, metadata={ 'first_epoch_num': block_epoch_start + 1, 'epoch_count': block_epoch_count }) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): mx_resnet_print(key=mlperf_constants.EPOCH_START, metadata={'epoch_num': epoch + 1}) tic = time.time() eval_metric.reset() nbatch = 0 use_normal_data_batch = epoch < explorer_params['no_augument_epoch'] if not use_normal_data_batch: if world_rank == 0: self.logger.info('use non-augumented batch') end_of_batch = False if use_normal_data_batch: data_iter = iter(train_data) next_data_batch = next(data_iter) else: cval_iter = iter(cval_data) next_cval_batch = next(cval_iter) smooth_decay = explorer_params['smooth_decay'] if not smooth_decay: explorer.apply_lr_decay_epoch(epoch) explorer.apply_wd_decay_epoch(epoch) explorer.set_mom(epoch) while not end_of_batch: if use_normal_data_batch: data_batch = next_data_batch else: cval_batch = next_cval_batch if monitor is not None: monitor.tic() if use_normal_data_batch: self.forward_backward(data_batch) else: self.forward_backward(cval_batch) if smooth_decay: explorer.apply_lr_decay_iter() explorer.apply_wd_decay_iter() explorer.apply_wd_warmup() explorer.apply_burn_in() use_explorer = (epoch == 0 and nbatch == 0) or (epoch >= exp_start_epoch and nbatch % exp_freq == 0) if use_explorer: explorer.set_tmp_coeff(world_rank) explorer.set_tmp_wd_coeff(world_rank) explorer.set_tmp_cg(world_rank) explorer.set_best_coeff(0) explorer.set_best_wd_coeff(0) explorer.set_best_cg(world_rank) self.update() if use_normal_data_batch: if isinstance(data_batch, list): self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) else: if isinstance(cval_batch, list): self.update_metric(eval_metric, [db.label for db in cval_batch], pre_sliced=True) else: self.update_metric(eval_metric, cval_batch.label) if use_normal_data_batch: try: # pre fetch next batch next_data_batch = next(data_iter) except StopIteration: end_of_batch = True else: try: # pre fetch next cval batch next_cval_batch = next(cval_iter) except StopIteration: end_of_batch = True if use_normal_data_batch: if not end_of_batch: self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn) else: if not end_of_batch: self.prepare(next_cval_batch, sparse_row_id_fn=sparse_row_id_fn) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 mx_resnet_print(key=mlperf_constants.EPOCH_STOP, metadata={"epoch_num": epoch + 1}) # one epoch of training is finished toc = time.time() if kvstore: if kvstore.rank == 0: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) else: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices #arg_params, aux_params = self.get_params() #self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data and epoch >= eval_offset and ( epoch - eval_offset) % eval_period == 0: mx_resnet_print(key=mlperf_constants.EVAL_START, metadata={'epoch_num': epoch + 1}) res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default if kvstore: if kvstore.rank == 0: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) else: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) res = dict(res) acc = [res['correct-count'], res['total-count']] acc = all_reduce(acc) acc = acc[0] / acc[1] mx_resnet_print(key=mlperf_constants.EVAL_STOP, metadata={'epoch_num': epoch + 1}) mx_resnet_print(key=mlperf_constants.EVAL_ACCURACY, val=acc, metadata={'epoch_num': epoch + 1}) mx_resnet_print( key=mlperf_constants.BLOCK_STOP, metadata={'first_epoch_num': block_epoch_start + 1}) if acc > accuracy_threshold: mx_resnet_print(key=mlperf_constants.RUN_STOP, metadata={'status': 'success'}) return epoch if epoch < (num_epoch - 1): block_epoch_start = epoch + 1 block_epoch_count = num_epoch - epoch - 1 if block_epoch_count > eval_period: block_epoch_count = eval_period mx_resnet_print(key=mlperf_constants.BLOCK_START, metadata={ 'first_epoch_num': block_epoch_start + 1, 'epoch_count': block_epoch_count }) # end of 1 epoch, reset the data-iter for another epoch if use_normal_data_batch: train_data.reset() else: cval_data.reset() mx_resnet_print(key=mlperf_constants.RUN_STOP, metadata={'status': 'aborted'}) return num_epoch
) args = parser.parse_args() # select gpu for horovod process if 'horovod' in args.kv_store: args.gpus = _get_gpu(args.gpus) # kvstore kv = mx.kvstore.create(args.kv_store) # load network from importlib import import_module net = import_module('symbols.'+args.network) mx_resnet_print(key=mlperf_log.EVAL_EPOCH_OFFSET, val=args.eval_offset) mx_resnet_print(key=mlperf_log.RUN_START, sync=True) if args.seed is None: args.seed = int(random.SystemRandom().randint(0, 2**16 - 1)) if 'horovod' in args.kv_store: all_seeds = np.random.randint(2**16, size=(int(os.environ['OMPI_COMM_WORLD_SIZE']))) args.seed = int(all_seeds[int(os.environ['OMPI_COMM_WORLD_RANK'])]) mx_resnet_print(key=mlperf_log.RUN_SET_RANDOM_SEED, val=args.seed, uniq=False) random.seed(args.seed) np.random.seed(args.seed) mx.random.seed(args.seed)
def fit(args, kv, network, data_loader, **kwargs): """ train a model args : argparse returns network : the symbol definition of the nerual network data_loader : function that returns the train and val data iterators """ if args.profile_server_suffix: mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server') mx.profiler.set_state(state='run', profile_process='server') if args.profile_worker_suffix: if kv.num_workers > 1: filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix else: filename = args.profile_worker_suffix mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker') mx.profiler.set_state(state='run', profile_process='worker') # logging head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) logging.info('start with arguments %s', args) epoch_size = get_epoch_size(args, kv) # data iterators (train, val) = data_loader(args, kv) if 'dist' in args.kv_store and not 'async' in args.kv_store: logging.info('Resizing training data to %d batches per machine', epoch_size) # resize train iter to ensure each machine has same number of batches per epoch # if not, dist_sync can hang at the end with one machine waiting for other machines if not args.use_dali: train = mx.io.ResizeIter(train, epoch_size) # load model if 'arg_params' in kwargs and 'aux_params' in kwargs: arg_params = kwargs['arg_params'] aux_params = kwargs['aux_params'] else: sym, arg_params, aux_params = None, None, None if sym is not None: assert sym.tojson() == network.tojson() # save model epoch_end_callbacks = [] # devices for training devs = mx.cpu() if args.gpus is None or args.gpus == "" else [ mx.gpu(int(i)) for i in args.gpus.split(',') ] # learning rate lr, lr_scheduler = _get_lr_scheduler(args, kv) # create model model = mx.mod.Module(context=devs, symbol=network) optimizer_params = { 'learning_rate': lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, 'multi_precision': True } mx_resnet_print(key=mlperf_log.OPT_NAME, val=args.optimizer) mx_resnet_print(key=mlperf_log.OPT_LR, val=lr) mx_resnet_print(key=mlperf_log.OPT_MOMENTUM, val=args.mom) mx_resnet_print(key=mlperf_log.MODEL_L2_REGULARIZATION, val=args.wd) ########################################################################## # MXNet excludes BN layers from L2 Penalty by default, # so this won't be explicitly stated anywhere in the code ########################################################################## mx_resnet_print(key=mlperf_log.MODEL_EXCLUDE_BN_FROM_L2, val=True) # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'} if args.optimizer in has_momentum: optimizer_params['momentum'] = args.mom initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) # evaluation metrices eval_metrics = ['accuracy'] if args.top_k > 0: eval_metrics.append( mx.metric.create('top_k_accuracy', top_k=args.top_k)) # callbacks that run after each batch batch_end_callbacks = [] if 'horovod' in args.kv_store: # if using horovod, only report on rank 0 with global batch size if kv.rank == 0: batch_end_callbacks.append( mx.callback.Speedometer(kv.num_workers * args.batch_size, args.disp_batches)) else: batch_end_callbacks.append( mx.callback.Speedometer(args.batch_size, args.disp_batches)) mx_resnet_print(key=mlperf_log.EVAL_TARGET, val=args.accuracy_threshold) # run last_epoch = mlperf_fit( model, train, begin_epoch=0, num_epoch=args.num_epochs, eval_data=val, eval_metric=eval_metrics, kvstore=kv, optimizer=args.optimizer, optimizer_params=optimizer_params, initializer=initializer, arg_params=arg_params, aux_params=aux_params, batch_end_callback=batch_end_callbacks, epoch_end_callback= epoch_end_callbacks, #checkpoint if args.use_dali else ,, allow_missing=True, eval_offset=args.eval_offset, eval_period=args.eval_period, accuracy_threshold=args.accuracy_threshold) mx_resnet_print(key=mlperf_log.RUN_STOP, sync=True) if ('horovod' not in args.kv_store) or kv.rank == 0: model.save_checkpoint('MLPerf-RN50v15', last_epoch, save_optimizer_states=False) # When using horovod, ensure all ops scheduled by the engine complete before exiting if 'horovod' in args.kv_store: mx.ndarray.waitall() if args.profile_server_suffix: mx.profiler.set_state(state='run', profile_process='server') if args.profile_worker_suffix: mx.profiler.set_state(state='run', profile_process='worker')
def mlperf_fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None, eval_offset=0, eval_period=1, accuracy_threshold=1.0): assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric ########################################################################### # Adding Correct and Total Count metrics ########################################################################### if not isinstance(validation_metric, list): validation_metric = [validation_metric] validation_metric = mx.metric.create(validation_metric) if not isinstance(validation_metric, mx.metric.CompositeEvalMetric): vm = mx.metric.CompositeEvalMetric() vm.append(validation_metric) validation_metric = vm for m in [CorrectCount(), TotalCount()]: validation_metric.metrics.append(m) ########################################################################### if not isinstance(eval_metric, mx.metric.EvalMetric): eval_metric = mx.metric.create(eval_metric) mx_resnet_print(key=mlperf_log.TRAIN_LOOP) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): mx_resnet_print(key=mlperf_log.TRAIN_EPOCH, val=epoch) tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() if isinstance(data_batch, list): self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn) except StopIteration: end_of_batch = True if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished toc = time.time() if kvstore: if kvstore.rank == 0: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) else: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data and epoch % eval_period == eval_offset: mx_resnet_print(key=mlperf_log.EVAL_START) res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default if kvstore: if kvstore.rank == 0: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) else: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) res = dict(res) acc = [res['correct-count'], res['total-count']] acc = all_reduce(acc) acc = acc[0] / acc[1] mx_resnet_print(key=mlperf_log.EVAL_ACCURACY, val={ "epoch": epoch, "value": acc }) mx_resnet_print(key=mlperf_log.EVAL_STOP) if acc > accuracy_threshold: return epoch # end of 1 epoch, reset the data-iter for another epoch train_data.reset() return num_epoch
mx.ndarray.random.normal(0, 0.01, out=arg) else: return super()._init_weight(name, arg) class BNZeroInit(mx.init.Xavier): def _init_gamma(self, name, arg): if name.endswith("bn3_gamma"): arg[:] = 0.0 else: arg[:] = 1.0 if __name__ == '__main__': LOGGER.propagate = False mx_resnet_print(key=mlperf_constants.INIT_START, uniq=False) # parse args parser = argparse.ArgumentParser( description="MLPerf RN50v1.5 training script", formatter_class=argparse.ArgumentDefaultsHelpFormatter) add_general_args(parser) fit.add_fit_args(parser) dali.add_dali_args(parser) parser.set_defaults( # network network='resnet-v1b', num_layers=50, # data resize=256,
def __init__(self, batch_size, num_threads, device_id, rec_path, idx_path, shard_id, num_shards, crop_shape, min_random_area, max_random_area, min_random_aspect_ratio, max_random_aspect_ratio, nvjpeg_padding, prefetch_queue=3, seed=12, output_layout=types.NCHW, pad_output=True, dtype='float16', mlperf_print=True): super(HybridTrainPipe, self).__init__( batch_size, num_threads, device_id, seed = seed + device_id, prefetch_queue_depth = prefetch_queue) if mlperf_print: # Shuffiling is done inside ops.MXNetReader mx_resnet_print(key=mlperf_log.INPUT_ORDER) self.input = ops.MXNetReader(path = [rec_path], index_path=[idx_path], random_shuffle=True, shard_id=shard_id, num_shards=num_shards) self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB, device_memory_padding = nvjpeg_padding, host_memory_padding = nvjpeg_padding) self.rrc = ops.RandomResizedCrop(device = "gpu", random_area = [ min_random_area, max_random_area], random_aspect_ratio = [ min_random_aspect_ratio, max_random_aspect_ratio], size = crop_shape) self.cmnp = ops.CropMirrorNormalize(device = "gpu", output_dtype = types.FLOAT16 if dtype == 'float16' else types.FLOAT, output_layout = output_layout, crop = crop_shape, pad_output = pad_output, image_type = types.RGB, mean = _mean_pixel, std = _std_pixel) self.coin = ops.CoinFlip(probability = 0.5) if mlperf_print: mx_resnet_print( key=mlperf_log.INPUT_CROP_USES_BBOXES, val=False) mx_resnet_print( key=mlperf_log.INPUT_DISTORTED_CROP_RATIO_RANGE, val=(min_random_aspect_ratio, max_random_aspect_ratio)) mx_resnet_print( key=mlperf_log.INPUT_DISTORTED_CROP_AREA_RANGE, val=(min_random_area, max_random_area)) mx_resnet_print( key=mlperf_log.INPUT_MEAN_SUBTRACTION, val=_mean_pixel) mx_resnet_print( key=mlperf_log.INPUT_RANDOM_FLIP)
if name.startswith("fc"): mx.ndarray.random.normal(0, 0.01, out=arg) else: return super()._init_weight(name, arg) class BNZeroInit(mx.init.Xavier): def _init_gamma(self, name, arg): if name.endswith("bn3_gamma"): arg[:] = 0.0 else: arg[:] = 1.0 if __name__ == '__main__': mx_resnet_print(key=mlperf_constants.INIT_START, uniq=False) # parse args parser = argparse.ArgumentParser( description="MLPerf RN50v1.5 training script", formatter_class=argparse.ArgumentDefaultsHelpFormatter) add_general_args(parser) fit.add_fit_args(parser) dali.add_dali_args(parser) parser.set_defaults( # network network='resnet-v1b', num_layers=50, # data resize=256,
def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, workspace=256, dtype='float32', memonger=False, input_layout='NCHW', conv_layout='NCHW', batchnorm_layout='NCHW', pooling_layout='NCHW', verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1, fuse_bn_relu=False, fuse_bn_add_relu=False, force_tensor_core=False, use_dali=True, smooth_alpha=0.0): """Return ResNet symbol of Parameters ---------- units : list Number of units in each stage num_stages : int Number of stage filter_list : list Channel size of each stage num_classes : int Ouput size of symbol dataset : str Dataset type, only cifar10 and imagenet supports workspace : int Workspace used in convolution operator dtype : str Precision (float32 or float16) memonger : boolean Activates "memory monger" to reduce the model's memory footprint input_layout : str interpretation (e.g. NCHW vs NHWC) of data provided by the i/o pipeline (may introduce transposes if in conflict with 'layout' above) conv_layout : str interpretation (e.g. NCHW vs NHWC) of data for convolution operation. batchnorm_layout : str directs which kernel performs the batchnorm (may introduce transposes if in conflict with 'conv_layout' above) pooling_layout : str directs which kernel performs the pooling (may introduce transposes if in conflict with 'conv_layout' above) """ act = 'relu' if fuse_bn_relu else None num_unit = len(units) assert (num_unit == num_stages) data = mx.sym.Variable(name='data') if not use_dali: # double buffering of data if dtype == 'float32': data = mx.sym.identity(data=data, name='id') else: if dtype == 'float16': data = mx.sym.Cast(data=data, dtype=np.float16) (nchannel, height, width) = image_shape # Insert transpose as needed to get the input layout to match the desired processing layout data = transform_layout(data, input_layout, conv_layout) if height <= 32: # such as cifar10 body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1, 1), pad=(1, 1), no_bias=True, name="conv0", workspace=workspace, layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=force_tensor_core) # Is this BatchNorm supposed to be here? body = batchnorm(data=body, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name='bn0', cudnn_off=cudnn_bn_off) else: # often expected to be 224 such as imagenet shape = image_shape mx_resnet_print(key=mlperf_log.MODEL_HP_INITIAL_SHAPE, val=shape) body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3), no_bias=True, name="conv0", workspace=workspace, layout=conv_layout, cudnn_algo_verbose=verbose, cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo, cudnn_tensor_core_only=force_tensor_core) shape = resnet_conv2d_log(shape, 2, filter_list[0], mlperf_log.TRUNCATED_NORMAL, False) body = batchnorm(data=body, io_layout=conv_layout, batchnorm_layout=batchnorm_layout, fix_gamma=False, eps=bn_eps, momentum=bn_mom, name='bn0', cudnn_off=cudnn_bn_off, act_type=act) shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True) if not fuse_bn_relu: body = mx.sym.Activation(data=body, act_type='relu', name='relu0') body = pooling(data=body, io_layout=conv_layout, pooling_layout=pooling_layout, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max') shape = resnet_max_pool_log(shape, 2) for i in range(num_stages): body, shape = residual_unit(body, shape, filter_list[i + 1], (1 if i == 0 else 2, 1 if i == 0 else 2), False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace, memonger=memonger, conv_layout=conv_layout, batchnorm_layout=batchnorm_layout, verbose=verbose, cudnn_bn_off=cudnn_bn_off, bn_eps=bn_eps, bn_mom=bn_mom, conv_algo=conv_algo, fuse_bn_relu=fuse_bn_relu, fuse_bn_add_relu=fuse_bn_add_relu, cudnn_tensor_core_only=force_tensor_core) for j in range(units[i] - 1): body, shape = residual_unit( body, shape, filter_list[i + 1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, workspace=workspace, memonger=memonger, conv_layout=conv_layout, batchnorm_layout=batchnorm_layout, verbose=verbose, cudnn_bn_off=cudnn_bn_off, bn_eps=bn_eps, bn_mom=bn_mom, conv_algo=conv_algo, fuse_bn_relu=fuse_bn_relu, fuse_bn_add_relu=fuse_bn_add_relu, cudnn_tensor_core_only=force_tensor_core) # Although kernel is not used here when global_pool=True, we should put one pool1 = pooling(data=body, io_layout=conv_layout, pooling_layout=pooling_layout, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') flat = mx.sym.Flatten(data=pool1) shape = (shape[0]) fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', cublas_algo_verbose=verbose) shape = resnet_dense_log(shape, num_classes) mx_resnet_print(key=mlperf_log.MODEL_HP_FINAL_SHAPE, val=shape) if dtype == 'float16': fc1 = mx.sym.Cast(data=fc1, dtype=np.float32) ########################################################################## # MXNet computes Cross Entropy loss gradients without explicitly computing # the value of loss function. # Take a look here: # https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.SoftmaxOutput # for further details ########################################################################## mx_resnet_print(key=mlperf_log.MODEL_HP_LOSS_FN, val=mlperf_log.CCE) return mx.sym.SoftmaxOutput(data=fc1, name='softmax', smooth_alpha=smooth_alpha)