Esempio n. 1
0
def get_rec_iter(args, trainpipes, valpipes, data_paths, kv=None):
    (rank, num_workers) = _get_rank_and_worker_count(args, kv)

    # now data is available in the provided paths to DALI, it ensures that the data has not been touched
    # user need to clean up the /tmp from the created symlinks
    # DALIClassificationIterator() does the init so we need to provide the real data here
    if args.dali_cache_size > 0 and args.lazy_init_sanity:
        link_to_tmp_file(args.data_train, data_paths["train_data_tmp"])
        link_to_tmp_file(args.data_train_idx, data_paths["train_idx_tmp"])
        link_to_tmp_file(args.data_val, data_paths["val_data_tmp"])
        link_to_tmp_file(args.data_val_idx, data_paths["val_idx_tmp"])

    mx_resnet_print(key=mlperf_constants.TRAIN_SAMPLES, val=args.num_examples)
    dali_train_iter = DALIClassificationIterator(
        trainpipes, args.num_examples // num_workers)

    if args.num_examples < trainpipes[0].epoch_size("Reader"):
        warnings.warn(
            "{} training examples will be used, although full training set contains {} examples"
            .format(args.num_examples, trainpipes[0].epoch_size("Reader")))

    worker_val_examples = valpipes[0].epoch_size("Reader")
    mx_resnet_print(key=mlperf_constants.EVAL_SAMPLES, val=worker_val_examples)
    if not args.separ_val:
        worker_val_examples = worker_val_examples // num_workers
        if rank < valpipes[0].epoch_size("Reader") % num_workers:
            worker_val_examples += 1

    dali_val_iter = DALIClassificationIterator(
        valpipes, worker_val_examples,
        fill_last_batch=False) if args.data_val else None

    return dali_train_iter, dali_val_iter
Esempio n. 2
0
def _get_lr_scheduler(args, kv):
    if 'lr_factor' not in args or args.lr_factor >= 1:
        return (args.lr, None)
    epoch_size = get_epoch_size(args, kv)
    begin_epoch = 0
    mx_resnet_print(key=mlperf_constants.OPT_BASE_LR, val=args.lr)
    mx_resnet_print(key=mlperf_constants.OPT_LR_WARMUP_EPOCHS,
                    val=args.warmup_epochs)

    if 'pow' in args.lr_step_epochs:
        lr = args.lr
        max_up = args.num_epochs * epoch_size
        pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs))
        poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr)
        return (lr, poly_sched)
    step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] if len(
        args.lr_step_epochs.strip()) else []
    lr = args.lr
    for s in step_epochs:
        if begin_epoch >= s:
            lr *= args.lr_factor
    if lr != args.lr:
        logging.info('Adjust learning rate to %e for epoch %d', lr,
                     begin_epoch)

    steps = [
        epoch_size * (x - begin_epoch) for x in step_epochs
        if x - begin_epoch > 0
    ]
    if steps:
        if kv:
            num_workers = kv.num_workers
        else:
            num_workers = 1
        epoch_size = math.ceil(
            int(args.num_examples / num_workers) / args.batch_size)
        return (lr,
                mx.lr_scheduler.MultiFactorScheduler(
                    step=steps,
                    factor=args.lr_factor,
                    base_lr=args.lr,
                    warmup_steps=epoch_size * args.warmup_epochs,
                    warmup_mode=args.warmup_strategy))
    else:
        return (lr, None)
Esempio n. 3
0
    def __init__(self,
                 batch_size,
                 num_threads,
                 device_id,
                 rec_path,
                 idx_path,
                 shard_id,
                 num_shards,
                 crop_shape,
                 nvjpeg_padding,
                 prefetch_queue=3,
                 seed=12,
                 resize_shp=None,
                 output_layout=types.NCHW,
                 pad_output=True,
                 dtype='float16',
                 mlperf_print=True):

        super(HybridValPipe,
              self).__init__(batch_size,
                             num_threads,
                             device_id,
                             seed=seed + device_id,
                             prefetch_queue_depth=prefetch_queue)

        self.input = ops.MXNetReader(path=[rec_path],
                                     index_path=[idx_path],
                                     random_shuffle=False,
                                     shard_id=shard_id,
                                     num_shards=num_shards)

        self.decode = ops.nvJPEGDecoder(device="mixed",
                                        output_type=types.RGB,
                                        device_memory_padding=nvjpeg_padding,
                                        host_memory_padding=nvjpeg_padding)

        self.resize = ops.Resize(
            device="gpu", resize_shorter=resize_shp) if resize_shp else None

        self.cmnp = ops.CropMirrorNormalize(
            device="gpu",
            output_dtype=types.FLOAT16 if dtype == 'float16' else types.FLOAT,
            output_layout=output_layout,
            crop=crop_shape,
            pad_output=pad_output,
            image_type=types.RGB,
            mean=_mean_pixel,
            std=_std_pixel)

        if mlperf_print:
            mx_resnet_print(key=mlperf_log.INPUT_MEAN_SUBTRACTION,
                            val=_mean_pixel)
            mx_resnet_print(key=mlperf_log.INPUT_RESIZE_ASPECT_PRESERVING)
            mx_resnet_print(key=mlperf_log.INPUT_CENTRAL_CROP)
Esempio n. 4
0
def get_rec_iter(args, trainpipes, valpipes, cvalpipes, kv=None):
    rank = kv.rank if kv else 0
    nWrk = kv.num_workers if kv else 1

    dali_train_iter = DALIClassificationIterator(trainpipes,
                                                 args.num_examples // nWrk)
    if args.no_augument_epoch < args.num_epochs:
        dali_cval_iter = DALIClassificationIterator(cvalpipes,
                                                    args.num_examples // nWrk)
    else:
        dali_cval_iter = None

    mx_resnet_print(key=mlperf_log.INPUT_SIZE,
                    val=trainpipes[0].epoch_size("Reader"))

    mx_resnet_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES,
                    val=trainpipes[0].epoch_size("Reader"))

    if args.data_val:
        mx_resnet_print(key=mlperf_log.EVAL_SIZE,
                        val=valpipes[0].epoch_size("Reader"))

        mx_resnet_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES,
                        val=valpipes[0].epoch_size("Reader"))

    if args.num_examples < trainpipes[0].epoch_size("Reader"):
        warnings.warn(
            "{} training examples will be used, although full training set contains {} examples"
            .format(args.num_examples, trainpipes[0].epoch_size("Reader")))

    worker_val_examples = valpipes[0].epoch_size("Reader")
    if not args.separ_val:
        worker_val_examples = worker_val_examples // nWrk
        if rank < valpipes[0].epoch_size("Reader") % nWrk:
            worker_val_examples += 1

    dali_val_iter = DALIClassificationIterator(
        valpipes, worker_val_examples,
        fill_last_batch=False) if args.data_val else None
    return dali_train_iter, dali_val_iter, dali_cval_iter
Esempio n. 5
0
def get_rec_iter(args, kv=None):
    # resize is default base length of shorter edge for dataset;
    # all images will be reshaped to this size
    resize = int(args.resize)
    # target shape is final shape of images pipelined to network;
    # all images will be cropped to this size
    target_shape = tuple([int(l) for l in args.image_shape.split(',')])

    pad_output = target_shape[0] == 4
    gpus = list(map(int, filter(None, args.gpus.split(',')))) # filter to not encount eventually empty strings
    batch_size = args.batch_size//len(gpus)
    
    mx_resnet_print(
            key=mlperf_log.INPUT_BATCH_SIZE,
            val=batch_size) # TODO MPI WORLD SIZE
    
    num_threads = args.dali_threads

    # the input_layout w.r.t. the model is the output_layout of the image pipeline
    output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW

    rank = kv.rank if kv else 0
    nWrk = kv.num_workers if kv else 1


    trainpipes = [HybridTrainPipe(batch_size      = batch_size,
                                  num_threads     = num_threads,
                                  device_id       = gpu_id,
                                  rec_path        = args.data_train,
                                  idx_path        = args.data_train_idx,
                                  shard_id        = gpus.index(gpu_id) + len(gpus)*rank,
                                  num_shards      = len(gpus)*nWrk,
                                  crop_shape      = target_shape[1:],
                                  min_random_area = args.min_random_area,
                                  max_random_area = args.max_random_area,
                                  min_random_aspect_ratio = args.min_random_aspect_ratio,
                                  max_random_aspect_ratio = args.max_random_aspect_ratio,
                                  nvjpeg_padding  = args.dali_nvjpeg_memory_padding * 1024 * 1024,
                                  prefetch_queue  = args.dali_prefetch_queue,
                                  seed            = args.seed,
                                  output_layout   = output_layout,
                                  pad_output      = pad_output,
                                  dtype           = args.dtype,
                                  mlperf_print    = gpu_id == gpus[0]) for gpu_id in gpus]

    valpipes = [HybridValPipe(batch_size     = batch_size,
                              num_threads    = num_threads,
                              device_id      = gpu_id,
                              rec_path       = args.data_val,
                              idx_path       = args.data_val_idx,
                              shard_id       = 0 if args.separ_val
                                                 else gpus.index(gpu_id) + len(gpus)*rank,
                              num_shards     = 1 if args.separ_val else len(gpus)*nWrk,
                              crop_shape     = target_shape[1:],
                              nvjpeg_padding = args.dali_nvjpeg_memory_padding * 1024 * 1024,
                              prefetch_queue = args.dali_prefetch_queue,
                              seed           = args.seed,
                              resize_shp     = resize,
                              output_layout  = output_layout,
                              pad_output     = pad_output,
                              dtype          = args.dtype,
                              mlperf_print   = gpu_id == gpus[0]) for gpu_id in gpus] if args.data_val else None
    
    trainpipes[0].build()
    if args.data_val:
        valpipes[0].build()

    mx_resnet_print(
            key=mlperf_log.INPUT_SIZE,
            val=trainpipes[0].epoch_size("Reader"))

    mx_resnet_print(
            key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES,
            val=trainpipes[0].epoch_size("Reader"))


    if args.data_val:
        mx_resnet_print(
                key=mlperf_log.EVAL_SIZE,
                val=valpipes[0].epoch_size("Reader"))

        mx_resnet_print(
                key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES,
                val=valpipes[0].epoch_size("Reader"))


    if args.num_examples < trainpipes[0].epoch_size("Reader"):
        warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader")))
    dali_train_iter = DALIClassificationIterator(trainpipes, args.num_examples // nWrk)

    worker_val_examples = valpipes[0].epoch_size("Reader")
    if not args.separ_val:
        worker_val_examples = worker_val_examples // nWrk
        if rank < valpipes[0].epoch_size("Reader") % nWrk:
            worker_val_examples += 1

    dali_val_iter = DALIClassificationIterator(valpipes, worker_val_examples, fill_last_batch = False) if args.data_val else None
    return dali_train_iter, dali_val_iter
Esempio n. 6
0
def build_input_pipeline(args, kv=None):
    # resize is default base length of shorter edge for dataset;
    # all images will be reshaped to this size
    resize = int(args.resize)
    # target shape is final shape of images pipelined to network;
    # all images will be cropped to this size
    target_shape = tuple([int(l) for l in args.image_shape.split(',')])

    pad_output = target_shape[0] == 4
    gpus = list(map(int, filter(None, args.gpus.split(
        ','))))  # filter to not encount eventually empty strings
    batch_size = args.batch_size // len(gpus)

    mx_resnet_print(key=mlperf_constants.MODEL_BN_SPAN, val=batch_size)

    num_threads = args.dali_threads

    # the input_layout w.r.t. the model is the output_layout of the image pipeline
    output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW

    (rank, num_workers) = _get_rank_and_worker_count(args, kv)

    data_paths = {}
    if args.dali_cache_size > 0 and args.lazy_init_sanity:
        data_paths["train_data_tmp"] = get_tmp_file()
        data_paths["train_idx_tmp"] = get_tmp_file()
        data_paths["val_data_tmp"] = get_tmp_file()
        data_paths["val_idx_tmp"] = get_tmp_file()
    else:
        data_paths["train_data_tmp"] = args.data_train
        data_paths["train_idx_tmp"] = args.data_train_idx
        data_paths["val_data_tmp"] = args.data_val
        data_paths["val_idx_tmp"] = args.data_val_idx

    trainpipes = [
        HybridTrainPipe(batch_size=batch_size,
                        num_threads=num_threads,
                        device_id=gpu_id,
                        rec_path=data_paths["train_data_tmp"],
                        idx_path=data_paths["train_idx_tmp"],
                        shard_id=gpus.index(gpu_id) + len(gpus) * rank,
                        num_shards=len(gpus) * num_workers,
                        crop_shape=target_shape[1:],
                        min_random_area=args.min_random_area,
                        max_random_area=args.max_random_area,
                        min_random_aspect_ratio=args.min_random_aspect_ratio,
                        max_random_aspect_ratio=args.max_random_aspect_ratio,
                        nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 *
                        1024,
                        prefetch_queue=args.dali_prefetch_queue,
                        seed=args.seed,
                        output_layout=output_layout,
                        pad_output=pad_output,
                        dtype=args.dtype,
                        mlperf_print=gpu_id == gpus[0],
                        use_roi_decode=args.dali_roi_decode,
                        cache_size=args.dali_cache_size) for gpu_id in gpus
    ]

    valpipes = [
        HybridValPipe(batch_size=batch_size,
                      num_threads=num_threads,
                      device_id=gpu_id,
                      rec_path=data_paths["val_data_tmp"],
                      idx_path=data_paths["val_idx_tmp"],
                      shard_id=0 if args.separ_val else gpus.index(gpu_id) +
                      len(gpus) * rank,
                      num_shards=1 if args.separ_val else len(gpus) *
                      num_workers,
                      crop_shape=target_shape[1:],
                      nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 *
                      1024,
                      prefetch_queue=args.dali_prefetch_queue,
                      seed=args.seed,
                      resize_shp=resize,
                      output_layout=output_layout,
                      pad_output=pad_output,
                      dtype=args.dtype,
                      mlperf_print=gpu_id == gpus[0],
                      cache_size=args.dali_cache_size) for gpu_id in gpus
    ] if args.data_val else None

    [trainpipe.build() for trainpipe in trainpipes]

    if args.data_val:
        [valpipe.build() for valpipe in valpipes]

    return lambda args, kv: get_rec_iter(args, trainpipes, valpipes,
                                         data_paths, kv)
Esempio n. 7
0
def build_input_pipeline(args, kv=None):
    # resize is default base length of shorter edge for dataset;
    # all images will be reshaped to this size
    resize = int(args.resize)
    # target shape is final shape of images pipelined to network;
    # all images will be cropped to this size
    target_shape = tuple([int(l) for l in args.image_shape.split(',')])

    pad_output = target_shape[0] == 4
    gpus = list(map(int, filter(None, args.gpus.split(
        ','))))  # filter to not encount eventually empty strings
    batch_size = args.batch_size // len(gpus)

    mx_resnet_print(key=mlperf_constants.MODEL_BN_SPAN, val=batch_size)

    num_threads = args.dali_threads

    # the input_layout w.r.t. the model is the output_layout of the image pipeline
    output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW

    rank = kv.rank if kv else 0
    nWrk = kv.num_workers if kv else 1

    trainpipes = [
        HybridTrainPipe(batch_size=batch_size,
                        num_threads=num_threads,
                        device_id=gpu_id,
                        rec_path=args.data_train,
                        idx_path=args.data_train_idx,
                        shard_id=gpus.index(gpu_id) + len(gpus) * rank,
                        num_shards=len(gpus) * nWrk,
                        crop_shape=target_shape[1:],
                        min_random_area=args.min_random_area,
                        max_random_area=args.max_random_area,
                        min_random_aspect_ratio=args.min_random_aspect_ratio,
                        max_random_aspect_ratio=args.max_random_aspect_ratio,
                        nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 *
                        1024,
                        prefetch_queue=args.dali_prefetch_queue,
                        seed=args.seed,
                        output_layout=output_layout,
                        pad_output=pad_output,
                        dtype=args.dtype,
                        mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus
    ]

    if args.no_augument_epoch >= args.num_epochs:
        cvalpipes = None
    else:
        cvalpipes = [
            HybridTrainPipe(
                batch_size=batch_size,
                num_threads=num_threads,
                device_id=gpu_id,
                rec_path=args.data_train,
                idx_path=args.data_train_idx,
                shard_id=gpus.index(gpu_id) + len(gpus) * rank,
                num_shards=len(gpus) * nWrk,
                crop_shape=target_shape[1:],
                min_random_area=args.min_random_area_2,
                max_random_area=args.max_random_area_2,
                min_random_aspect_ratio=args.min_random_aspect_ratio_2,
                max_random_aspect_ratio=args.max_random_aspect_ratio_2,
                nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 * 1024,
                prefetch_queue=args.dali_prefetch_queue,
                seed=args.seed,
                output_layout=output_layout,
                pad_output=pad_output,
                dtype=args.dtype,
                mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus
        ] if args.use_new_cval else [
            HybridCvalPipe(batch_size=batch_size,
                           num_threads=num_threads,
                           device_id=gpu_id,
                           rec_path=args.data_train,
                           idx_path=args.data_train_idx,
                           shard_id=gpus.index(gpu_id) + len(gpus) * rank,
                           num_shards=len(gpus) * nWrk,
                           crop_shape=target_shape[1:],
                           nvjpeg_padding=args.dali_nvjpeg_memory_padding *
                           1024 * 1024,
                           prefetch_queue=args.dali_prefetch_queue,
                           seed=args.seed,
                           resize_shp=resize,
                           output_layout=output_layout,
                           pad_output=pad_output,
                           dtype=args.dtype,
                           mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus
        ]

    valpipes = [
        HybridValPipe(batch_size=batch_size,
                      num_threads=num_threads,
                      device_id=gpu_id,
                      rec_path=args.data_val,
                      idx_path=args.data_val_idx,
                      shard_id=0 if args.separ_val else gpus.index(gpu_id) +
                      len(gpus) * rank,
                      num_shards=1 if args.separ_val else len(gpus) * nWrk,
                      crop_shape=target_shape[1:],
                      nvjpeg_padding=args.dali_nvjpeg_memory_padding * 1024 *
                      1024,
                      prefetch_queue=args.dali_prefetch_queue,
                      seed=args.seed,
                      resize_shp=resize,
                      output_layout=output_layout,
                      pad_output=pad_output,
                      dtype=args.dtype,
                      mlperf_print=gpu_id == gpus[0]) for gpu_id in gpus
    ] if args.data_val else None

    trainpipes[0].build()
    if args.no_augument_epoch < args.num_epochs:
        cvalpipes[0].build()
    if args.data_val:
        valpipes[0].build()

    return lambda args, kv: get_rec_iter(args, trainpipes, valpipes, cvalpipes,
                                         kv)
def residual_unit(data,
                  shape,
                  num_filter,
                  stride,
                  dim_match,
                  name,
                  bottle_neck=True,
                  workspace=256,
                  memonger=False,
                  conv_layout='NCHW',
                  batchnorm_layout='NCHW',
                  verbose=False,
                  cudnn_bn_off=False,
                  bn_eps=2e-5,
                  bn_mom=0.9,
                  conv_algo=-1,
                  fuse_bn_relu=False,
                  fuse_bn_add_relu=False,
                  cudnn_tensor_core_only=False):
    """Return ResNet Unit symbol for building ResNet
    Parameters
    ----------
    data : str
        Input data
    num_filter : int
        Number of output channels
    bnf : int
        Bottle neck channels factor with regard to num_filter
    stride : tuple
        Stride used in convolution
    dim_match : Boolean
        True means channel number between input and output is the same, otherwise means differ
    name : str
        Base name of the operators
    workspace : int
        Workspace used in convolution operator
    """
    input_shape = shape
    act = 'relu' if fuse_bn_relu else None
    if bottle_neck:
        shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK)
        conv1 = mx.sym.Convolution(
            data=data,
            num_filter=int(num_filter * 0.25),
            kernel=(1, 1),
            stride=(1, 1),
            pad=(0, 0),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv1',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, 1, int(num_filter * 0.25),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        bn1 = batchnorm(data=conv1,
                        io_layout=conv_layout,
                        batchnorm_layout=batchnorm_layout,
                        fix_gamma=False,
                        eps=bn_eps,
                        momentum=bn_mom,
                        name=name + '_bn1',
                        cudnn_off=cudnn_bn_off,
                        act_type=act)
        shape = resnet_batchnorm_log(shape,
                                     momentum=bn_mom,
                                     eps=bn_eps,
                                     center=True,
                                     scale=True,
                                     training=True)

        act1 = mx.sym.Activation(
            data=bn1, act_type='relu', name=name +
            '_relu1') if not fuse_bn_relu else bn1
        shape = resnet_relu_log(shape)

        conv2 = mx.sym.Convolution(
            data=act1,
            num_filter=int(num_filter * 0.25),
            kernel=(3, 3),
            stride=stride,
            pad=(1, 1),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv2',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, stride, int(num_filter * 0.25),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        bn2 = batchnorm(data=conv2,
                        io_layout=conv_layout,
                        batchnorm_layout=batchnorm_layout,
                        fix_gamma=False,
                        eps=bn_eps,
                        momentum=bn_mom,
                        name=name + '_bn2',
                        cudnn_off=cudnn_bn_off,
                        act_type=act)
        shape = resnet_batchnorm_log(shape,
                                     momentum=bn_mom,
                                     eps=bn_eps,
                                     center=True,
                                     scale=True,
                                     training=True)

        act2 = mx.sym.Activation(
            data=bn2, act_type='relu', name=name +
            '_relu2') if not fuse_bn_relu else bn2
        shape = resnet_relu_log(shape)

        conv3 = mx.sym.Convolution(
            data=act2,
            num_filter=num_filter,
            kernel=(1, 1),
            stride=(1, 1),
            pad=(0, 0),
            no_bias=True,
            workspace=workspace,
            name=name + '_conv3',
            layout=conv_layout,
            cudnn_algo_verbose=verbose,
            cudnn_algo_fwd=conv_algo,
            cudnn_algo_bwd_data=conv_algo,
            cudnn_algo_bwd_filter=conv_algo,
            cudnn_tensor_core_only=cudnn_tensor_core_only)
        shape = resnet_conv2d_log(shape, 1, int(num_filter),
                                  mlperf_log.TRUNCATED_NORMAL, False)

        if dim_match:
            shortcut = data
        else:
            conv1sc = mx.sym.Convolution(
                data=data,
                num_filter=num_filter,
                kernel=(1, 1),
                stride=stride,
                no_bias=True,
                workspace=workspace,
                name=name + '_conv1sc',
                layout=conv_layout,
                cudnn_algo_verbose=verbose,
                cudnn_algo_fwd=conv_algo,
                cudnn_algo_bwd_data=conv_algo,
                cudnn_algo_bwd_filter=conv_algo,
                cudnn_tensor_core_only=cudnn_tensor_core_only)
            proj_shape = resnet_conv2d_log(input_shape, stride,
                                           int(num_filter),
                                           mlperf_log.TRUNCATED_NORMAL, False)
            shortcut = batchnorm(data=conv1sc,
                                 io_layout=conv_layout,
                                 batchnorm_layout=batchnorm_layout,
                                 fix_gamma=False,
                                 eps=bn_eps,
                                 momentum=bn_mom,
                                 name=name + '_sc',
                                 cudnn_off=cudnn_bn_off)
            proj_shape = resnet_batchnorm_log(proj_shape,
                                              momentum=bn_mom,
                                              eps=bn_eps,
                                              center=True,
                                              scale=True,
                                              training=True)
            proj_shape = resnet_projection_log(input_shape, proj_shape)
        if memonger:
            shortcut._set_attr(mirror_stage='True')
        if fuse_bn_add_relu:
            shape = resnet_batchnorm_log(shape,
                                         momentum=bn_mom,
                                         eps=bn_eps,
                                         center=True,
                                         scale=True,
                                         training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            return batchnorm_add_relu(data=conv3,
                                      addend=shortcut,
                                      io_layout=conv_layout,
                                      batchnorm_layout=batchnorm_layout,
                                      fix_gamma=False,
                                      eps=bn_eps,
                                      momentum=bn_mom,
                                      name=name + '_bn3',
                                      cudnn_off=cudnn_bn_off), shape
        else:
            bn3 = batchnorm(data=conv3,
                            io_layout=conv_layout,
                            batchnorm_layout=batchnorm_layout,
                            fix_gamma=False,
                            eps=bn_eps,
                            momentum=bn_mom,
                            name=name + '_bn3',
                            cudnn_off=cudnn_bn_off)
            shape = resnet_batchnorm_log(shape,
                                         momentum=bn_mom,
                                         eps=bn_eps,
                                         center=True,
                                         scale=True,
                                         training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            return mx.sym.Activation(data=bn3 + shortcut,
                                     act_type='relu',
                                     name=name + '_relu3'), shape

    else:
        raise NotImplementedError
def residual_unit_norm_conv(data, nchw_inshape, num_filter, stride, dim_match, name, bottle_neck=True,
                  workspace=256, memonger=False, conv_layout='NCHW', batchnorm_layout='NCHW',
                  verbose=False, cudnn_bn_off=False, bn_eps=2e-5, bn_mom=0.9, conv_algo=-1,
                  fuse_bn_relu=False, fuse_bn_add_relu=False, cudnn_tensor_core_only=False):
    """Return ResNet Unit symbol for building ResNet
    Parameters
    ----------
    data : str
        Input data
    nchw_inshape : tuple of int
        Input minibatch shape in (n, c, h, w) format independent of actual layout
    num_filter : int
        Number of output channels
    bnf : int
        Bottle neck channels factor with regard to num_filter
    stride : tuple
        Stride used in convolution
    dim_match : Boolean
        True means channel number between input and output is the same, otherwise means differ
    name : str
        Base name of the operators
    workspace : int
        Workspace used in convolution operator

    Returns
    -------
    (sym, nchw_outshape)

    sym : the model symbol (up to this point)

    nchw_output : tuple
        ( batch_size, features, height, width)
    """

    batch_size = nchw_inshape[0]
    shape = nchw_inshape[1:]
    act = 'relu' if fuse_bn_relu else None
    if bottle_neck:
        shape = resnet_begin_block_log(shape, mlperf_log.BOTTLENECK_BLOCK)
        # 1st NormalizedConvolution: [no Stats Apply] [no Relu] Convolution Stats-Gen
        (conv1, conv1_sum, conv1_sum_squares) = \
            mx.sym.NormalizedConvolution(data, no_equiv_scale_bias=True, act_type=None,
                                         num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
                                         name=name + '_conv1', layout=conv_layout)
        shape = resnet_conv2d_log(shape, 1, int(num_filter*0.25), mlperf_log.TRUNCATED_NORMAL, False)

        # As prep for 2nd NormalizedConvolution: Finalize kernel converts sum,sum_squares to equiv_scale,equiv_bias
        elem_count = element_count((batch_size,) + shape)
        (bn1_equiv_scale, bn1_equiv_bias, bn1_saved_mean, bn1_saved_inv_std, bn1_gamma_out, bn1_beta_out) = \
            mx.sym.BNStatsFinalize(sum=conv1_sum, sum_squares=conv1_sum_squares,
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   output_mean_var=True, elem_count=elem_count, name=name + '_bn1')
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

        # Second NormalizedConvolution: Stats-Apply Relu Convolution Stats-Gen
        (conv2, conv2_sum, conv2_sum_squares) = \
            mx.sym.NormalizedConvolution(conv1, no_equiv_scale_bias=False, act_type='relu',
                                         num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1),
                                         equiv_scale=bn1_equiv_scale, equiv_bias=bn1_equiv_bias,
                                         mean=bn1_saved_mean, var=bn1_saved_inv_std, gamma=bn1_gamma_out, beta=bn1_beta_out,
                                         name=name + '_conv2', layout=conv_layout)
        shape = resnet_relu_log(shape)
        shape = resnet_conv2d_log(shape, stride, int(num_filter*0.25), mlperf_log.TRUNCATED_NORMAL, False)

        # As prep for 3nd NormalizedConvolution: Finalize kernel converts sum,sum_squares to equiv_scale,equiv_bias
        elem_count = element_count((batch_size,) + shape)
        (bn2_equiv_scale, bn2_equiv_bias, bn2_saved_mean, bn2_saved_inv_std, bn2_gamma_out, bn2_beta_out) = \
            mx.sym.BNStatsFinalize(sum=conv2_sum, sum_squares=conv2_sum_squares,
                                   eps=bn_eps, momentum=bn_mom, fix_gamma=False,
                                   output_mean_var=True, elem_count=elem_count, name=name + '_bn2')
        shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)

        # Third NormalizedConvolution: Stats-Apply Relu Convolution [no Stats-Gen]
        # The 'no stats-gen' mode is triggered by just not using the stats outputs for anything.
        (conv3, _, _) = \
            mx.sym.NormalizedConvolution(conv2, no_equiv_scale_bias=False, act_type='relu',
                                         num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
                                         equiv_scale=bn2_equiv_scale, equiv_bias=bn2_equiv_bias,
                                         mean=bn2_saved_mean, var=bn2_saved_inv_std, gamma=bn2_gamma_out, beta=bn2_beta_out,
                                         name=name + '_conv3', layout=conv_layout)
        shape = resnet_relu_log(shape)
        shape = resnet_conv2d_log(shape, 1, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False)

        if dim_match:
            shortcut = data
        else:
            conv1sc = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
                                            workspace=workspace, name=name+'_conv1sc', layout=conv_layout,
                                         cudnn_algo_verbose=verbose,
                                         cudnn_algo_fwd=conv_algo, cudnn_algo_bwd_data=conv_algo, cudnn_algo_bwd_filter=conv_algo,
                                         cudnn_tensor_core_only=cudnn_tensor_core_only)
            proj_shape = resnet_conv2d_log(nchw_inshape[1:], stride, int(num_filter), mlperf_log.TRUNCATED_NORMAL, False)
            shortcut = batchnorm(data=conv1sc, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                                 fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn_sc', cudnn_off=cudnn_bn_off)
            proj_shape = resnet_batchnorm_log(proj_shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
            proj_shape = resnet_projection_log(nchw_inshape[1:], proj_shape)
        if memonger:
            shortcut._set_attr(mirror_stage='True')
        if fuse_bn_add_relu:
            shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            nchw_shape = (batch_size, ) + shape
            return batchnorm_add_relu(data=conv3, addend=shortcut, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                            fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off), nchw_shape
        else:
            bn3 = batchnorm(data=conv3, io_layout=conv_layout, batchnorm_layout=batchnorm_layout,
                            fix_gamma=False, eps=bn_eps, momentum=bn_mom, name=name + '_bn3', cudnn_off=cudnn_bn_off)
            shape = resnet_batchnorm_log(shape, momentum=bn_mom, eps=bn_eps, center=True, scale=True, training=True)
            shape = resnet_end_block_log(shape)
            mx_resnet_print(key=mlperf_log.MODEL_HP_SHORTCUT_ADD)
            shape = resnet_relu_log(shape)
            nchw_shape = (batch_size, ) + shape
            return mx.sym.Activation(data=bn3 + shortcut, act_type='relu', name=name + '_relu3'), nchw_shape

    else:
        raise NotImplementedError 
Esempio n. 10
0
def fit(args, kv, model, initializer, data_loader, devs, arg_params,
        aux_params, **kwargs):
    """
    train a model
    args : argparse returns
    model : loaded model of the neural network
    initializer : weight initializer
    data_loader : function that returns the train and val data iterators
    devs : devices for training
    arg_params : model parameters
    aux_params : model parameters
    """
    if args.profile_server_suffix:
        mx.profiler.set_config(filename=args.profile_server_suffix,
                               profile_all=True,
                               profile_process='server')
        mx.profiler.set_state(state='run', profile_process='server')

    if args.profile_worker_suffix:
        if kv.num_workers > 1:
            filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
        else:
            filename = args.profile_worker_suffix
        mx.profiler.set_config(filename=filename,
                               profile_all=True,
                               profile_process='worker')
        mx.profiler.set_state(state='run', profile_process='worker')

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)

    epoch_size = get_epoch_size(args, kv)

    # save model
    epoch_end_callbacks = []

    # learning rate
    lr, lr_scheduler = _get_lr_scheduler(args, kv)

    total_steps = math.ceil(args.num_examples * args.num_epochs /
                            kv.num_workers / args.batch_size)
    warmup_steps = get_epoch_size(args, kv) * args.warmup_epochs

    if args.decay_steps > 0:
        decay_steps = args.decay_steps
        decay_epochs = math.ceil(args.decay_steps / get_epoch_size(args, kv))
    else:
        if args.decay_after_warmup:
            decay_steps = total_steps - warmup_steps
            decay_epochs = args.num_epochs - args.warmup_epochs
        else:
            decay_steps = total_steps
            decay_epochs = args.num_epochs

    explorer_params = {
        'burn_in_iter':
        args.burn_in,
        'lr_range_max':
        args.lr_range_max,
        'lr_range_min':
        args.lr_range_min,
        'wd_range_max':
        args.wd_range_max,
        'wd_range_min':
        args.wd_range_min,
        'cg_range_max':
        args.cg_range_max,
        'cg_range_min':
        args.cg_range_min,
        'start_epoch':
        0,
        'end_epoch':
        args.num_epochs,
        'lr_decay':
        args.lr_decay,
        'wd_decay':
        args.wd_decay if args.wd_decay != -1 else args.lr_decay,
        'num_grps':
        args.num_grps,
        'num_cg_grps':
        args.num_cg_grps,
        'lr_decay_mode':
        args.lr_decay_mode,
        'wd_decay_mode':
        args.wd_decay_mode,
        'wd_step_epochs':
        [int(s.strip()) for s in args.wd_step_epochs.split(',')]
        if len(args.wd_step_epochs.strip()) else [],
        'wd_factor':
        args.wd_factor,
        'lr_rate':
        args.lr,
        'warmup_step':
        warmup_steps,
        'warmup_epochs':
        args.warmup_epochs,
        'wd_warmup':
        args.wd_warmup,
        'ds_upper_factor':
        args.ds_upper_factor,
        'ds_lower_factor':
        args.ds_lower_factor,
        'ds_fix_min':
        args.ds_fix_min,
        'ds_fix_max':
        args.ds_fix_max,
        'explore_freq':
        args.explore_freq,
        'natan_turn_epoch':
        args.natan_turn_epoch,
        'natan_final_ratio':
        args.natan_final_ratio,
        'explore_start_epoch':
        args.explore_start_epoch,
        'momentum':
        args.mom,
        'momentum_end':
        args.mom_end if args.mom_end else args.mom,
        'epoch_size':
        epoch_size,
        'smooth_decay':
        args.smooth_decay,
        'add_one_fwd_epoch':
        args.add_one_fwd_epoch
        if args.add_one_fwd_epoch is not None else args.num_epochs,
        'no_augument_epoch':
        args.no_augument_epoch
        if args.no_augument_epoch is not None else args.num_epochs,
        'decay_after_warmup':
        args.decay_after_warmup,
        'end_lr_ratio':
        args.end_lr_ratio,
        'total_steps':
        total_steps,
        'decay_steps':
        decay_steps,
        'decay_epochs':
        decay_epochs,
    }

    optimizer_params = {
        'learning_rate': lr,
        'wd': args.wd * args.lr,
        'lr_scheduler': lr_scheduler,
        'multi_precision': True
    }

    mx_resnet_print(key=mlperf_constants.OPT_NAME, val='lars')  #args.optimizer

    mx_resnet_print(key=mlperf_constants.LARS_EPSILON, val=1e-9)

    mx_resnet_print(key=mlperf_constants.LARS_OPT_WEIGHT_DECAY, val=args.wd)

    mx_resnet_print(key=mlperf_constants.LARS_OPT_LR_DECAY_POLY_POWER,
                    val=args.lr_decay)

    mx_resnet_print(key=mlperf_constants.LARS_OPT_END_LR,
                    val=args.lr * args.end_lr_ratio)

    mx_resnet_print(key=mlperf_constants.LARS_OPT_LR_DECAY_STEPS,
                    val=decay_steps)

    ##########################################################################
    # MXNet excludes BN layers from L2 Penalty by default,
    # so this won't be explicitly stated anywhere in the code
    ##########################################################################
    mx_resnet_print(key=mlperf_log.MODEL_EXCLUDE_BN_FROM_L2, val=True)

    # Only a limited number of optimizers have 'momentum' property
    has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
    if args.optimizer in has_momentum:
        optimizer_params['momentum'] = args.mom

    if args.optimizer == 'sgd':
        optimizer_params['bias_wd'] = args.bias_wd
        optimizer_params['bn_lr_decay'] = args.bn_lr_decay

    ### copy from nvidia-mxnet/3rdparty/horovod/example/mxnet/common/fit.py
    # A limited number of optimizers have a warmup period
    has_warmup = {'lbsgd', 'lbnag'}
    if args.optimizer in has_warmup:
        nworkers = kv.num_workers
        epoch_size = args.num_examples / args.batch_size / nworkers

        if epoch_size < 1:
            epoch_size = 1
        macrobatch_size = args.macrobatch_size
        if macrobatch_size < args.batch_size * nworkers:
            macrobatch_size = args.batch_size * nworkers
        #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
        batch_scale = math.ceil(
            float(macrobatch_size) / args.batch_size / nworkers)
        optimizer_params['updates_per_epoch'] = epoch_size
        #optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
        optimizer_params['batch_scale'] = batch_scale
        optimizer_params['warmup_strategy'] = args.warmup_strategy
        optimizer_params['warmup_epochs'] = args.warmup_epochs
        optimizer_params['num_epochs'] = args.num_epochs
    ###

    # evaluation metrices
    eval_metrics = ['accuracy']
    if args.top_k > 0:
        eval_metrics.append(
            mx.metric.create('top_k_accuracy', top_k=args.top_k))

    # callbacks that run after each batch
    batch_end_callbacks = []
    if 'horovod' in args.kv_store:
        # if using horovod, only report on rank 0 with global batch size
        if kv.rank == 0:
            batch_end_callbacks.append(
                mx.callback.Speedometer(kv.num_workers * args.batch_size,
                                        args.disp_batches))
        mx_resnet_print(key=mlperf_constants.GLOBAL_BATCH_SIZE,
                        val=kv.num_workers * args.batch_size)
    else:
        batch_end_callbacks.append(
            mx.callback.Speedometer(args.batch_size, args.disp_batches))
        mx_resnet_print(key=mlperf_constants.GLOBAL_BATCH_SIZE,
                        val=args.batch_size)

    mx_resnet_print(key=mlperf_log.EVAL_TARGET, val=args.accuracy_threshold)

    # run
    last_epoch = mlperf_fit(
        model,
        args,
        data_loader,
        epoch_size,
        begin_epoch=0,
        num_epoch=args.num_epochs,
        eval_metric=eval_metrics,
        kvstore=kv,
        optimizer=args.optimizer,
        optimizer_params=optimizer_params,
        explorer=args.explorer,
        explorer_params=explorer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        batch_end_callback=batch_end_callbacks,
        epoch_end_callback=
        epoch_end_callbacks,  #checkpoint if args.use_dali else ,,
        allow_missing=True,
        eval_offset=args.eval_offset,
        eval_period=args.eval_period,
        accuracy_threshold=args.accuracy_threshold)

    if ('horovod' not in args.kv_store) or kv.rank == 0:
        arg_params, aux_params = model.get_params()
        model.set_params(arg_params, aux_params)
        model.save_checkpoint('MLPerf-RN50v15',
                              last_epoch,
                              save_optimizer_states=False)

    # When using horovod, ensure all ops scheduled by the engine complete before exiting
    if 'horovod' in args.kv_store:
        mx.ndarray.waitall()

    if args.profile_server_suffix:
        mx.profiler.set_state(state='run', profile_process='server')
    if args.profile_worker_suffix:
        mx.profiler.set_state(state='run', profile_process='worker')
Esempio n. 11
0
def mlperf_fit(self,
               args,
               data_loader,
               epoch_size,
               eval_metric='acc',
               epoch_end_callback=None,
               batch_end_callback=None,
               kvstore='local',
               optimizer='sgd',
               optimizer_params=(('learning_rate', 0.01), ),
               explorer='linear',
               explorer_params=None,
               eval_end_callback=None,
               eval_batch_end_callback=None,
               initializer=Uniform(0.01),
               arg_params=None,
               aux_params=None,
               allow_missing=False,
               force_rebind=False,
               force_init=False,
               begin_epoch=0,
               num_epoch=None,
               validation_metric=None,
               monitor=None,
               sparse_row_id_fn=None,
               eval_offset=0,
               eval_period=1,
               accuracy_threshold=1.0):

    assert num_epoch is not None, 'please specify number of epochs'

    if monitor is not None:
        self.install_monitor(monitor)

    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    explorer = Explorer.create_explorer(name=explorer,
                                        optimizer=self._optimizer,
                                        explorer_params=explorer_params)
    #This mxnet can not use several optimizers without sgd series
    explorer.set_best_coeff(0)
    explorer.set_best_wd_coeff(0)
    explorer.set_best_cg(0)
    exp_freq = explorer_params['explore_freq']
    exp_start_epoch = explorer_params['explore_start_epoch']

    if validation_metric is None:
        validation_metric = eval_metric
    ###########################################################################
    # Adding Correct and Total Count metrics
    ###########################################################################
    if not isinstance(validation_metric, list):
        validation_metric = [validation_metric]

    validation_metric = mx.metric.create(validation_metric)

    if not isinstance(validation_metric, mx.metric.CompositeEvalMetric):
        vm = mx.metric.CompositeEvalMetric()
        vm.append(validation_metric)
        validation_metric = vm

    for m in [CorrectCount(), TotalCount()]:
        validation_metric.metrics.append(m)
    ###########################################################################

    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    try:
        world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    except:
        world_rank = 0
        world_size = 1

    use_cval_data =    explorer_params['add_one_fwd_epoch'] < num_epoch \
                    or explorer_params['no_augument_epoch'] < num_epoch

    best_rank = 0
    self.prepare_states()

    mx_resnet_print(key=mlperf_constants.INIT_STOP, sync=True)
    mx_resnet_print(key=mlperf_constants.RUN_START, sync=True)

    # data iterators
    (train_data, eval_data, cval_data) = data_loader(args, kvstore)
    if 'dist' in args.kv_store and not 'async' in args.kv_store:
        logging.info('Resizing training data to %d batches per machine',
                     epoch_size)
        # resize train iter to ensure each machine has same number of batches per epoch
        # if not, dist_sync can hang at the end with one machine waiting for other machines
        if not args.use_dali:
            train = mx.io.ResizeIter(train_data, epoch_size)

    block_epoch_start = begin_epoch
    block_epoch_count = eval_offset + 1 - (begin_epoch % eval_period)
    if block_epoch_count < 0:
        block_epoch_count += eval_period
    mx_resnet_print(key=mlperf_constants.BLOCK_START,
                    metadata={
                        'first_epoch_num': block_epoch_start + 1,
                        'epoch_count': block_epoch_count
                    })
    ################################################################################
    # training loop
    ################################################################################

    for epoch in range(begin_epoch, num_epoch):
        mx_resnet_print(key=mlperf_constants.EPOCH_START,
                        metadata={'epoch_num': epoch + 1})
        tic = time.time()
        eval_metric.reset()
        nbatch = 0

        use_normal_data_batch = epoch < explorer_params['no_augument_epoch']
        if not use_normal_data_batch:
            if world_rank == 0:
                self.logger.info('use non-augumented batch')

        end_of_batch = False

        if use_normal_data_batch:
            data_iter = iter(train_data)
            next_data_batch = next(data_iter)
        else:
            cval_iter = iter(cval_data)
            next_cval_batch = next(cval_iter)

        smooth_decay = explorer_params['smooth_decay']

        if not smooth_decay:
            explorer.apply_lr_decay_epoch(epoch)
            explorer.apply_wd_decay_epoch(epoch)
        explorer.set_mom(epoch)

        while not end_of_batch:
            if use_normal_data_batch:
                data_batch = next_data_batch
            else:
                cval_batch = next_cval_batch
            if monitor is not None:
                monitor.tic()

            if use_normal_data_batch:
                self.forward_backward(data_batch)
            else:
                self.forward_backward(cval_batch)

            if smooth_decay:
                explorer.apply_lr_decay_iter()
                explorer.apply_wd_decay_iter()
            explorer.apply_wd_warmup()
            explorer.apply_burn_in()

            use_explorer = (epoch == 0
                            and nbatch == 0) or (epoch >= exp_start_epoch
                                                 and nbatch % exp_freq == 0)
            if use_explorer:
                explorer.set_tmp_coeff(world_rank)
                explorer.set_tmp_wd_coeff(world_rank)
                explorer.set_tmp_cg(world_rank)

            explorer.set_best_coeff(0)
            explorer.set_best_wd_coeff(0)
            explorer.set_best_cg(world_rank)
            self.update()

            if use_normal_data_batch:
                if isinstance(data_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in data_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, data_batch.label)
            else:
                if isinstance(cval_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in cval_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, cval_batch.label)

            if use_normal_data_batch:
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                except StopIteration:
                    end_of_batch = True
            else:
                try:
                    # pre fetch next cval batch
                    next_cval_batch = next(cval_iter)
                except StopIteration:
                    end_of_batch = True

            if use_normal_data_batch:
                if not end_of_batch:
                    self.prepare(next_data_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)
            else:
                if not end_of_batch:
                    self.prepare(next_cval_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)

            if monitor is not None:
                monitor.toc_print()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1

        mx_resnet_print(key=mlperf_constants.EPOCH_STOP,
                        metadata={"epoch_num": epoch + 1})
        # one epoch of training is finished
        toc = time.time()
        if kvstore:
            if kvstore.rank == 0:
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch,
                                 (toc - tic))
        else:
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        #arg_params, aux_params = self.get_params()
        #self.set_params(arg_params, aux_params)

        if epoch_end_callback is not None:
            for callback in _as_list(epoch_end_callback):
                callback(epoch, self.symbol, arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data and epoch >= eval_offset and (
                epoch - eval_offset) % eval_period == 0:
            mx_resnet_print(key=mlperf_constants.EVAL_START,
                            metadata={'epoch_num': epoch + 1})
            res = self.score(eval_data,
                             validation_metric,
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if kvstore:
                if kvstore.rank == 0:
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch,
                                         name, val)
            else:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)
            res = dict(res)

            acc = [res['correct-count'], res['total-count']]
            acc = all_reduce(acc)
            acc = acc[0] / acc[1]
            mx_resnet_print(key=mlperf_constants.EVAL_STOP,
                            metadata={'epoch_num': epoch + 1})

            mx_resnet_print(key=mlperf_constants.EVAL_ACCURACY,
                            val=acc,
                            metadata={'epoch_num': epoch + 1})

            mx_resnet_print(
                key=mlperf_constants.BLOCK_STOP,
                metadata={'first_epoch_num': block_epoch_start + 1})
            if acc > accuracy_threshold:
                mx_resnet_print(key=mlperf_constants.RUN_STOP,
                                metadata={'status': 'success'})

                return epoch

            if epoch < (num_epoch - 1):
                block_epoch_start = epoch + 1
                block_epoch_count = num_epoch - epoch - 1
                if block_epoch_count > eval_period:
                    block_epoch_count = eval_period
                mx_resnet_print(key=mlperf_constants.BLOCK_START,
                                metadata={
                                    'first_epoch_num': block_epoch_start + 1,
                                    'epoch_count': block_epoch_count
                                })

        # end of 1 epoch, reset the data-iter for another epoch
        if use_normal_data_batch:
            train_data.reset()
        else:
            cval_data.reset()

    mx_resnet_print(key=mlperf_constants.RUN_STOP,
                    metadata={'status': 'aborted'})
    return num_epoch
Esempio n. 12
0
    )
    args = parser.parse_args()


    # select gpu for horovod process
    if 'horovod' in args.kv_store:
        args.gpus = _get_gpu(args.gpus)

    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    # load network
    from importlib import import_module
    net = import_module('symbols.'+args.network)

    mx_resnet_print(key=mlperf_log.EVAL_EPOCH_OFFSET,
                    val=args.eval_offset)

    mx_resnet_print(key=mlperf_log.RUN_START, sync=True)
    if args.seed is None:
        args.seed = int(random.SystemRandom().randint(0, 2**16 - 1))
    
    if 'horovod' in args.kv_store:
        all_seeds = np.random.randint(2**16, size=(int(os.environ['OMPI_COMM_WORLD_SIZE'])))
        args.seed = int(all_seeds[int(os.environ['OMPI_COMM_WORLD_RANK'])])

    mx_resnet_print(key=mlperf_log.RUN_SET_RANDOM_SEED, val=args.seed, uniq=False)
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    mx.random.seed(args.seed)
Esempio n. 13
0
def fit(args, kv, network, data_loader, **kwargs):
    """
    train a model
    args : argparse returns
    network : the symbol definition of the nerual network
    data_loader : function that returns the train and val data iterators
    """
    if args.profile_server_suffix:
        mx.profiler.set_config(filename=args.profile_server_suffix,
                               profile_all=True,
                               profile_process='server')
        mx.profiler.set_state(state='run', profile_process='server')

    if args.profile_worker_suffix:
        if kv.num_workers > 1:
            filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
        else:
            filename = args.profile_worker_suffix
        mx.profiler.set_config(filename=filename,
                               profile_all=True,
                               profile_process='worker')
        mx.profiler.set_state(state='run', profile_process='worker')

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)

    epoch_size = get_epoch_size(args, kv)

    # data iterators
    (train, val) = data_loader(args, kv)
    if 'dist' in args.kv_store and not 'async' in args.kv_store:
        logging.info('Resizing training data to %d batches per machine',
                     epoch_size)
        # resize train iter to ensure each machine has same number of batches per epoch
        # if not, dist_sync can hang at the end with one machine waiting for other machines
        if not args.use_dali:
            train = mx.io.ResizeIter(train, epoch_size)

    # load model
    if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else:
        sym, arg_params, aux_params = None, None, None
        if sym is not None:
            assert sym.tojson() == network.tojson()

    # save model
    epoch_end_callbacks = []

    # devices for training
    devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
        mx.gpu(int(i)) for i in args.gpus.split(',')
    ]

    # learning rate
    lr, lr_scheduler = _get_lr_scheduler(args, kv)

    # create model
    model = mx.mod.Module(context=devs, symbol=network)

    optimizer_params = {
        'learning_rate': lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
        'multi_precision': True
    }

    mx_resnet_print(key=mlperf_log.OPT_NAME, val=args.optimizer)

    mx_resnet_print(key=mlperf_log.OPT_LR, val=lr)

    mx_resnet_print(key=mlperf_log.OPT_MOMENTUM, val=args.mom)

    mx_resnet_print(key=mlperf_log.MODEL_L2_REGULARIZATION, val=args.wd)

    ##########################################################################
    # MXNet excludes BN layers from L2 Penalty by default,
    # so this won't be explicitly stated anywhere in the code
    ##########################################################################
    mx_resnet_print(key=mlperf_log.MODEL_EXCLUDE_BN_FROM_L2, val=True)

    # Only a limited number of optimizers have 'momentum' property
    has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
    if args.optimizer in has_momentum:
        optimizer_params['momentum'] = args.mom

    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)

    # evaluation metrices
    eval_metrics = ['accuracy']
    if args.top_k > 0:
        eval_metrics.append(
            mx.metric.create('top_k_accuracy', top_k=args.top_k))

    # callbacks that run after each batch
    batch_end_callbacks = []
    if 'horovod' in args.kv_store:
        # if using horovod, only report on rank 0 with global batch size
        if kv.rank == 0:
            batch_end_callbacks.append(
                mx.callback.Speedometer(kv.num_workers * args.batch_size,
                                        args.disp_batches))
    else:
        batch_end_callbacks.append(
            mx.callback.Speedometer(args.batch_size, args.disp_batches))

    mx_resnet_print(key=mlperf_log.EVAL_TARGET, val=args.accuracy_threshold)

    # run
    last_epoch = mlperf_fit(
        model,
        train,
        begin_epoch=0,
        num_epoch=args.num_epochs,
        eval_data=val,
        eval_metric=eval_metrics,
        kvstore=kv,
        optimizer=args.optimizer,
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        batch_end_callback=batch_end_callbacks,
        epoch_end_callback=
        epoch_end_callbacks,  #checkpoint if args.use_dali else ,,
        allow_missing=True,
        eval_offset=args.eval_offset,
        eval_period=args.eval_period,
        accuracy_threshold=args.accuracy_threshold)

    mx_resnet_print(key=mlperf_log.RUN_STOP, sync=True)

    if ('horovod' not in args.kv_store) or kv.rank == 0:
        model.save_checkpoint('MLPerf-RN50v15',
                              last_epoch,
                              save_optimizer_states=False)

    # When using horovod, ensure all ops scheduled by the engine complete before exiting
    if 'horovod' in args.kv_store:
        mx.ndarray.waitall()

    if args.profile_server_suffix:
        mx.profiler.set_state(state='run', profile_process='server')
    if args.profile_worker_suffix:
        mx.profiler.set_state(state='run', profile_process='worker')
Esempio n. 14
0
def mlperf_fit(self,
               train_data,
               eval_data=None,
               eval_metric='acc',
               epoch_end_callback=None,
               batch_end_callback=None,
               kvstore='local',
               optimizer='sgd',
               optimizer_params=(('learning_rate', 0.01), ),
               eval_end_callback=None,
               eval_batch_end_callback=None,
               initializer=Uniform(0.01),
               arg_params=None,
               aux_params=None,
               allow_missing=False,
               force_rebind=False,
               force_init=False,
               begin_epoch=0,
               num_epoch=None,
               validation_metric=None,
               monitor=None,
               sparse_row_id_fn=None,
               eval_offset=0,
               eval_period=1,
               accuracy_threshold=1.0):

    assert num_epoch is not None, 'please specify number of epochs'

    self.bind(data_shapes=train_data.provide_data,
              label_shapes=train_data.provide_label,
              for_training=True,
              force_rebind=force_rebind)

    if monitor is not None:
        self.install_monitor(monitor)

    self.init_params(initializer=initializer,
                     arg_params=arg_params,
                     aux_params=aux_params,
                     allow_missing=allow_missing,
                     force_init=force_init)
    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    if validation_metric is None:
        validation_metric = eval_metric
    ###########################################################################
    # Adding Correct and Total Count metrics
    ###########################################################################
    if not isinstance(validation_metric, list):
        validation_metric = [validation_metric]

    validation_metric = mx.metric.create(validation_metric)

    if not isinstance(validation_metric, mx.metric.CompositeEvalMetric):
        vm = mx.metric.CompositeEvalMetric()
        vm.append(validation_metric)
        validation_metric = vm

    for m in [CorrectCount(), TotalCount()]:
        validation_metric.metrics.append(m)
    ###########################################################################

    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    mx_resnet_print(key=mlperf_log.TRAIN_LOOP)
    ################################################################################
    # training loop
    ################################################################################
    for epoch in range(begin_epoch, num_epoch):
        mx_resnet_print(key=mlperf_log.TRAIN_EPOCH, val=epoch)
        tic = time.time()
        eval_metric.reset()
        nbatch = 0
        data_iter = iter(train_data)
        end_of_batch = False
        next_data_batch = next(data_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if monitor is not None:
                monitor.tic()
            self.forward_backward(data_batch)
            self.update()

            if isinstance(data_batch, list):
                self.update_metric(eval_metric,
                                   [db.label for db in data_batch],
                                   pre_sliced=True)
            else:
                self.update_metric(eval_metric, data_batch.label)

            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                self.prepare(next_data_batch,
                             sparse_row_id_fn=sparse_row_id_fn)
            except StopIteration:
                end_of_batch = True

            if monitor is not None:
                monitor.toc_print()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1

        # one epoch of training is finished
        toc = time.time()
        if kvstore:
            if kvstore.rank == 0:
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch,
                                 (toc - tic))
        else:
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        arg_params, aux_params = self.get_params()
        self.set_params(arg_params, aux_params)

        if epoch_end_callback is not None:
            for callback in _as_list(epoch_end_callback):
                callback(epoch, self.symbol, arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data and epoch % eval_period == eval_offset:
            mx_resnet_print(key=mlperf_log.EVAL_START)
            res = self.score(eval_data,
                             validation_metric,
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if kvstore:
                if kvstore.rank == 0:
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch,
                                         name, val)
            else:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)
            res = dict(res)

            acc = [res['correct-count'], res['total-count']]
            acc = all_reduce(acc)
            acc = acc[0] / acc[1]
            mx_resnet_print(key=mlperf_log.EVAL_ACCURACY,
                            val={
                                "epoch": epoch,
                                "value": acc
                            })
            mx_resnet_print(key=mlperf_log.EVAL_STOP)
            if acc > accuracy_threshold:
                return epoch

        # end of 1 epoch, reset the data-iter for another epoch
        train_data.reset()

    return num_epoch
Esempio n. 15
0
            mx.ndarray.random.normal(0, 0.01, out=arg)
        else:
            return super()._init_weight(name, arg)


class BNZeroInit(mx.init.Xavier):
    def _init_gamma(self, name, arg):
        if name.endswith("bn3_gamma"):
            arg[:] = 0.0
        else:
            arg[:] = 1.0


if __name__ == '__main__':
    LOGGER.propagate = False
    mx_resnet_print(key=mlperf_constants.INIT_START, uniq=False)
    # parse args
    parser = argparse.ArgumentParser(
        description="MLPerf RN50v1.5 training script",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    add_general_args(parser)
    fit.add_fit_args(parser)
    dali.add_dali_args(parser)

    parser.set_defaults(
        # network
        network='resnet-v1b',
        num_layers=50,

        # data
        resize=256,
Esempio n. 16
0
    def __init__(self, batch_size, num_threads, device_id, rec_path, idx_path,
                 shard_id, num_shards, crop_shape, 
                 min_random_area, max_random_area,
                 min_random_aspect_ratio, max_random_aspect_ratio,
                 nvjpeg_padding, prefetch_queue=3,
                 seed=12,
                 output_layout=types.NCHW, pad_output=True, dtype='float16',
                 mlperf_print=True):
        super(HybridTrainPipe, self).__init__(
                batch_size, num_threads, device_id, 
                seed = seed + device_id, 
                prefetch_queue_depth = prefetch_queue)

        if mlperf_print:
            # Shuffiling is done inside ops.MXNetReader
            mx_resnet_print(key=mlperf_log.INPUT_ORDER)

        self.input = ops.MXNetReader(path = [rec_path], index_path=[idx_path],
                                     random_shuffle=True, shard_id=shard_id, num_shards=num_shards)

        self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB,
                                        device_memory_padding = nvjpeg_padding,
                                        host_memory_padding = nvjpeg_padding)

        self.rrc = ops.RandomResizedCrop(device = "gpu",
                                         random_area = [
                                             min_random_area,
                                             max_random_area],
                                         random_aspect_ratio = [
                                             min_random_aspect_ratio,
                                             max_random_aspect_ratio],
                                         size = crop_shape)

        self.cmnp = ops.CropMirrorNormalize(device = "gpu",
                                            output_dtype = types.FLOAT16 if dtype == 'float16' else types.FLOAT,
                                            output_layout = output_layout,
                                            crop = crop_shape,
                                            pad_output = pad_output,
                                            image_type = types.RGB,
                                            mean = _mean_pixel,
                                            std =  _std_pixel)
        self.coin = ops.CoinFlip(probability = 0.5)

        if mlperf_print:
            mx_resnet_print(
                    key=mlperf_log.INPUT_CROP_USES_BBOXES,
                    val=False)
            mx_resnet_print(
                    key=mlperf_log.INPUT_DISTORTED_CROP_RATIO_RANGE,
                    val=(min_random_aspect_ratio,
                         max_random_aspect_ratio))
            mx_resnet_print(
                    key=mlperf_log.INPUT_DISTORTED_CROP_AREA_RANGE,
                    val=(min_random_area,
                         max_random_area))
            mx_resnet_print(
                    key=mlperf_log.INPUT_MEAN_SUBTRACTION,
                    val=_mean_pixel)
            mx_resnet_print(
                    key=mlperf_log.INPUT_RANDOM_FLIP)
Esempio n. 17
0
        if name.startswith("fc"):
            mx.ndarray.random.normal(0, 0.01, out=arg)
        else:
            return super()._init_weight(name, arg)


class BNZeroInit(mx.init.Xavier):
    def _init_gamma(self, name, arg):
        if name.endswith("bn3_gamma"):
            arg[:] = 0.0
        else:
            arg[:] = 1.0


if __name__ == '__main__':
    mx_resnet_print(key=mlperf_constants.INIT_START, uniq=False)
    # parse args
    parser = argparse.ArgumentParser(
        description="MLPerf RN50v1.5 training script",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    add_general_args(parser)
    fit.add_fit_args(parser)
    dali.add_dali_args(parser)

    parser.set_defaults(
        # network
        network='resnet-v1b',
        num_layers=50,

        # data
        resize=256,
def resnet(units,
           num_stages,
           filter_list,
           num_classes,
           image_shape,
           bottle_neck=True,
           workspace=256,
           dtype='float32',
           memonger=False,
           input_layout='NCHW',
           conv_layout='NCHW',
           batchnorm_layout='NCHW',
           pooling_layout='NCHW',
           verbose=False,
           cudnn_bn_off=False,
           bn_eps=2e-5,
           bn_mom=0.9,
           conv_algo=-1,
           fuse_bn_relu=False,
           fuse_bn_add_relu=False,
           force_tensor_core=False,
           use_dali=True,
           smooth_alpha=0.0):
    """Return ResNet symbol of
    Parameters
    ----------
    units : list
        Number of units in each stage
    num_stages : int
        Number of stage
    filter_list : list
        Channel size of each stage
    num_classes : int
        Ouput size of symbol
    dataset : str
        Dataset type, only cifar10 and imagenet supports
    workspace : int
        Workspace used in convolution operator
    dtype : str
        Precision (float32 or float16)
    memonger : boolean
        Activates "memory monger" to reduce the model's memory footprint
    input_layout : str
        interpretation (e.g. NCHW vs NHWC) of data provided by the i/o pipeline (may introduce transposes
        if in conflict with 'layout' above)
    conv_layout : str
        interpretation (e.g. NCHW vs NHWC) of data for convolution operation.
    batchnorm_layout : str
        directs which kernel performs the batchnorm (may introduce transposes if in conflict with 'conv_layout' above)
    pooling_layout : str
        directs which kernel performs the pooling (may introduce transposes if in conflict with 'conv_layout' above)
    """

    act = 'relu' if fuse_bn_relu else None
    num_unit = len(units)
    assert (num_unit == num_stages)
    data = mx.sym.Variable(name='data')
    if not use_dali:
        # double buffering of data
        if dtype == 'float32':
            data = mx.sym.identity(data=data, name='id')
        else:
            if dtype == 'float16':
                data = mx.sym.Cast(data=data, dtype=np.float16)
    (nchannel, height, width) = image_shape

    # Insert transpose as needed to get the input layout to match the desired processing layout
    data = transform_layout(data, input_layout, conv_layout)

    if height <= 32:  # such as cifar10
        body = mx.sym.Convolution(data=data,
                                  num_filter=filter_list[0],
                                  kernel=(3, 3),
                                  stride=(1, 1),
                                  pad=(1, 1),
                                  no_bias=True,
                                  name="conv0",
                                  workspace=workspace,
                                  layout=conv_layout,
                                  cudnn_algo_verbose=verbose,
                                  cudnn_algo_fwd=conv_algo,
                                  cudnn_algo_bwd_data=conv_algo,
                                  cudnn_algo_bwd_filter=conv_algo,
                                  cudnn_tensor_core_only=force_tensor_core)
        # Is this BatchNorm supposed to be here?
        body = batchnorm(data=body,
                         io_layout=conv_layout,
                         batchnorm_layout=batchnorm_layout,
                         fix_gamma=False,
                         eps=bn_eps,
                         momentum=bn_mom,
                         name='bn0',
                         cudnn_off=cudnn_bn_off)
    else:  # often expected to be 224 such as imagenet
        shape = image_shape
        mx_resnet_print(key=mlperf_log.MODEL_HP_INITIAL_SHAPE, val=shape)
        body = mx.sym.Convolution(data=data,
                                  num_filter=filter_list[0],
                                  kernel=(7, 7),
                                  stride=(2, 2),
                                  pad=(3, 3),
                                  no_bias=True,
                                  name="conv0",
                                  workspace=workspace,
                                  layout=conv_layout,
                                  cudnn_algo_verbose=verbose,
                                  cudnn_algo_fwd=conv_algo,
                                  cudnn_algo_bwd_data=conv_algo,
                                  cudnn_algo_bwd_filter=conv_algo,
                                  cudnn_tensor_core_only=force_tensor_core)
        shape = resnet_conv2d_log(shape, 2, filter_list[0],
                                  mlperf_log.TRUNCATED_NORMAL, False)
        body = batchnorm(data=body,
                         io_layout=conv_layout,
                         batchnorm_layout=batchnorm_layout,
                         fix_gamma=False,
                         eps=bn_eps,
                         momentum=bn_mom,
                         name='bn0',
                         cudnn_off=cudnn_bn_off,
                         act_type=act)
        shape = resnet_batchnorm_log(shape,
                                     momentum=bn_mom,
                                     eps=bn_eps,
                                     center=True,
                                     scale=True,
                                     training=True)
        if not fuse_bn_relu:
            body = mx.sym.Activation(data=body, act_type='relu', name='relu0')

        body = pooling(data=body,
                       io_layout=conv_layout,
                       pooling_layout=pooling_layout,
                       kernel=(3, 3),
                       stride=(2, 2),
                       pad=(1, 1),
                       pool_type='max')
        shape = resnet_max_pool_log(shape, 2)

    for i in range(num_stages):
        body, shape = residual_unit(body,
                                    shape,
                                    filter_list[i + 1],
                                    (1 if i == 0 else 2, 1 if i == 0 else 2),
                                    False,
                                    name='stage%d_unit%d' % (i + 1, 1),
                                    bottle_neck=bottle_neck,
                                    workspace=workspace,
                                    memonger=memonger,
                                    conv_layout=conv_layout,
                                    batchnorm_layout=batchnorm_layout,
                                    verbose=verbose,
                                    cudnn_bn_off=cudnn_bn_off,
                                    bn_eps=bn_eps,
                                    bn_mom=bn_mom,
                                    conv_algo=conv_algo,
                                    fuse_bn_relu=fuse_bn_relu,
                                    fuse_bn_add_relu=fuse_bn_add_relu,
                                    cudnn_tensor_core_only=force_tensor_core)
        for j in range(units[i] - 1):
            body, shape = residual_unit(
                body,
                shape,
                filter_list[i + 1], (1, 1),
                True,
                name='stage%d_unit%d' % (i + 1, j + 2),
                bottle_neck=bottle_neck,
                workspace=workspace,
                memonger=memonger,
                conv_layout=conv_layout,
                batchnorm_layout=batchnorm_layout,
                verbose=verbose,
                cudnn_bn_off=cudnn_bn_off,
                bn_eps=bn_eps,
                bn_mom=bn_mom,
                conv_algo=conv_algo,
                fuse_bn_relu=fuse_bn_relu,
                fuse_bn_add_relu=fuse_bn_add_relu,
                cudnn_tensor_core_only=force_tensor_core)
    # Although kernel is not used here when global_pool=True, we should put one
    pool1 = pooling(data=body,
                    io_layout=conv_layout,
                    pooling_layout=pooling_layout,
                    global_pool=True,
                    kernel=(7, 7),
                    pool_type='avg',
                    name='pool1')
    flat = mx.sym.Flatten(data=pool1)
    shape = (shape[0])
    fc1 = mx.sym.FullyConnected(data=flat,
                                num_hidden=num_classes,
                                name='fc1',
                                cublas_algo_verbose=verbose)
    shape = resnet_dense_log(shape, num_classes)

    mx_resnet_print(key=mlperf_log.MODEL_HP_FINAL_SHAPE, val=shape)
    if dtype == 'float16':
        fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)

    ##########################################################################
    # MXNet computes Cross Entropy loss gradients without explicitly computing
    # the value of loss function.
    # Take a look here:
    # https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.SoftmaxOutput
    # for further details
    ##########################################################################
    mx_resnet_print(key=mlperf_log.MODEL_HP_LOSS_FN, val=mlperf_log.CCE)

    return mx.sym.SoftmaxOutput(data=fc1,
                                name='softmax',
                                smooth_alpha=smooth_alpha)