Exemplo n.º 1
0
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    if args.dataset == "cifar10":
        dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size)
    else:
        dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size,
                                         args.rank, args.group_size)

    batch_num = dataset.get_dataset_size()
    args.steps_per_epoch = dataset.get_dataset_size()
    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')

    # get network and init
    network = vgg16(args.num_classes, args)

    # pre_trained
    if args.pre_trained:
        load_param_into_net(network, load_checkpoint(args.pre_trained))
Exemplo n.º 2
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          gradients_mean=True)
    # dataloader
    de_dataset = classification_dataset(args.data_dir,
                                        args.image_size,
                                        args.per_batch_size,
                                        1,
                                        args.rank,
                                        args.group_size,
                                        num_parallel_workers=8)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone,
                          num_classes=args.num_classes,
                          platform=args.platform)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))

    load_pretrain_model(args.pretrained, network, args)

    # lr scheduler
    lr = get_lr(args)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=args.label_smooth_factor,
                        num_classes=args.num_classes)

    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                     scale_factor=2,
                                                     scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale,
                                                   drop_overflow_update=False)

    model = Model(network,
                  loss_fn=loss,
                  optimizer=opt,
                  loss_scale_manager=loss_scale_manager,
                  metrics={'acc'},
                  amp_level="O3")

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [
        progress_cb,
    ]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
            keep_checkpoint_max=args.ckpt_save_max)
        save_ckpt_path = os.path.join(args.outputs_dir,
                                      'ckpt_' + str(args.rank) + '/')
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=save_ckpt_path,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch,
                de_dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)
Exemplo n.º 3
0
def test(cloud_args=None):
    """test"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          gradients_mean=True)

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    if os.path.isdir(args.pretrained):
        models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
        print(models)
        if args.graph_ckpt:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split(
                    '_')[0])
        else:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
        args.models = sorted(models, key=f)
    else:
        args.models = [
            args.pretrained,
        ]

    for model in args.models:
        de_dataset = classification_dataset(args.data_dir,
                                            image_size=args.image_size,
                                            per_batch_size=args.per_batch_size,
                                            max_epoch=1,
                                            rank=args.rank,
                                            group_size=args.group_size,
                                            mode='eval')
        eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
        network = get_network(args.backbone,
                              args.num_classes,
                              platform=args.platform)
        if network is None:
            raise NotImplementedError('not implement {}'.format(args.backbone))

        load_pretrain_model(model, network, args)

        img_tot = 0
        top1_correct = 0
        top5_correct = 0
        if args.platform == "Ascend":
            network.to_float(mstype.float16)
        else:
            auto_mixed_precision(network)
        network.set_train(False)
        t_end = time.time()
        it = 0
        for data, gt_classes in eval_dataloader:
            output = network(Tensor(data, mstype.float32))
            output = output.asnumpy()

            top1_output = np.argmax(output, (-1))
            top5_output = np.argsort(output)[:, -5:]

            t1_correct = np.equal(top1_output, gt_classes).sum()
            top1_correct += t1_correct
            top5_correct += get_top5_acc(top5_output, gt_classes)
            img_tot += args.per_batch_size

            if args.rank == 0 and it == 0:
                t_end = time.time()
                it = 1
        if args.rank == 0:
            time_used = time.time() - t_end
            fps = (img_tot - args.per_batch_size) * args.group_size / time_used
            args.logger.info(
                'Inference Performance: {:.2f} img/sec'.format(fps))
        results = get_result(args, model, top1_correct, top5_correct, img_tot)
        top1_correct = results[0, 0]
        top5_correct = results[1, 0]
        img_tot = results[2, 0]
        acc1 = 100.0 * top1_correct / img_tot
        acc5 = 100.0 * top5_correct / img_tot
        args.logger.info('after allreduce eval: top1_correct={}, tot={},'
                         'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot,
                                                    acc1))
        args.logger.info('after allreduce eval: top5_correct={}, tot={},'
                         'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot,
                                                    acc5))
    if args.is_distributed:
        release()
Exemplo n.º 4
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

    if args.is_dynamic_loss_scale == 1:
        args.loss_scale = 1  # for dynamic loss scale can not set loss scale in momentum opt

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    # dataloader
    de_dataset = classification_dataset(args.data_dir, args.image_size,
                                        args.per_batch_size, args.max_epoch,
                                        args.rank, args.group_size)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone, args.num_classes)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))
    network.add_flags_recursive(fp16=True)
    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    criterion = CrossEntropy(smooth_factor=args.label_smooth_factor,
                             num_classes=args.num_classes)

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(args.lr,
                            args.lr_epochs,
                            args.steps_per_epoch,
                            args.warmup_epochs,
                            args.max_epoch,
                            gamma=args.lr_gamma,
                            )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr,
                                        args.steps_per_epoch,
                                        args.warmup_epochs,
                                        args.max_epoch,
                                        args.T_max,
                                        args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)


    criterion.add_flags_recursive(fp32=True)

    # package training process, adjust lr + forward + backward + optimizer
    train_net = BuildTrainNetwork(network, criterion)
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
    else:
        parallel_mode = ParallelMode.STAND_ALONE
    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)

    # Model api changed since TR5_branch 2020/03/09
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
                                      parameter_broadcast=True, mirror_mean=True)
    model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager)

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [progress_cb,]
    if args.rank_save_ckpt_flag:
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
                                       keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
Exemplo n.º 5
0
def test(cloud_args=None):
    """test"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          parameter_broadcast=True,
                                          mirror_mean=True)
    else:
        args.rank = 0
        args.group_size = 1

    args.outputs_dir = os.path.join(
        args.log_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    if os.path.isdir(args.pretrained):
        models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
        print(models)
        if args.graph_ckpt:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split(
                    '_')[0])
        else:
            f = lambda x: -1 * int(
                os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
        args.models = sorted(models, key=f)
    else:
        args.models = [
            args.pretrained,
        ]

    for model in args.models:
        de_dataset = classification_dataset(args.data_dir,
                                            image_size=args.image_size,
                                            per_batch_size=args.per_batch_size,
                                            max_epoch=1,
                                            rank=args.rank,
                                            group_size=args.group_size,
                                            mode='eval')
        eval_dataloader = de_dataset.create_tuple_iterator()
        network = get_network(args.backbone,
                              args.num_classes,
                              platform=args.platform)
        if network is None:
            raise NotImplementedError('not implement {}'.format(args.backbone))

        param_dict = load_checkpoint(model)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values

        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(model))

        img_tot = 0
        top1_correct = 0
        top5_correct = 0
        if args.platform == "Ascend":
            network.to_float(mstype.float16)
        else:
            auto_mixed_precision(network)
        network.set_train(False)
        t_end = time.time()
        it = 0
        for data, gt_classes in eval_dataloader:
            output = network(Tensor(data, mstype.float32))
            output = output.asnumpy()

            top1_output = np.argmax(output, (-1))
            top5_output = np.argsort(output)[:, -5:]

            t1_correct = np.equal(top1_output, gt_classes).sum()
            top1_correct += t1_correct
            top5_correct += get_top5_acc(top5_output, gt_classes)
            img_tot += args.per_batch_size

            if args.rank == 0 and it == 0:
                t_end = time.time()
                it = 1
        if args.rank == 0:
            time_used = time.time() - t_end
            fps = (img_tot - args.per_batch_size) * args.group_size / time_used
            args.logger.info(
                'Inference Performance: {:.2f} img/sec'.format(fps))
        results = [[top1_correct], [top5_correct], [img_tot]]
        args.logger.info('before results={}'.format(results))
        if args.is_distributed:
            model_md5 = model.replace('/', '')
            tmp_dir = '/cache'
            if not os.path.exists(tmp_dir):
                os.mkdir(tmp_dir)
            top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(
                args.rank, model_md5)
            top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(
                args.rank, model_md5)
            img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(
                args.rank, model_md5)
            np.save(top1_correct_npy, top1_correct)
            np.save(top5_correct_npy, top5_correct)
            np.save(img_tot_npy, img_tot)
            while True:
                rank_ok = True
                for other_rank in range(args.group_size):
                    top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(
                        other_rank, model_md5)
                    top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(
                        other_rank, model_md5)
                    img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(
                        other_rank, model_md5)
                    if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \
                       not os.path.exists(img_tot_npy):
                        rank_ok = False
                if rank_ok:
                    break

            top1_correct_all = 0
            top5_correct_all = 0
            img_tot_all = 0
            for other_rank in range(args.group_size):
                top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(
                    other_rank, model_md5)
                top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(
                    other_rank, model_md5)
                img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(
                    other_rank, model_md5)
                top1_correct_all += np.load(top1_correct_npy)
                top5_correct_all += np.load(top5_correct_npy)
                img_tot_all += np.load(img_tot_npy)
            results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
            results = np.array(results)
        else:
            results = np.array(results)

        args.logger.info('after results={}'.format(results))
        top1_correct = results[0, 0]
        top5_correct = results[1, 0]
        img_tot = results[2, 0]
        acc1 = 100.0 * top1_correct / img_tot
        acc5 = 100.0 * top5_correct / img_tot
        args.logger.info('after allreduce eval: top1_correct={}, tot={},'
                         'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot,
                                                    acc1))
        args.logger.info('after allreduce eval: top5_correct={}, tot={},'
                         'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot,
                                                    acc5))
    if args.is_distributed:
        release()
Exemplo n.º 6
0
def test(cloud_args=None):
    """test"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.device_target,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    args.outputs_dir = os.path.join(
        args.log_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    if args.dataset == "cifar10":
        net = vgg16(num_classes=args.num_classes)
        opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
                       0.01,
                       cfg.momentum,
                       weight_decay=args.weight_decay)
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True,
                                                reduction='mean',
                                                is_grad=False)
        model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})

        param_dict = load_checkpoint(args.checkpoint_path)
        load_param_into_net(net, param_dict)
        net.set_train(False)
        dataset = vgg_create_dataset(args.data_path, 1, False)
        res = model.eval(dataset)
        print("result: ", res)
    else:
        # network
        args.logger.important_info('start create network')
        if os.path.isdir(args.pretrained):
            models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
            print(models)
            if args.graph_ckpt:
                f = lambda x: -1 * int(
                    os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].
                    split('_')[0])
            else:
                f = lambda x: -1 * int(
                    os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
            args.models = sorted(models, key=f)
        else:
            args.models = [
                args.pretrained,
            ]

        for model in args.models:
            if args.dataset == "cifar10":
                dataset = vgg_create_dataset(args.data_path,
                                             args.image_size,
                                             args.per_batch_size,
                                             training=False)
            else:
                dataset = classification_dataset(args.data_path,
                                                 args.image_size,
                                                 args.per_batch_size)

            eval_dataloader = dataset.create_tuple_iterator()
            network = vgg16(args.num_classes, args, phase="test")

            # pre_trained
            load_param_into_net(network, load_checkpoint(model))
            network.add_flags_recursive(fp16=True)

            img_tot = 0
            top1_correct = 0
            top5_correct = 0

            network.set_train(False)
            t_end = time.time()
            it = 0
            for data, gt_classes in eval_dataloader:
                output = network(Tensor(data, mstype.float32))
                output = output.asnumpy()

                top1_output = np.argmax(output, (-1))
                top5_output = np.argsort(output)[:, -5:]

                t1_correct = np.equal(top1_output, gt_classes).sum()
                top1_correct += t1_correct
                top5_correct += get_top5_acc(top5_output, gt_classes)
                img_tot += args.per_batch_size

                if args.rank == 0 and it == 0:
                    t_end = time.time()
                    it = 1
            if args.rank == 0:
                time_used = time.time() - t_end
                fps = (img_tot -
                       args.per_batch_size) * args.group_size / time_used
                args.logger.info(
                    'Inference Performance: {:.2f} img/sec'.format(fps))
            results = [[top1_correct], [top5_correct], [img_tot]]
            args.logger.info('before results={}'.format(results))
            results = np.array(results)

            args.logger.info('after results={}'.format(results))
            top1_correct = results[0, 0]
            top5_correct = results[1, 0]
            img_tot = results[2, 0]
            acc1 = 100.0 * top1_correct / img_tot
            acc5 = 100.0 * top5_correct / img_tot
            args.logger.info('after allreduce eval: top1_correct={}, tot={},'
                             'acc={:.2f}%(TOP1)'.format(
                                 top1_correct, img_tot, acc1))
            args.logger.info('after allreduce eval: top5_correct={}, tot={},'
                             'acc={:.2f}%(TOP5)'.format(
                                 top5_correct, img_tot, acc5))
Exemplo n.º 7
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        if args.platform == "Ascend":
            init()
        else:
            init("nccl")
        args.rank = get_rank()
        args.group_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=args.group_size,
                                          parameter_broadcast=True,
                                          mirror_mean=True)
    else:
        args.rank = 0
        args.group_size = 1

    if args.is_dynamic_loss_scale == 1:
        args.loss_scale = 1  # for dynamic loss scale can not set loss scale in momentum opt

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    # dataloader
    de_dataset = classification_dataset(args.data_dir,
                                        args.image_size,
                                        args.per_batch_size,
                                        1,
                                        args.rank,
                                        args.group_size,
                                        num_parallel_workers=8)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone,
                          args.num_classes,
                          platform=args.platform)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(
            args.lr,
            args.lr_epochs,
            args.steps_per_epoch,
            args.warmup_epochs,
            args.max_epoch,
            gamma=args.lr_gamma,
        )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr, args.steps_per_epoch,
                                        args.warmup_epochs, args.max_epoch,
                                        args.T_max, args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=args.label_smooth_factor,
                        num_classes=args.num_classes)

    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                     scale_factor=2,
                                                     scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale,
                                                   drop_overflow_update=False)

    if args.platform == "Ascend":
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale_manager,
                      metrics={'acc'},
                      amp_level="O3")
    else:
        auto_mixed_precision(network)
        model = Model(network,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale_manager,
                      metrics={'acc'})

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [
        progress_cb,
    ]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
            keep_checkpoint_max=args.ckpt_save_max)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch,
                de_dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)