Exemplo n.º 1
0
def train():
    """Train function."""
    args = parse_args()

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    if args.need_profiler:
        from mindspore.profiler.profiling import Profiler
        profiler = Profiler(output_path=args.outputs_dir,
                            is_detail=True,
                            is_show_op_path=True)

    loss_meter = AverageMeter('loss')

    context.reset_auto_parallel_context()
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        degree = get_group_size()
    else:
        parallel_mode = ParallelMode.STAND_ALONE
        degree = 1
    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      mirror_mean=True,
                                      device_num=degree)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)

    if args.resume_yolov3:
        param_dict = load_checkpoint(args.resume_yolov3)
        param_dict_new = {}
        for key, values in param_dict.items():
            args.logger.info('ckpt param name = {}'.format(key))
            if key.startswith('moments.') or key.startswith('global_') or \
               key.startswith('learning_rate') or key.startswith('momentum'):
                continue
            elif key.startswith('yolo_network.'):
                key_new = key[13:]

                if key_new.endswith('1.beta'):
                    key_new = key_new.replace('1.beta', 'batchnorm.beta')

                if key_new.endswith('1.gamma'):
                    key_new = key_new.replace('1.gamma', 'batchnorm.gamma')

                if key_new.endswith('1.moving_mean'):
                    key_new = key_new.replace('1.moving_mean',
                                              'batchnorm.moving_mean')

                if key_new.endswith('1.moving_variance'):
                    key_new = key_new.replace('1.moving_variance',
                                              'batchnorm.moving_variance')

                if key_new.endswith('.weight'):
                    if key_new.endswith('0.weight'):
                        key_new = key_new.replace('0.weight', 'conv.weight')
                    else:
                        key_new = key_new.replace('.weight', '.conv.weight')

                if key_new.endswith('.bias'):
                    key_new = key_new.replace('.bias', '.conv.bias')
                param_dict_new[key_new] = values

                args.logger.info('in resume {}'.format(key_new))
            else:
                param_dict_new[key] = values
                args.logger.info('in resume {}'.format(key))

        args.logger.info('resume finished')
        for _, param in network.parameters_and_names():
            args.logger.info('network param name = {}'.format(param.name))
            if param.name not in param_dict_new:
                args.logger.info('not match param name = {}'.format(
                    param.name))
        load_param_into_net(network, param_dict_new)
        args.logger.info('load_model {} success'.format(args.resume_yolov3))

    config = ConfigYOLOV3DarkNet53()
    # convert fusion network to quantization aware network
    if config.quantization_aware:
        network = quant.convert_quant_network(network,
                                              bn_fold=True,
                                              per_channel=[True, False],
                                              symmetric=[True, False])

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]

    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root,
                                        anno_path=args.annFile,
                                        is_training=True,
                                        batch_size=args.per_batch_size,
                                        max_epoch=args.max_epoch,
                                        device_num=args.group_size,
                                        rank=args.rank,
                                        config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size /
                               args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(
            args.lr,
            args.lr_epochs,
            args.steps_per_epoch,
            args.warmup_epochs,
            args.max_epoch,
            gamma=args.lr_gamma,
        )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch,
                                        args.warmup_epochs, args.max_epoch,
                                        args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_V2':
        lr = warmup_cosine_annealing_lr_V2(args.lr, args.steps_per_epoch,
                                           args.warmup_epochs, args.max_epoch,
                                           args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_sample':
        lr = warmup_cosine_annealing_lr_sample(args.lr, args.steps_per_epoch,
                                               args.warmup_epochs,
                                               args.max_epoch, args.T_max,
                                               args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    network = TrainingWrapper(network, opt)
    network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator()

    shape_record = ShapeRecord()
    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
        shape_record.set(input_shape)

        images = Tensor(images)
        annos = data["annotation"]
        if args.group_size == 1:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box(annos, config, input_shape)
        else:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box_single(annos, config, input_shape)

        batch_y_true_0 = Tensor(batch_y_true_0)
        batch_y_true_1 = Tensor(batch_y_true_1)
        batch_y_true_2 = Tensor(batch_y_true_2)
        batch_gt_box0 = Tensor(batch_gt_box0)
        batch_gt_box1 = Tensor(batch_gt_box1)
        batch_gt_box2 = Tensor(batch_gt_box2)

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2,
                       input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(
                        epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break

    args.logger.info('==========end training===============')
Exemplo n.º 2
0
    args.steps_per_epoch = args.steps_per_epoch // args.max_epoch
    args.logger.info('Finish loading dataset')

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    # lr scheduler
    if args.lr_scheduler == 'multistep':
        lr_fun = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma,
                             args.steps_per_epoch, args.max_epoch,
                             args.warmup_epochs)
        lr = lr_fun.get_lr()
    elif args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(args.lr,
                            args.lr_epochs,
                            args.steps_per_epoch,
                            args.warmup_epochs,
                            args.max_epoch,
                            gamma=args.lr_gamma)
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch,
                                        args.warmup_epochs, args.max_epoch,
                                        args.t_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_V2':
        lr = warmup_cosine_annealing_lr_v2(args.lr, args.steps_per_epoch,
                                           args.warmup_epochs, args.max_epoch,
                                           args.t_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_sample':
        lr = warmup_cosine_annealing_lr_sample(args.lr, args.steps_per_epoch,
                                               args.warmup_epochs,
                                               args.max_epoch, args.t_max,
                                               args.eta_min)
Exemplo n.º 3
0
def train():
    """Train function."""
    args = parse_args()

    devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
    context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
                        device_target=args.device_target, save_graphs=True, device_id=devid)

    # init distributed
    if args.is_distributed:
        if args.device_target == "Ascend":
            init()
        else:
            init("nccl")
        args.rank = get_rank()
        args.group_size = get_group_size()

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    if args.need_profiler:
        from mindspore.profiler.profiling import Profiler
        profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)

    loss_meter = AverageMeter('loss')

    context.reset_auto_parallel_context()
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        degree = get_group_size()
    else:
        parallel_mode = ParallelMode.STAND_ALONE
        degree = 1
    context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)

    if args.pretrained_backbone:
        network = load_backbone(network, args.pretrained_backbone, args)
        args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
    else:
        args.logger.info('Not load pre-trained backbone, please be careful')

    if args.resume_yolov3:
        param_dict = load_checkpoint(args.resume_yolov3)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('yolo_network.'):
                param_dict_new[key[13:]] = values
                args.logger.info('in resume {}'.format(key))
            else:
                param_dict_new[key] = values
                args.logger.info('in resume {}'.format(key))

        args.logger.info('resume finished')
        load_param_into_net(network, param_dict_new)
        args.logger.info('load_model {} success'.format(args.resume_yolov3))

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config = ConfigYOLOV3DarkNet53()

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]
    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
                                        batch_size=args.per_batch_size, max_epoch=args.max_epoch,
                                        device_num=args.group_size, rank=args.rank, config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(args.lr,
                            args.lr_epochs,
                            args.steps_per_epoch,
                            args.warmup_epochs,
                            args.max_epoch,
                            gamma=args.lr_gamma,
                            )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr,
                                        args.steps_per_epoch,
                                        args.warmup_epochs,
                                        args.max_epoch,
                                        args.T_max,
                                        args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_V2':
        lr = warmup_cosine_annealing_lr_V2(args.lr,
                                           args.steps_per_epoch,
                                           args.warmup_epochs,
                                           args.max_epoch,
                                           args.T_max,
                                           args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_sample':
        lr = warmup_cosine_annealing_lr_sample(args.lr,
                                               args.steps_per_epoch,
                                               args.warmup_epochs,
                                               args.max_epoch,
                                               args.T_max,
                                               args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)
    enable_amp = False
    is_gpu = context.get_context("device_target") == "GPU"
    if is_gpu:
        enable_amp = True
    if enable_amp:
        loss_scale_value = 1.0
        loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
        network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
                                          level="O2", keep_batchnorm_fp32=True)
        keep_loss_fp32(network)
    else:
        network = TrainingWrapper(network, opt)
        network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                       keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator()

    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))

        images = Tensor(images)

        batch_y_true_0 = Tensor(data['bbox1'])
        batch_y_true_1 = Tensor(data['bbox2'])
        batch_y_true_2 = Tensor(data['bbox3'])
        batch_gt_box0 = Tensor(data['gt_box1'])
        batch_gt_box1 = Tensor(data['gt_box2'])
        batch_gt_box2 = Tensor(data['gt_box3'])

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
                       batch_gt_box2, input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break

    args.logger.info('==========end training===============')
Exemplo n.º 4
0
def train():
    """Train function."""
    args = parse_args()

    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    loss_meter = AverageMeter('loss')

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recursive_init(network)

    pretrained_backbone_slice = args.pretrained_backbone.split('/')
    backbone_ckpt_file = pretrained_backbone_slice[
        len(pretrained_backbone_slice) - 1]
    local_backbone_ckpt_path = '/cache/' + backbone_ckpt_file
    # download backbone checkpoint
    mox.file.copy_parallel(src_url=args.pretrained_backbone,
                           dst_url=local_backbone_ckpt_path)

    if args.pretrained_backbone:
        network = load_backbone(network, local_backbone_ckpt_path, args)
        args.logger.info('load pre-trained backbone {} into network'.format(
            args.pretrained_backbone))
    else:
        args.logger.info('Not load pre-trained backbone, please be careful')

    if args.resume_yolov3:
        param_dict = load_checkpoint(args.resume_yolov3)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('yolo_network.'):
                param_dict_new[key[13:]] = values
                args.logger.info('in resume {}'.format(key))
            else:
                param_dict_new[key] = values
                args.logger.info('in resume {}'.format(key))

        args.logger.info('resume finished')
        load_param_into_net(network, param_dict_new)
        args.logger.info('load_model {} success'.format(args.resume_yolov3))

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config = ConfigYOLOV3DarkNet53()

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [convert_training_shape(args)]
    if args.resize_rate:
        config.resize_rate = args.resize_rate

    # data download
    local_data_path = '/cache/data'
    local_ckpt_path = '/cache/ckpt_file'
    print('Download data.')
    mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_path)

    ds, data_size = create_yolo_dataset(
        image_dir=os.path.join(local_data_path, 'images'),
        anno_path=os.path.join(local_data_path, 'annotation.json'),
        is_training=True,
        batch_size=args.per_batch_size,
        max_epoch=args.epoch_size,
        device_num=args.group_size,
        rank=args.rank,
        config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size /
                               args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch * 10

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(
            args.lr,
            args.lr_epochs,
            args.steps_per_epoch,
            args.warmup_epochs,
            args.epoch_size,
            gamma=args.lr_gamma,
        )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch,
                                        args.warmup_epochs, args.max_epoch,
                                        args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_V2':
        lr = warmup_cosine_annealing_lr_V2(args.lr, args.steps_per_epoch,
                                           args.warmup_epochs, args.max_epoch,
                                           args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_sample':
        lr = warmup_cosine_annealing_lr_sample(args.lr, args.steps_per_epoch,
                                               args.warmup_epochs,
                                               args.max_epoch, args.T_max,
                                               args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    network = TrainingWrapper(network, opt)
    network.set_train()

    # checkpoint save
    ckpt_max_num = 10
    ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                   keep_checkpoint_max=ckpt_max_num)
    ckpt_cb = ModelCheckpoint(config=ckpt_config,
                              directory=local_ckpt_path,
                              prefix='yolov3')
    cb_params = _InternalCallbackParam()
    cb_params.train_network = network
    cb_params.epoch_num = ckpt_max_num
    cb_params.cur_epoch_num = 1
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator()

    shape_record = ShapeRecord()
    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        shape_record.set(input_shape)

        images = Tensor(images)
        annos = data["annotation"]
        if args.group_size == 1:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box(annos, config, input_shape)
        else:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box_single(annos, config, input_shape)

        batch_y_true_0 = Tensor(batch_y_true_0)
        batch_y_true_1 = Tensor(batch_y_true_1)
        batch_y_true_2 = Tensor(batch_y_true_2)
        batch_gt_box0 = Tensor(batch_gt_box0)
        batch_gt_box1 = Tensor(batch_gt_box1)
        batch_gt_box2 = Tensor(batch_gt_box2)

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2,
                       input_shape)
        loss_meter.update(loss.asnumpy())

        # ckpt progress
        cb_params.cur_step_num = i + 1  # current step number
        cb_params.batch_num = i + 2
        ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(
                        epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0:
            cb_params.cur_epoch_num += 1

    args.logger.info('==========end training===============')

    # upload checkpoint files
    print('Upload checkpoint.')
    mox.file.copy_parallel(src_url=local_ckpt_path, dst_url=args.train_url)