コード例 #1
0
def test_save_checkpoint():
    """Test save checkpoint."""
    train_config = CheckpointConfig(
        save_checkpoint_steps=16,
        save_checkpoint_seconds=0,
        keep_checkpoint_max=5,
        keep_checkpoint_per_n_minutes=0)
    cb_params = _InternalCallbackParam()
    net = Net()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
    network_ = WithLossCell(net, loss)
    _train_network = TrainOneStepCell(network_, optim)
    cb_params.train_network = _train_network
    cb_params.epoch_num = 10
    cb_params.cur_epoch_num = 5
    cb_params.cur_step_num = 0
    cb_params.batch_num = 32
    ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config)
    run_context = RunContext(cb_params)
    ckpoint_cb.begin(run_context)
    ckpoint_cb.step_end(run_context)
    if os.path.exists('./test_files/test_ckpt-model.pkl'):
        os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE)
        os.remove('./test_files/test_ckpt-model.pkl')
コード例 #2
0
def test_checkpoint_save_ckpt_seconds():
    """Test checkpoint save ckpt seconds."""
    train_config = CheckpointConfig(
        save_checkpoint_steps=16,
        save_checkpoint_seconds=100,
        keep_checkpoint_max=0,
        keep_checkpoint_per_n_minutes=1)
    ckpt_cb = ModelCheckpoint(config=train_config)
    cb_params = _InternalCallbackParam()
    net = Net()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
    network_ = WithLossCell(net, loss)
    _train_network = TrainOneStepCell(network_, optim)
    cb_params.train_network = _train_network
    cb_params.epoch_num = 10
    cb_params.cur_epoch_num = 4
    cb_params.cur_step_num = 128
    cb_params.batch_num = 32
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)
    ckpt_cb.step_end(run_context)
    ckpt_cb2 = ModelCheckpoint(config=train_config)
    cb_params.cur_epoch_num = 1
    cb_params.cur_step_num = 16
    ckpt_cb2.begin(run_context)
    ckpt_cb2.step_end(run_context)
コード例 #3
0
def test_step_end_save_graph():
    """Test save checkpoint."""
    train_config = CheckpointConfig(
        save_checkpoint_steps=16,
        save_checkpoint_seconds=0,
        keep_checkpoint_max=5,
        keep_checkpoint_per_n_minutes=0)
    cb_params = _InternalCallbackParam()
    net = LossNet()
    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
    input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32))
    net(input_data, input_label)
    cb_params.train_network = net
    cb_params.epoch_num = 10
    cb_params.cur_epoch_num = 5
    cb_params.cur_step_num = 0
    cb_params.batch_num = 32
    ckpoint_cb = ModelCheckpoint(prefix="test", directory='./test_files', config=train_config)
    run_context = RunContext(cb_params)
    ckpoint_cb.begin(run_context)
    ckpoint_cb.step_end(run_context)
    assert os.path.exists('./test_files/test-graph.meta')
    if os.path.exists('./test_files/test-graph.meta'):
        os.chmod('./test_files/test-graph.meta', stat.S_IWRITE)
        os.remove('./test_files/test-graph.meta')
    ckpoint_cb.step_end(run_context)
    assert not os.path.exists('./test_files/test-graph.meta')
コード例 #4
0
def test_checkpoint_save_ckpt_with_encryption():
    """Test checkpoint save ckpt with encryption."""
    train_config = CheckpointConfig(
        save_checkpoint_steps=16,
        save_checkpoint_seconds=0,
        keep_checkpoint_max=5,
        keep_checkpoint_per_n_minutes=0,
        enc_key=os.urandom(16),
        enc_mode="AES-GCM")
    ckpt_cb = ModelCheckpoint(config=train_config)
    cb_params = _InternalCallbackParam()
    net = Net()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
    network_ = WithLossCell(net, loss)
    _train_network = TrainOneStepCell(network_, optim)
    cb_params.train_network = _train_network
    cb_params.epoch_num = 10
    cb_params.cur_epoch_num = 5
    cb_params.cur_step_num = 160
    cb_params.batch_num = 32
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)
    ckpt_cb.step_end(run_context)
    ckpt_cb2 = ModelCheckpoint(config=train_config)
    cb_params.cur_epoch_num = 1
    cb_params.cur_step_num = 15

    if platform.system().lower() == "windows":
        with pytest.raises(NotImplementedError):
            ckpt_cb2.begin(run_context)
            ckpt_cb2.step_end(run_context)
    else:
        ckpt_cb2.begin(run_context)
        ckpt_cb2.step_end(run_context)
コード例 #5
0
        batch_y_true_2 = Tensor.from_numpy(data['bbox3'])
        batch_gt_box0 = Tensor.from_numpy(data['gt_box1'])
        batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
        batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2,
                       input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(
                        epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
コード例 #6
0
def train():
    """Train function."""
    args = parse_args()
    devid = int(os.getenv('DEVICE_ID', '0'))
    context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
                        device_target=args.device_target, save_graphs=True, device_id=devid)
    if args.need_profiler:
        from mindspore.profiler.profiling import Profiler
        profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)

    loss_meter = AverageMeter('loss')

    context.reset_auto_parallel_context()
    parallel_mode = ParallelMode.STAND_ALONE
    degree = 1
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        degree = get_group_size()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)
    load_yolov3_params(args, network)

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config = ConfigYOLOV3DarkNet53()

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]
    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
                                        batch_size=args.per_batch_size, max_epoch=args.max_epoch,
                                        device_num=args.group_size, rank=args.rank, config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    lr = get_lr(args)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)
    is_gpu = context.get_context("device_target") == "GPU"
    if is_gpu:
        loss_scale_value = 1.0
        loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
        network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
                                          level="O2", keep_batchnorm_fp32=True)
        keep_loss_fp32(network)
    else:
        network = TrainingWrapper(network, opt)
        network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                       keep_checkpoint_max=ckpt_max_num)
        save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=save_ckpt_path,
                                  prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator(output_numpy=True)

    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))

        images = Tensor.from_numpy(images)

        batch_y_true_0 = Tensor.from_numpy(data['bbox1'])
        batch_y_true_1 = Tensor.from_numpy(data['bbox2'])
        batch_y_true_2 = Tensor.from_numpy(data['bbox3'])
        batch_gt_box0 = Tensor.from_numpy(data['gt_box1'])
        batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
        batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
                       batch_gt_box2, input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break

    args.logger.info('==========end training===============')
コード例 #7
0
def train():
    """Train function."""
    args = parse_args()
    devid = int(os.getenv('DEVICE_ID', '0'))
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.device_target,
                        save_graphs=False,
                        device_id=devid)
    loss_meter = AverageMeter('loss')

    network = YOLOV4CspDarkNet53(is_training=True)
    # default is kaiming-normal
    default_recursive_init(network)

    if args.pretrained_backbone:
        pretrained_backbone_slice = args.pretrained_backbone.split('/')
        backbone_ckpt_file = pretrained_backbone_slice[
            len(pretrained_backbone_slice) - 1]
        local_backbone_ckpt_path = '/cache/' + backbone_ckpt_file
        # download backbone checkpoint
        mox.file.copy_parallel(src_url=args.pretrained_backbone,
                               dst_url=local_backbone_ckpt_path)
        args.pretrained_backbone = local_backbone_ckpt_path
    load_yolov4_params(args, network)

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config = ConfigYOLOV4CspDarkNet53()

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [convert_training_shape(args)]
    if args.resize_rate:
        config.resize_rate = args.resize_rate

    # data download
    local_data_path = '/cache/data'
    local_ckpt_path = '/cache/ckpt_file'
    print('Download data.')
    mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_path)

    ds, data_size = create_yolo_dataset(
        image_dir=os.path.join(local_data_path, 'images'),
        anno_path=os.path.join(local_data_path, 'annotation.json'),
        is_training=True,
        batch_size=args.per_batch_size,
        max_epoch=args.max_epoch,
        device_num=args.group_size,
        rank=args.rank,
        config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size /
                               args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch * 10

    lr = get_lr(args)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)
    is_gpu = context.get_context("device_target") == "GPU"
    if is_gpu:
        loss_scale_value = 1.0
        loss_scale = FixedLossScaleManager(loss_scale_value,
                                           drop_overflow_update=False)
        network = amp.build_train_network(network,
                                          optimizer=opt,
                                          loss_scale_manager=loss_scale,
                                          level="O2",
                                          keep_batchnorm_fp32=False)
        keep_loss_fp32(network)
    else:
        network = TrainingWrapper(network, opt)
        network.set_train()

    # checkpoint save
    ckpt_max_num = 10
    ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                   keep_checkpoint_max=ckpt_max_num)
    ckpt_cb = ModelCheckpoint(config=ckpt_config,
                              directory=local_ckpt_path,
                              prefix='yolov4')
    cb_params = _InternalCallbackParam()
    cb_params.train_network = network
    cb_params.epoch_num = ckpt_max_num
    cb_params.cur_epoch_num = 1
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)

    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        images = Tensor.from_numpy(images)

        batch_y_true_0 = Tensor.from_numpy(data['bbox1'])
        batch_y_true_1 = Tensor.from_numpy(data['bbox2'])
        batch_y_true_2 = Tensor.from_numpy(data['bbox3'])
        batch_gt_box0 = Tensor.from_numpy(data['gt_box1'])
        batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
        batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2,
                       input_shape)
        loss_meter.update(loss.asnumpy())

        # ckpt progress
        cb_params.cur_step_num = i + 1  # current step number
        cb_params.batch_num = i + 2
        ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(
                        epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0:
            cb_params.cur_epoch_num += 1

    args.logger.info('==========end training===============')

    # upload checkpoint files
    print('Upload checkpoint.')
    mox.file.copy_parallel(src_url=local_ckpt_path, dst_url=args.train_url)
コード例 #8
0
ファイル: train.py プロジェクト: peixinhou/mindspore
def train(args):
    '''train'''
    print('=============yolov3 start trainging==================')
    devid = int(os.getenv('DEVICE_ID',
                          '0')) if args.run_platform != 'CPU' else 0
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args.run_platform,
                        save_graphs=False,
                        device_id=devid)
    # init distributed
    if args.world_size != 1:
        init()
        args.local_rank = get_rank()
        args.world_size = get_group_size()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            device_num=args.world_size,
            gradients_mean=True)
    args.logger = get_logger(args.outputs_dir, args.local_rank)

    # dataloader
    ds = create_dataset(args)

    args.logger.important_info('start create network')
    create_network_start = time.time()

    train_net = define_network(args)

    # checkpoint
    ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
    train_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                    keep_checkpoint_max=ckpt_max_num)
    ckpt_cb = ModelCheckpoint(config=train_config,
                              directory=args.outputs_dir,
                              prefix='{}'.format(args.local_rank))
    cb_params = _InternalCallbackParam()
    cb_params.train_network = train_net
    cb_params.epoch_num = ckpt_max_num
    cb_params.cur_epoch_num = 1
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1
    i = 0
    if args.use_loss_scale:
        scale_manager = DynamicLossScaleManager(init_loss_scale=2**10,
                                                scale_factor=2,
                                                scale_window=2000)
    for data in ds.create_tuple_iterator(output_numpy=True):
        batch_images = data[0]
        batch_labels = data[1]
        input_list = [Tensor(batch_images, mstype.float32)]
        for idx in range(2, 26):
            input_list.append(Tensor(data[idx], mstype.float32))
        if args.use_loss_scale:
            scaling_sens = Tensor(scale_manager.get_loss_scale(),
                                  dtype=mstype.float32)
            loss0, overflow, _ = train_net(*input_list, scaling_sens)
            overflow = np.all(overflow.asnumpy())
            if overflow:
                scale_manager.update_loss_scale(overflow)
            else:
                scale_manager.update_loss_scale(False)
            args.logger.info(
                'rank[{}], iter[{}], loss[{}], overflow:{}, loss_scale:{}, lr:{}, batch_images:{}, '
                'batch_labels:{}'.format(args.local_rank, i, loss0, overflow,
                                         scaling_sens, args.lr[i],
                                         batch_images.shape,
                                         batch_labels.shape))
        else:
            loss0 = train_net(*input_list)
            args.logger.info(
                'rank[{}], iter[{}], loss[{}], lr:{}, batch_images:{}, '
                'batch_labels:{}'.format(args.local_rank, i, loss0, args.lr[i],
                                         batch_images.shape,
                                         batch_labels.shape))
        # save ckpt
        cb_params.cur_step_num = i + 1  # current step number
        cb_params.batch_num = i + 2
        if args.local_rank == 0:
            ckpt_cb.step_end(run_context)

        # save Log
        if i == 0:
            time_for_graph_compile = time.time() - create_network_start
            args.logger.important_info(
                'Yolov3, graph compile time={:.2f}s'.format(
                    time_for_graph_compile))

        if i % args.steps_per_epoch == 0:
            cb_params.cur_epoch_num += 1

        if i % args.log_interval == 0 and args.local_rank == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.batch_size * (
                i - old_progress) * args.world_size / time_used
            args.logger.info(
                'epoch[{}], iter[{}], loss:[{}], {:.2f} imgs/sec'.format(
                    epoch, i, loss0, fps))
            t_end = time.time()
            old_progress = i

        if i % args.steps_per_epoch == 0 and args.local_rank == 0:
            epoch_time_used = time.time() - t_epoch
            epoch = int(i / args.steps_per_epoch)
            fps = args.batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
            args.logger.info(
                '=================================================')
            args.logger.info(
                'epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(
                    epoch, i, fps))
            args.logger.info(
                '=================================================')
            t_epoch = time.time()

        i = i + 1

    args.logger.info('=============yolov3 training finished==================')
コード例 #9
0
def train():
    """Train function."""
    args = parse_args()

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    if args.need_profiler:
        from mindspore.profiler.profiling import Profiler
        profiler = Profiler(output_path=args.outputs_dir,
                            is_detail=True,
                            is_show_op_path=True)

    loss_meter = AverageMeter('loss')

    context.reset_auto_parallel_context()
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        degree = get_group_size()
    else:
        parallel_mode = ParallelMode.STAND_ALONE
        degree = 1
    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      mirror_mean=True,
                                      device_num=degree)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)

    if args.resume_yolov3:
        param_dict = load_checkpoint(args.resume_yolov3)
        param_dict_new = {}
        for key, values in param_dict.items():
            args.logger.info('ckpt param name = {}'.format(key))
            if key.startswith('moments.') or key.startswith('global_') or \
               key.startswith('learning_rate') or key.startswith('momentum'):
                continue
            elif key.startswith('yolo_network.'):
                key_new = key[13:]

                if key_new.endswith('1.beta'):
                    key_new = key_new.replace('1.beta', 'batchnorm.beta')

                if key_new.endswith('1.gamma'):
                    key_new = key_new.replace('1.gamma', 'batchnorm.gamma')

                if key_new.endswith('1.moving_mean'):
                    key_new = key_new.replace('1.moving_mean',
                                              'batchnorm.moving_mean')

                if key_new.endswith('1.moving_variance'):
                    key_new = key_new.replace('1.moving_variance',
                                              'batchnorm.moving_variance')

                if key_new.endswith('.weight'):
                    if key_new.endswith('0.weight'):
                        key_new = key_new.replace('0.weight', 'conv.weight')
                    else:
                        key_new = key_new.replace('.weight', '.conv.weight')

                if key_new.endswith('.bias'):
                    key_new = key_new.replace('.bias', '.conv.bias')
                param_dict_new[key_new] = values

                args.logger.info('in resume {}'.format(key_new))
            else:
                param_dict_new[key] = values
                args.logger.info('in resume {}'.format(key))

        args.logger.info('resume finished')
        for _, param in network.parameters_and_names():
            args.logger.info('network param name = {}'.format(param.name))
            if param.name not in param_dict_new:
                args.logger.info('not match param name = {}'.format(
                    param.name))
        load_param_into_net(network, param_dict_new)
        args.logger.info('load_model {} success'.format(args.resume_yolov3))

    config = ConfigYOLOV3DarkNet53()
    # convert fusion network to quantization aware network
    if config.quantization_aware:
        network = quant.convert_quant_network(network,
                                              bn_fold=True,
                                              per_channel=[True, False],
                                              symmetric=[True, False])

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]

    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root,
                                        anno_path=args.annFile,
                                        is_training=True,
                                        batch_size=args.per_batch_size,
                                        max_epoch=args.max_epoch,
                                        device_num=args.group_size,
                                        rank=args.rank,
                                        config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size /
                               args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(
            args.lr,
            args.lr_epochs,
            args.steps_per_epoch,
            args.warmup_epochs,
            args.max_epoch,
            gamma=args.lr_gamma,
        )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch,
                                        args.warmup_epochs, args.max_epoch,
                                        args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_V2':
        lr = warmup_cosine_annealing_lr_V2(args.lr, args.steps_per_epoch,
                                           args.warmup_epochs, args.max_epoch,
                                           args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_sample':
        lr = warmup_cosine_annealing_lr_sample(args.lr, args.steps_per_epoch,
                                               args.warmup_epochs,
                                               args.max_epoch, args.T_max,
                                               args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    network = TrainingWrapper(network, opt)
    network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator()

    shape_record = ShapeRecord()
    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
        shape_record.set(input_shape)

        images = Tensor(images)
        annos = data["annotation"]
        if args.group_size == 1:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box(annos, config, input_shape)
        else:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box_single(annos, config, input_shape)

        batch_y_true_0 = Tensor(batch_y_true_0)
        batch_y_true_1 = Tensor(batch_y_true_1)
        batch_y_true_2 = Tensor(batch_y_true_2)
        batch_gt_box0 = Tensor(batch_gt_box0)
        batch_gt_box1 = Tensor(batch_gt_box1)
        batch_gt_box2 = Tensor(batch_gt_box2)

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2,
                       input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(
                        epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break

    args.logger.info('==========end training===============')
コード例 #10
0
def train():
    '''train function.'''
    # logger
    args = parse_args()

    # init distributed
    if args.world_size != 1:
        init()
        args.local_rank = get_rank()
        args.world_size = get_group_size()

    args.per_batch_size = config.per_batch_size
    args.dst_h = config.dst_h
    args.dst_w = config.dst_w
    args.workers = config.workers
    args.attri_num = config.attri_num
    args.classes = config.classes
    args.backbone = config.backbone
    args.loss_scale = config.loss_scale
    args.flat_dim = config.flat_dim
    args.fc_dim = config.fc_dim
    args.lr = config.lr
    args.lr_scale = config.lr_scale
    args.lr_epochs = config.lr_epochs
    args.weight_decay = config.weight_decay
    args.momentum = config.momentum
    args.max_epoch = config.max_epoch
    args.warmup_epochs = config.warmup_epochs
    args.log_interval = config.log_interval
    args.ckpt_path = config.ckpt_path

    if args.world_size == 1:
        args.per_batch_size = 256
    else:
        args.lr = args.lr * 4.

    if args.world_size != 1:
        parallel_mode = ParallelMode.DATA_PARALLEL
    else:
        parallel_mode = ParallelMode.STAND_ALONE

    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)

    # model and log save path
    args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.local_rank)
    loss_meter = AverageMeter('loss')

    # dataloader
    args.logger.info('start create dataloader')
    de_dataloader, steps_per_epoch, num_classes = data_generator(args)
    args.steps_per_epoch = steps_per_epoch
    args.num_classes = num_classes
    args.logger.info('end create dataloader')
    args.logger.save_args(args)

    # backbone and loss
    args.logger.important_info('start create network')
    create_network_start = time.time()
    network = get_resnet18(args)

    criterion = get_loss()

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    # optimizer and lr scheduler
    lr = warmup_step(args, gamma=0.1)
    opt = Momentum(params=network.trainable_params(),
                   learning_rate=lr,
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    train_net = BuildTrainNetwork(network, criterion)

    # mixed precision training
    criterion.add_flags_recursive(fp32=True)

    # package training process
    train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
    context.reset_auto_parallel_context()

    # checkpoint
    if args.local_rank == 0:
        ckpt_max_num = args.max_epoch
        train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = train_net
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 0
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1

    i = 0
    for _, (data, gt_classes) in enumerate(de_dataloader):

        data_tensor = Tensor(data, dtype=mstype.float32)
        gt_tensor = Tensor(gt_classes, dtype=mstype.int32)

        loss = train_net(data_tensor, gt_tensor)
        loss_meter.update(loss.asnumpy()[0])

        # save ckpt
        if args.local_rank == 0:
            cb_params.cur_step_num = i + 1
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.steps_per_epoch == 0 and args.local_rank == 0:
            cb_params.cur_epoch_num += 1

        # save Log
        if i == 0:
            time_for_graph_compile = time.time() - create_network_start
            args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))

        if i % args.log_interval == 0 and args.local_rank == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
            args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))

            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if i % args.steps_per_epoch == 0 and args.local_rank == 0:
            epoch_time_used = time.time() - t_epoch
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
            args.logger.info('=================================================')
            args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
            args.logger.info('=================================================')
            t_epoch = time.time()

        i += 1

    args.logger.info('--------- trains out ---------')
コード例 #11
0
def train(args):
    '''train'''
    print('=============yolov3 start trainging==================')


    # init distributed
    if args.world_size != 1:
        init()
        args.local_rank = get_rank()
        args.world_size = get_group_size()

    args.batch_size = config.batch_size
    args.warmup_lr = config.warmup_lr
    args.lr_rates = config.lr_rates
    args.lr_steps = config.lr_steps
    args.gamma = config.gamma
    args.weight_decay = config.weight_decay
    args.momentum = config.momentum
    args.max_epoch = config.max_epoch
    args.log_interval = config.log_interval
    args.ckpt_path = config.ckpt_path
    args.ckpt_interval = config.ckpt_interval

    args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    print('args.outputs_dir', args.outputs_dir)

    args.logger = get_logger(args.outputs_dir, args.local_rank)

    if args.world_size != 8:
        args.lr_steps = [i * 8 // args.world_size for i in args.lr_steps]

    if args.world_size == 1:
        args.weight_decay = 0.

    if args.world_size != 1:
        parallel_mode = ParallelMode.DATA_PARALLEL
    else:
        parallel_mode = ParallelMode.STAND_ALONE

    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.world_size, gradients_mean=True)
    mindrecord_path = args.mindrecord_path

    num_classes = config.num_classes
    anchors = config.anchors
    anchors_mask = config.anchors_mask
    num_anchors_list = [len(x) for x in anchors_mask]

    momentum = args.momentum
    args.logger.info('train opt momentum:{}'.format(momentum))

    weight_decay = args.weight_decay * float(args.batch_size)
    args.logger.info('real weight_decay:{}'.format(weight_decay))
    lr_scale = args.world_size / 8
    args.logger.info('lr_scale:{}'.format(lr_scale))

    # dataloader
    args.logger.info('start create dataloader')
    epoch = args.max_epoch
    ds = de.MindDataset(mindrecord_path + "0", columns_list=["image", "annotation"], num_shards=args.world_size,
                        shard_id=args.local_rank)

    ds = ds.map(input_columns=["image", "annotation"],
                output_columns=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0',
                                'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1',
                                'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1',
                                't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2',
                                'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'],
                column_order=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0',
                              'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1',
                              'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1',
                              't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2',
                              'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'],
                operations=compose_map_func, num_parallel_workers=16, python_multiprocessing=True)

    ds = ds.batch(args.batch_size, drop_remainder=True, num_parallel_workers=8)

    args.steps_per_epoch = ds.get_dataset_size()
    lr = warmup_step_new(args, lr_scale=lr_scale)

    ds = ds.repeat(epoch)
    args.logger.info('args.steps_per_epoch:{}'.format(args.steps_per_epoch))
    args.logger.info('args.world_size:{}'.format(args.world_size))
    args.logger.info('args.local_rank:{}'.format(args.local_rank))
    args.logger.info('end create dataloader')
    args.logger.save_args(args)
    args.logger.important_info('start create network')
    create_network_start = time.time()

    # backbone and loss
    network = backbone_HwYolov3(num_classes, num_anchors_list, args)

    criterion0 = YoloLoss(num_classes, anchors, anchors_mask[0], 64, 0, head_idx=0.0)
    criterion1 = YoloLoss(num_classes, anchors, anchors_mask[1], 32, 0, head_idx=1.0)
    criterion2 = YoloLoss(num_classes, anchors, anchors_mask[2], 16, 0, head_idx=2.0)

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    train_net = BuildTrainNetworkV2(network, criterion0, criterion1, criterion2, args)

    # optimizer
    opt = Momentum(params=train_net.trainable_params(), learning_rate=Tensor(lr), momentum=momentum,
                   weight_decay=weight_decay)

    # package training process
    train_net = TrainOneStepWithLossScaleCell(train_net, opt)
    train_net.set_broadcast_flag()

    # checkpoint
    ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
    train_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
    ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
    cb_params = _InternalCallbackParam()
    cb_params.train_network = train_net
    cb_params.epoch_num = ckpt_max_num
    cb_params.cur_epoch_num = 1
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1
    i = 0
    scale_manager = DynamicLossScaleManager(init_loss_scale=2 ** 10, scale_factor=2, scale_window=2000)

    for data in ds.create_tuple_iterator(output_numpy=True):

        batch_images = data[0]
        batch_labels = data[1]
        coord_mask_0 = data[2]
        conf_pos_mask_0 = data[3]
        conf_neg_mask_0 = data[4]
        cls_mask_0 = data[5]
        t_coord_0 = data[6]
        t_conf_0 = data[7]
        t_cls_0 = data[8]
        gt_list_0 = data[9]
        coord_mask_1 = data[10]
        conf_pos_mask_1 = data[11]
        conf_neg_mask_1 = data[12]
        cls_mask_1 = data[13]
        t_coord_1 = data[14]
        t_conf_1 = data[15]
        t_cls_1 = data[16]
        gt_list_1 = data[17]
        coord_mask_2 = data[18]
        conf_pos_mask_2 = data[19]
        conf_neg_mask_2 = data[20]
        cls_mask_2 = data[21]
        t_coord_2 = data[22]
        t_conf_2 = data[23]
        t_cls_2 = data[24]
        gt_list_2 = data[25]

        img_tensor = Tensor(batch_images, mstype.float32)
        coord_mask_tensor_0 = Tensor(coord_mask_0.astype(np.float32))
        conf_pos_mask_tensor_0 = Tensor(conf_pos_mask_0.astype(np.float32))
        conf_neg_mask_tensor_0 = Tensor(conf_neg_mask_0.astype(np.float32))
        cls_mask_tensor_0 = Tensor(cls_mask_0.astype(np.float32))
        t_coord_tensor_0 = Tensor(t_coord_0.astype(np.float32))
        t_conf_tensor_0 = Tensor(t_conf_0.astype(np.float32))
        t_cls_tensor_0 = Tensor(t_cls_0.astype(np.float32))
        gt_list_tensor_0 = Tensor(gt_list_0.astype(np.float32))

        coord_mask_tensor_1 = Tensor(coord_mask_1.astype(np.float32))
        conf_pos_mask_tensor_1 = Tensor(conf_pos_mask_1.astype(np.float32))
        conf_neg_mask_tensor_1 = Tensor(conf_neg_mask_1.astype(np.float32))
        cls_mask_tensor_1 = Tensor(cls_mask_1.astype(np.float32))
        t_coord_tensor_1 = Tensor(t_coord_1.astype(np.float32))
        t_conf_tensor_1 = Tensor(t_conf_1.astype(np.float32))
        t_cls_tensor_1 = Tensor(t_cls_1.astype(np.float32))
        gt_list_tensor_1 = Tensor(gt_list_1.astype(np.float32))

        coord_mask_tensor_2 = Tensor(coord_mask_2.astype(np.float32))
        conf_pos_mask_tensor_2 = Tensor(conf_pos_mask_2.astype(np.float32))
        conf_neg_mask_tensor_2 = Tensor(conf_neg_mask_2.astype(np.float32))
        cls_mask_tensor_2 = Tensor(cls_mask_2.astype(np.float32))
        t_coord_tensor_2 = Tensor(t_coord_2.astype(np.float32))
        t_conf_tensor_2 = Tensor(t_conf_2.astype(np.float32))
        t_cls_tensor_2 = Tensor(t_cls_2.astype(np.float32))
        gt_list_tensor_2 = Tensor(gt_list_2.astype(np.float32))

        scaling_sens = Tensor(scale_manager.get_loss_scale(), dtype=mstype.float32)

        loss0, overflow, _ = train_net(img_tensor, coord_mask_tensor_0, conf_pos_mask_tensor_0,
                                       conf_neg_mask_tensor_0, cls_mask_tensor_0, t_coord_tensor_0,
                                       t_conf_tensor_0, t_cls_tensor_0, gt_list_tensor_0,
                                       coord_mask_tensor_1, conf_pos_mask_tensor_1, conf_neg_mask_tensor_1,
                                       cls_mask_tensor_1, t_coord_tensor_1, t_conf_tensor_1,
                                       t_cls_tensor_1, gt_list_tensor_1, coord_mask_tensor_2,
                                       conf_pos_mask_tensor_2, conf_neg_mask_tensor_2,
                                       cls_mask_tensor_2, t_coord_tensor_2, t_conf_tensor_2,
                                       t_cls_tensor_2, gt_list_tensor_2, scaling_sens)

        overflow = np.all(overflow.asnumpy())
        if overflow:
            scale_manager.update_loss_scale(overflow)
        else:
            scale_manager.update_loss_scale(False)
        args.logger.info('rank[{}], iter[{}], loss[{}], overflow:{}, loss_scale:{}, lr:{}, batch_images:{}, '
                         'batch_labels:{}'.format(args.local_rank, i, loss0, overflow, scaling_sens, lr[i],
                                                  batch_images.shape, batch_labels.shape))

        # save ckpt
        cb_params.cur_step_num = i + 1  # current step number
        cb_params.batch_num = i + 2
        if args.local_rank == 0:
            ckpt_cb.step_end(run_context)

        # save Log
        if i == 0:
            time_for_graph_compile = time.time() - create_network_start
            args.logger.important_info('Yolov3, graph compile time={:.2f}s'.format(time_for_graph_compile))

        if i % args.steps_per_epoch == 0:
            cb_params.cur_epoch_num += 1

        if i % args.log_interval == 0 and args.local_rank == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.batch_size * (i - old_progress) * args.world_size / time_used
            args.logger.info('epoch[{}], iter[{}], loss:[{}], {:.2f} imgs/sec'.format(epoch, i, loss0, fps))
            t_end = time.time()
            old_progress = i

        if i % args.steps_per_epoch == 0 and args.local_rank == 0:
            epoch_time_used = time.time() - t_epoch
            epoch = int(i / args.steps_per_epoch)
            fps = args.batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
            args.logger.info('=================================================')
            args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
            args.logger.info('=================================================')
            t_epoch = time.time()

        i = i + 1

    args.logger.info('=============yolov3 training finished==================')
コード例 #12
0
ファイル: train.py プロジェクト: xiaoxiugege/mindspore
def main(args):

    if args.is_distributed == 0:
        cfg = faceqa_1p_cfg
    else:
        cfg = faceqa_8p_cfg

    cfg.data_lst = args.train_label_file
    cfg.pretrained = args.pretrained

    # Init distributed
    if args.is_distributed:
        init()
        cfg.local_rank = get_rank()
        cfg.world_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
    else:
        parallel_mode = ParallelMode.STAND_ALONE

    # parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      device_num=cfg.world_size,
                                      gradients_mean=True)

    mindspore.common.set_seed(1)

    # logger
    cfg.outputs_dir = os.path.join(
        cfg.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
    loss_meter = AverageMeter('loss')

    # Dataloader
    cfg.logger.info('start create dataloader')
    de_dataset = faceqa_dataset(imlist=cfg.data_lst,
                                local_rank=cfg.local_rank,
                                world_size=cfg.world_size,
                                per_batch_size=cfg.per_batch_size)
    cfg.steps_per_epoch = de_dataset.get_dataset_size()
    de_dataset = de_dataset.repeat(cfg.max_epoch)
    de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
    # Show cfg
    cfg.logger.save_args(cfg)
    cfg.logger.info('end create dataloader')

    # backbone and loss
    cfg.logger.important_info('start create network')
    create_network_start = time.time()

    network = FaceQABackbone()
    criterion = CriterionsFaceQA()

    # load pretrain model
    if os.path.isfile(cfg.pretrained):
        param_dict = load_checkpoint(cfg.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        cfg.logger.info('load model {} success'.format(cfg.pretrained))

    # optimizer and lr scheduler
    lr = warmup_step(cfg, gamma=0.9)
    opt = Momentum(params=network.trainable_params(),
                   learning_rate=lr,
                   momentum=cfg.momentum,
                   weight_decay=cfg.weight_decay,
                   loss_scale=cfg.loss_scale)

    # package training process, adjust lr + forward + backward + optimizer
    train_net = BuildTrainNetwork(network, criterion)
    train_net = TrainOneStepCell(
        train_net,
        opt,
        sens=cfg.loss_scale,
    )

    # checkpoint save
    if cfg.local_rank == 0:
        ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
        train_config = CheckpointConfig(
            save_checkpoint_steps=cfg.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=train_config,
                                  directory=cfg.outputs_dir,
                                  prefix='{}'.format(cfg.local_rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = train_net
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1

    cfg.logger.important_info('====start train====')
    for i, (data, gt) in enumerate(de_dataloader):
        # clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss
        data = data.astype(np.float32)
        gt = gt.astype(np.float32)
        data = Tensor(data)
        gt = Tensor(gt)

        loss = train_net(data, gt)
        loss_meter.update(loss.asnumpy())

        # ckpt
        if cfg.local_rank == 0:
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        # logging loss, fps, ...
        if i == 0:
            time_for_graph_compile = time.time() - create_network_start
            cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(
                cfg.task, time_for_graph_compile))

        if i % cfg.log_interval == 0 and cfg.local_rank == 0:
            time_used = time.time() - t_end
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * (
                i - old_progress) * cfg.world_size / time_used
            cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(
                epoch, i, loss_meter, fps))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
            epoch_time_used = time.time() - t_epoch
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
            cfg.logger.info(
                '=================================================')
            cfg.logger.info(
                'epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(
                    epoch, i, fps))
            cfg.logger.info(
                '=================================================')
            t_epoch = time.time()

    cfg.logger.important_info('====train end====')
コード例 #13
0
def train():
    """Train function."""
    args = parse_args()
    args.logger.save_args(args)

    if args.need_profiler:
        from mindspore.profiler.profiling import Profiler
        profiler = Profiler(output_path=args.outputs_dir,
                            is_detail=True,
                            is_show_op_path=True)

    loss_meter = AverageMeter('loss')

    context.reset_auto_parallel_context()
    parallel_mode = ParallelMode.STAND_ALONE
    degree = 1
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        degree = get_group_size()
    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      gradients_mean=True,
                                      device_num=degree)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)
    load_yolov3_quant_params(args, network)

    config = ConfigYOLOV3DarkNet53()
    # convert fusion network to quantization aware network
    if config.quantization_aware:
        network = quant.convert_quant_network(network,
                                              bn_fold=True,
                                              per_channel=[True, False],
                                              symmetric=[True, False])

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]

    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root,
                                        anno_path=args.annFile,
                                        is_training=True,
                                        batch_size=args.per_batch_size,
                                        max_epoch=args.max_epoch,
                                        device_num=args.group_size,
                                        rank=args.rank,
                                        config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size /
                               args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    lr = get_lr(args)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    network = TrainingWrapper(network, opt)
    network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        save_ckpt_path = os.path.join(args.outputs_dir,
                                      'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=save_ckpt_path,
                                  prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)

    shape_record = ShapeRecord()
    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
        shape_record.set(input_shape)

        images = Tensor.from_numpy(images)
        annos = data["annotation"]
        if args.group_size == 1:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box(annos, config, input_shape)
        else:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box_single(annos, config, input_shape)

        batch_y_true_0 = Tensor.from_numpy(batch_y_true_0)
        batch_y_true_1 = Tensor.from_numpy(batch_y_true_1)
        batch_y_true_2 = Tensor.from_numpy(batch_y_true_2)
        batch_gt_box0 = Tensor.from_numpy(batch_gt_box0)
        batch_gt_box1 = Tensor.from_numpy(batch_gt_box1)
        batch_gt_box2 = Tensor.from_numpy(batch_gt_box2)

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2,
                       input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(
                        epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break

    args.logger.info('==========end training===============')
コード例 #14
0
def train():
    """Train function."""
    args = parse_args()

    devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
    context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
                        device_target=args.device_target, save_graphs=True, device_id=devid)

    # init distributed
    if args.is_distributed:
        if args.device_target == "Ascend":
            init()
        else:
            init("nccl")
        args.rank = get_rank()
        args.group_size = get_group_size()

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    if args.need_profiler:
        from mindspore.profiler.profiling import Profiler
        profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)

    loss_meter = AverageMeter('loss')

    context.reset_auto_parallel_context()
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        degree = get_group_size()
    else:
        parallel_mode = ParallelMode.STAND_ALONE
        degree = 1
    context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree)

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recurisive_init(network)

    if args.pretrained_backbone:
        network = load_backbone(network, args.pretrained_backbone, args)
        args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
    else:
        args.logger.info('Not load pre-trained backbone, please be careful')

    if args.resume_yolov3:
        param_dict = load_checkpoint(args.resume_yolov3)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('yolo_network.'):
                param_dict_new[key[13:]] = values
                args.logger.info('in resume {}'.format(key))
            else:
                param_dict_new[key] = values
                args.logger.info('in resume {}'.format(key))

        args.logger.info('resume finished')
        load_param_into_net(network, param_dict_new)
        args.logger.info('load_model {} success'.format(args.resume_yolov3))

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config = ConfigYOLOV3DarkNet53()

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [conver_training_shape(args)]
    if args.resize_rate:
        config.resize_rate = args.resize_rate

    ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
                                        batch_size=args.per_batch_size, max_epoch=args.max_epoch,
                                        device_num=args.group_size, rank=args.rank, config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(args.lr,
                            args.lr_epochs,
                            args.steps_per_epoch,
                            args.warmup_epochs,
                            args.max_epoch,
                            gamma=args.lr_gamma,
                            )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr,
                                        args.steps_per_epoch,
                                        args.warmup_epochs,
                                        args.max_epoch,
                                        args.T_max,
                                        args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_V2':
        lr = warmup_cosine_annealing_lr_V2(args.lr,
                                           args.steps_per_epoch,
                                           args.warmup_epochs,
                                           args.max_epoch,
                                           args.T_max,
                                           args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_sample':
        lr = warmup_cosine_annealing_lr_sample(args.lr,
                                               args.steps_per_epoch,
                                               args.warmup_epochs,
                                               args.max_epoch,
                                               args.T_max,
                                               args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)
    enable_amp = False
    is_gpu = context.get_context("device_target") == "GPU"
    if is_gpu:
        enable_amp = True
    if enable_amp:
        loss_scale_value = 1.0
        loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
        network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
                                          level="O2", keep_batchnorm_fp32=True)
        keep_loss_fp32(network)
    else:
        network = TrainingWrapper(network, opt)
        network.set_train()

    if args.rank_save_ckpt_flag:
        # checkpoint save
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                       keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = network
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator()

    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))

        images = Tensor(images)

        batch_y_true_0 = Tensor(data['bbox1'])
        batch_y_true_1 = Tensor(data['bbox2'])
        batch_y_true_2 = Tensor(data['bbox3'])
        batch_gt_box0 = Tensor(data['gt_box1'])
        batch_gt_box1 = Tensor(data['gt_box2'])
        batch_gt_box2 = Tensor(data['gt_box3'])

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
                       batch_gt_box2, input_shape)
        loss_meter.update(loss.asnumpy())

        if args.rank_save_ckpt_flag:
            # ckpt progress
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
            cb_params.cur_epoch_num += 1

        if args.need_profiler:
            if i == 10:
                profiler.analyse()
                break

    args.logger.info('==========end training===============')
コード例 #15
0
ファイル: train.py プロジェクト: peixinhou/mindspore
def main():
    cfg, args = init_argument()
    loss_meter = AverageMeter('loss')
    # dataloader
    cfg.logger.info('start create dataloader')
    de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
    cfg.steps_per_epoch = steps_per_epoch
    cfg.logger.info('step per epoch: %s', cfg.steps_per_epoch)
    de_dataloader = de_dataset.create_tuple_iterator()
    cfg.logger.info('class num original: %s', class_num)
    if class_num % 16 != 0:
        class_num = (class_num // 16 + 1) * 16
    cfg.class_num = class_num
    cfg.logger.info('change the class num to: %s', cfg.class_num)
    cfg.logger.info('end create dataloader')

    # backbone and loss
    cfg.logger.important_info('start create network')
    create_network_start = time.time()

    network = SphereNet(num_layers=cfg.net_depth,
                        feature_dim=cfg.embedding_size,
                        shape=cfg.input_size)
    if args.device_target == 'CPU':
        head = CombineMarginFC(embbeding_size=cfg.embedding_size,
                               classnum=cfg.class_num)
    else:
        head = CombineMarginFCFp16(embbeding_size=cfg.embedding_size,
                                   classnum=cfg.class_num)
    criterion = CrossEntropy()

    # load the pretrained model
    if os.path.isfile(cfg.pretrained):
        param_dict = load_checkpoint(cfg.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        cfg.logger.info('load model %s success', cfg.pretrained)

    # mixed precision training
    if args.device_target == 'CPU':
        network.add_flags_recursive(fp32=True)
        head.add_flags_recursive(fp32=True)
    else:
        network.add_flags_recursive(fp16=True)
        head.add_flags_recursive(fp16=True)
    criterion.add_flags_recursive(fp32=True)

    train_net = BuildTrainNetworkWithHead(network, head, criterion)

    # optimizer and lr scheduler
    lr = step_lr(lr=cfg.lr,
                 epoch_size=cfg.epoch_size,
                 steps_per_epoch=cfg.steps_per_epoch,
                 max_epoch=cfg.max_epoch,
                 gamma=cfg.lr_gamma)
    opt = SGD(params=train_net.trainable_params(),
              learning_rate=lr,
              momentum=cfg.momentum,
              weight_decay=cfg.weight_decay,
              loss_scale=cfg.loss_scale)

    # package training process, adjust lr + forward + backward + optimizer
    train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale)

    # checkpoint save
    if cfg.local_rank == 0:
        ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
        train_config = CheckpointConfig(
            save_checkpoint_steps=cfg.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=train_config,
                                  directory=cfg.outputs_dir,
                                  prefix='{}'.format(cfg.local_rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = train_net
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1

    cfg.logger.important_info('====start train====')
    for i, total_data in enumerate(de_dataloader):
        data, gt = total_data
        data = Tensor(data)
        gt = Tensor(gt)

        loss = train_net(data, gt)
        loss_meter.update(loss.asnumpy())

        # ckpt
        if cfg.local_rank == 0:
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        # logging loss, fps, ...
        if i == 0:
            time_for_graph_compile = time.time() - create_network_start
            cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(
                cfg.task, time_for_graph_compile))

        if i % cfg.log_interval == 0 and cfg.local_rank == 0:
            time_used = time.time() - t_end
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * (
                i - old_progress) * cfg.world_size / time_used
            cfg.logger.info(
                'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr={}'.format(
                    epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
            epoch_time_used = time.time() - t_epoch
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
            cfg.logger.info(
                '=================================================')
            cfg.logger.info(
                'epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(
                    epoch, i, fps))
            cfg.logger.info(
                '=================================================')
            t_epoch = time.time()

    cfg.logger.important_info('====train end====')
コード例 #16
0
def train():
    """Train function."""
    args = parse_args()

    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    loss_meter = AverageMeter('loss')

    network = YOLOV3DarkNet53(is_training=True)
    # default is kaiming-normal
    default_recursive_init(network)

    pretrained_backbone_slice = args.pretrained_backbone.split('/')
    backbone_ckpt_file = pretrained_backbone_slice[
        len(pretrained_backbone_slice) - 1]
    local_backbone_ckpt_path = '/cache/' + backbone_ckpt_file
    # download backbone checkpoint
    mox.file.copy_parallel(src_url=args.pretrained_backbone,
                           dst_url=local_backbone_ckpt_path)

    if args.pretrained_backbone:
        network = load_backbone(network, local_backbone_ckpt_path, args)
        args.logger.info('load pre-trained backbone {} into network'.format(
            args.pretrained_backbone))
    else:
        args.logger.info('Not load pre-trained backbone, please be careful')

    if args.resume_yolov3:
        param_dict = load_checkpoint(args.resume_yolov3)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('yolo_network.'):
                param_dict_new[key[13:]] = values
                args.logger.info('in resume {}'.format(key))
            else:
                param_dict_new[key] = values
                args.logger.info('in resume {}'.format(key))

        args.logger.info('resume finished')
        load_param_into_net(network, param_dict_new)
        args.logger.info('load_model {} success'.format(args.resume_yolov3))

    network = YoloWithLossCell(network)
    args.logger.info('finish get network')

    config = ConfigYOLOV3DarkNet53()

    config.label_smooth = args.label_smooth
    config.label_smooth_factor = args.label_smooth_factor

    if args.training_shape:
        config.multi_scale = [convert_training_shape(args)]
    if args.resize_rate:
        config.resize_rate = args.resize_rate

    # data download
    local_data_path = '/cache/data'
    local_ckpt_path = '/cache/ckpt_file'
    print('Download data.')
    mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_path)

    ds, data_size = create_yolo_dataset(
        image_dir=os.path.join(local_data_path, 'images'),
        anno_path=os.path.join(local_data_path, 'annotation.json'),
        is_training=True,
        batch_size=args.per_batch_size,
        max_epoch=args.epoch_size,
        device_num=args.group_size,
        rank=args.rank,
        config=config)
    args.logger.info('Finish loading dataset')

    args.steps_per_epoch = int(data_size / args.per_batch_size /
                               args.group_size)

    if not args.ckpt_interval:
        args.ckpt_interval = args.steps_per_epoch * 10

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(
            args.lr,
            args.lr_epochs,
            args.steps_per_epoch,
            args.warmup_epochs,
            args.epoch_size,
            gamma=args.lr_gamma,
        )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch,
                                        args.warmup_epochs, args.max_epoch,
                                        args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_V2':
        lr = warmup_cosine_annealing_lr_V2(args.lr, args.steps_per_epoch,
                                           args.warmup_epochs, args.max_epoch,
                                           args.T_max, args.eta_min)
    elif args.lr_scheduler == 'cosine_annealing_sample':
        lr = warmup_cosine_annealing_lr_sample(args.lr, args.steps_per_epoch,
                                               args.warmup_epochs,
                                               args.max_epoch, args.T_max,
                                               args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    network = TrainingWrapper(network, opt)
    network.set_train()

    # checkpoint save
    ckpt_max_num = 10
    ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                   keep_checkpoint_max=ckpt_max_num)
    ckpt_cb = ModelCheckpoint(config=ckpt_config,
                              directory=local_ckpt_path,
                              prefix='yolov3')
    cb_params = _InternalCallbackParam()
    cb_params.train_network = network
    cb_params.epoch_num = ckpt_max_num
    cb_params.cur_epoch_num = 1
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

    old_progress = -1
    t_end = time.time()
    data_loader = ds.create_dict_iterator()

    shape_record = ShapeRecord()
    for i, data in enumerate(data_loader):
        images = data["image"]
        input_shape = images.shape[2:4]
        shape_record.set(input_shape)

        images = Tensor(images)
        annos = data["annotation"]
        if args.group_size == 1:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box(annos, config, input_shape)
        else:
            batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
                batch_preprocess_true_box_single(annos, config, input_shape)

        batch_y_true_0 = Tensor(batch_y_true_0)
        batch_y_true_1 = Tensor(batch_y_true_1)
        batch_y_true_2 = Tensor(batch_y_true_2)
        batch_gt_box0 = Tensor(batch_gt_box0)
        batch_gt_box1 = Tensor(batch_gt_box1)
        batch_gt_box2 = Tensor(batch_gt_box2)

        input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
        loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2,
                       batch_gt_box0, batch_gt_box1, batch_gt_box2,
                       input_shape)
        loss_meter.update(loss.asnumpy())

        # ckpt progress
        cb_params.cur_step_num = i + 1  # current step number
        cb_params.batch_num = i + 2
        ckpt_cb.step_end(run_context)

        if i % args.log_interval == 0:
            time_used = time.time() - t_end
            epoch = int(i / args.steps_per_epoch)
            fps = args.per_batch_size * (
                i - old_progress) * args.group_size / time_used
            if args.rank == 0:
                args.logger.info(
                    'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(
                        epoch, i, loss_meter, fps, lr[i]))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if (i + 1) % args.steps_per_epoch == 0:
            cb_params.cur_epoch_num += 1

    args.logger.info('==========end training===============')

    # upload checkpoint files
    print('Upload checkpoint.')
    mox.file.copy_parallel(src_url=local_ckpt_path, dst_url=args.train_url)