Exemple #1
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)
    context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
                        device_target=args.platform, save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
                                          parameter_broadcast=True, mirror_mean=True)
    else:
        args.rank = 0
        args.group_size = 1

    if args.is_dynamic_loss_scale == 1:
        args.loss_scale = 1  # for dynamic loss scale can not set loss scale in momentum opt

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    # dataloader
    de_dataset = classification_dataset(args.data_dir, args.image_size,
                                        args.per_batch_size, 1,
                                        args.rank, args.group_size, num_parallel_workers=8)
    de_dataset.map_model = 4  # !!!important
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = get_network(args.backbone, args.num_classes, platform=args.platform)
    if network is None:
        raise NotImplementedError('not implement {}'.format(args.backbone))

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr = warmup_step_lr(args.lr,
                            args.lr_epochs,
                            args.steps_per_epoch,
                            args.warmup_epochs,
                            args.max_epoch,
                            gamma=args.lr_gamma,
                            )
    elif args.lr_scheduler == 'cosine_annealing':
        lr = warmup_cosine_annealing_lr(args.lr,
                                        args.steps_per_epoch,
                                        args.warmup_epochs,
                                        args.max_epoch,
                                        args.T_max,
                                        args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)


    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)

    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)

    if args.platform == "Ascend":
        model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
                      metrics={'acc'}, amp_level="O3")
    else:
        model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
                      metrics={'acc'}, amp_level="O2")

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [progress_cb,]
    if args.rank_save_ckpt_flag:
        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
                                       keep_checkpoint_max=args.ckpt_save_max)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
Exemple #2
0
def create_dataset1(dataset_path,
                    do_train,
                    repeat_num=1,
                    batch_size=32,
                    target="Ascend"):
    """
    create a train or evaluate cifar10 dataset for mobilenet
    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    if target == "Ascend":
        device_num, rank_id = _get_rank_info()
    elif target == "GPU":
        init()
        rank_id = get_rank()
        device_num = get_group_size()
    else:
        device_num = 1

    if device_num == 1:
        data_set = ds.Cifar10Dataset(dataset_path,
                                     num_parallel_workers=THREAD_NUM,
                                     shuffle=True)
    else:
        data_set = ds.Cifar10Dataset(dataset_path,
                                     num_parallel_workers=THREAD_NUM,
                                     shuffle=True,
                                     num_shards=device_num,
                                     shard_id=rank_id)

    # define map operations
    trans = []
    if do_train:
        trans += [
            C.RandomCrop((32, 32), (4, 4, 4, 4)),
            C.RandomHorizontalFlip(prob=0.5)
        ]

    trans += [
        C.Resize((224, 224)),
        C.Rescale(1.0 / 255.0, 0.0),
        C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(operations=type_cast_op,
                            input_columns="label",
                            num_parallel_workers=THREAD_NUM)
    data_set = data_set.map(operations=trans,
                            input_columns="image",
                            num_parallel_workers=THREAD_NUM)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)
    # apply dataset repeat operation
    data_set = data_set.repeat(repeat_num)

    return data_set
Exemple #3
0
def create_dataset2(dataset_path,
                    do_train,
                    repeat_num=1,
                    batch_size=32,
                    target="Ascend"):
    """
    create a train or eval imagenet2012 dataset for mobilenet

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    if target == "Ascend":
        device_num, rank_id = _get_rank_info()
    else:
        init()
        rank_id = get_rank()
        device_num = get_group_size()

    if device_num == 1:
        data_set = ds.ImageFolderDataset(dataset_path,
                                         num_parallel_workers=THREAD_NUM,
                                         shuffle=True)
    else:
        data_set = ds.ImageFolderDataset(dataset_path,
                                         num_parallel_workers=THREAD_NUM,
                                         shuffle=True,
                                         num_shards=device_num,
                                         shard_id=rank_id)

    image_size = 224
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    # define map operations
    if do_train:
        trans = [
            C.RandomCropDecodeResize(image_size,
                                     scale=(0.08, 1.0),
                                     ratio=(0.75, 1.333)),
            C.RandomHorizontalFlip(prob=0.5),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
    else:
        trans = [
            C.Decode(),
            C.Resize(256),
            C.CenterCrop(image_size),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]

    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(operations=trans,
                            input_columns="image",
                            num_parallel_workers=THREAD_NUM)
    data_set = data_set.map(operations=type_cast_op,
                            input_columns="label",
                            num_parallel_workers=THREAD_NUM)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    data_set = data_set.repeat(repeat_num)

    return data_set
Exemple #4
0
def train():
    """training CenterNet"""
    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(save_graphs=False)

    ckpt_save_dir = args_opt.save_checkpoint_path
    rank = 0
    device_num = 1
    num_workers = 8
    if args_opt.device_target == "Ascend":
        context.set_context(enable_auto_mixed_precision=False)
        context.set_context(device_id=args_opt.device_id)
        if args_opt.distribute == "true":
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'

            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
                                              device_num=device_num)
            _set_parallel_all_reduce_split()
    else:
        args_opt.distribute = "false"
        args_opt.need_profiler = "false"
        args_opt.enable_data_sink = "false"

    # Start create dataset!
    # mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
    logger.info("Begin creating dataset for CenterNet")
    coco = COCOHP(dataset_config, run_mode="train", net_opt=net_config, save_path=args_opt.save_result_dir)
    dataset = coco.create_train_dataset(args_opt.mindrecord_dir, args_opt.mindrecord_prefix,
                                        batch_size=train_config.batch_size, device_num=device_num, rank=rank,
                                        num_parallel_workers=num_workers, do_shuffle=args_opt.do_shuffle == 'true')
    dataset_size = dataset.get_dataset_size()
    logger.info("Create dataset done!")

    net_with_loss = CenterNetLossCell(net_config)

    args_opt.train_steps = args_opt.epoch_size * dataset_size
    logger.info("train steps: {}".format(args_opt.train_steps))

    optimizer = _get_optimizer(net_with_loss, dataset_size)

    enable_static_time = args_opt.device_target == "CPU"
    callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time)]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
        config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
                                     keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
                                     directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)
    if args_opt.device_target == "Ascend":
        net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
                                                    sens=train_config.loss_scale_value)
    else:
        net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(args_opt.epoch_size, dataset, callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
Exemple #5
0
def train_net(data_dir,
              cross_valid_ind=1,
              epochs=400,
              batch_size=16,
              lr=0.0001,
              run_distribute=False,
              cfg=None):
    rank = 0
    group_size = 1
    if run_distribute:
        init()
        group_size = get_group_size()
        rank = get_rank()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=group_size,
                                          gradients_mean=False)

    if cfg['model'] == 'unet_medical':
        net = UNetMedical(n_channels=cfg['num_channels'],
                          n_classes=cfg['num_classes'])
    elif cfg['model'] == 'unet_nested':
        net = NestedUNet(in_channel=cfg['num_channels'],
                         n_class=cfg['num_classes'],
                         use_deconv=cfg['use_deconv'],
                         use_bn=cfg['use_bn'],
                         use_ds=cfg['use_ds'])
    elif cfg['model'] == 'unet_simple':
        net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
    else:
        raise ValueError("Unsupported model: {}".format(cfg['model']))

    if cfg['resume']:
        param_dict = load_checkpoint(cfg['resume_ckpt'])
        if cfg['transfer_training']:
            filter_checkpoint_parameter_by_list(param_dict,
                                                cfg['filter_weight'])
        load_param_into_net(net, param_dict)

    if 'use_ds' in cfg and cfg['use_ds']:
        criterion = MultiCrossEntropyWithLogits()
    else:
        criterion = CrossEntropyWithLogits()
    if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
        repeat = 10
        dataset_sink_mode = True
        per_print_times = 0
        train_dataset = create_cell_nuclei_dataset(data_dir,
                                                   cfg['img_size'],
                                                   repeat,
                                                   batch_size,
                                                   is_train=True,
                                                   augment=True,
                                                   split=0.8,
                                                   rank=rank,
                                                   group_size=group_size)
    else:
        repeat = epochs
        dataset_sink_mode = False
        per_print_times = 1
        train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True,
                                          cross_valid_ind, run_distribute,
                                          cfg["crop"], cfg['img_size'])
    train_data_size = train_dataset.get_dataset_size()
    print("dataset length is:", train_data_size)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=train_data_size,
        keep_checkpoint_max=cfg['keep_checkpoint_max'])
    ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
                                 directory='./ckpt_{}/'.format(device_id),
                                 config=ckpt_config)

    optimizer = nn.Adam(params=net.trainable_params(),
                        learning_rate=lr,
                        weight_decay=cfg['weight_decay'],
                        loss_scale=cfg['loss_scale'])

    loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(
        cfg['FixedLossScaleManager'], False)

    model = Model(net,
                  loss_fn=criterion,
                  loss_scale_manager=loss_scale_manager,
                  optimizer=optimizer,
                  amp_level="O3")

    print("============== Starting Training ==============")
    callbacks = [
        StepLossTimeMonitor(batch_size=batch_size,
                            per_print_times=per_print_times), ckpoint_cb
    ]
    model.train(int(epochs / repeat),
                train_dataset,
                callbacks=callbacks,
                dataset_sink_mode=dataset_sink_mode)
    print("============== End Training ==============")
Exemple #6
0
def train(cloud_args=None):
    """training process"""
    args = parse_args(cloud_args)

    context.set_context(mode=context.GRAPH_MODE,
                        enable_auto_mixed_precision=True,
                        device_target=args.device_target,
                        save_graphs=False)

    if args.device_target == 'Ascend':
        devid = int(os.getenv('DEVICE_ID'))
        context.set_context(device_id=devid)

    # init distributed
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

    if args.is_dynamic_loss_scale == 1:
        args.loss_scale = 1  # for dynamic loss scale can not set loss scale in momentum opt

    # select for master rank save ckpt or all rank save, compatible for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(
        args.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)

    # dataloader
    de_dataset = classification_dataset(args.data_dir, args.image_size,
                                        args.per_batch_size, args.max_epoch,
                                        args.rank, args.group_size)
    de_dataset.map_model = 4
    args.steps_per_epoch = de_dataset.get_dataset_size()

    args.logger.save_args(args)

    # network
    args.logger.important_info('start create network')
    # get network and init
    network = DenseNet121(args.num_classes)
    # loss
    if not args.label_smooth:
        args.label_smooth_factor = 0.0
    criterion = CrossEntropy(smooth_factor=args.label_smooth_factor,
                             num_classes=args.num_classes)

    # load pretrain model
    if os.path.isfile(args.pretrained):
        param_dict = load_checkpoint(args.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        args.logger.info('load model {} success'.format(args.pretrained))

    # lr scheduler
    if args.lr_scheduler == 'exponential':
        lr_scheduler = MultiStepLR(args.lr,
                                   args.lr_epochs,
                                   args.lr_gamma,
                                   args.steps_per_epoch,
                                   args.max_epoch,
                                   warmup_epochs=args.warmup_epochs)
    elif args.lr_scheduler == 'cosine_annealing':
        lr_scheduler = CosineAnnealingLR(args.lr,
                                         args.T_max,
                                         args.steps_per_epoch,
                                         args.max_epoch,
                                         warmup_epochs=args.warmup_epochs,
                                         eta_min=args.eta_min)
    else:
        raise NotImplementedError(args.lr_scheduler)
    lr_schedule = lr_scheduler.get_lr()

    # optimizer
    opt = Momentum(params=get_param_groups(network),
                   learning_rate=Tensor(lr_schedule),
                   momentum=args.momentum,
                   weight_decay=args.weight_decay,
                   loss_scale=args.loss_scale)

    # mixed precision training
    criterion.add_flags_recursive(fp32=True)

    # package training process, adjust lr + forward + backward + optimizer
    train_net = BuildTrainNetwork(network, criterion)
    if args.is_distributed:
        parallel_mode = ParallelMode.DATA_PARALLEL
    else:
        parallel_mode = ParallelMode.STAND_ALONE
    if args.is_dynamic_loss_scale == 1:
        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
                                                     scale_factor=2,
                                                     scale_window=2000)
    else:
        loss_scale_manager = FixedLossScaleManager(args.loss_scale,
                                                   drop_overflow_update=False)

    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      device_num=args.group_size,
                                      gradients_mean=True)

    if args.device_target == 'Ascend':
        model = Model(train_net,
                      optimizer=opt,
                      metrics=None,
                      loss_scale_manager=loss_scale_manager,
                      amp_level="O3")
    elif args.device_target == 'GPU':
        model = Model(train_net,
                      optimizer=opt,
                      metrics=None,
                      loss_scale_manager=loss_scale_manager,
                      amp_level="O0")
    else:
        raise ValueError("Unsupported device target.")

    # checkpoint save
    progress_cb = ProgressMonitor(args)
    callbacks = [
        progress_cb,
    ]
    if args.rank_save_ckpt_flag:
        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=ckpt_config,
                                  directory=args.outputs_dir,
                                  prefix='{}'.format(args.rank))
        callbacks.append(ckpt_cb)

    model.train(args.max_epoch, de_dataset, callbacks=callbacks)
Exemple #7
0
                        help='run platform')
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.platform,
                        save_graphs=False)
    if os.getenv('DEVICE_ID', "not_set").isdigit():
        context.set_context(device_id=int(os.getenv('DEVICE_ID')))

    # init distributed
    if args_opt.is_distributed:
        if args_opt.platform == "Ascend":
            init()
        else:
            init("nccl")
        cfg.rank = get_rank()
        cfg.group_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=cfg.group_size,
                                          gradients_mean=True)
    else:
        cfg.rank = 0
        cfg.group_size = 1

    # dataloader
    dataset = create_dataset(args_opt.dataset_path, cfg, True)
    batches_per_epoch = dataset.get_dataset_size()

    # network
    net = NASNetAMobile(cfg.num_classes)
Exemple #8
0
def _build_training_pipeline(config: TransformerConfig,
                             pre_training_dataset=None,
                             fine_tune_dataset=None,
                             test_dataset=None,
                             platform="Ascend"):
    """
    Build training pipeline.

    Args:
        config (TransformerConfig): Config of mass model.
        pre_training_dataset (Dataset): Pre-training dataset.
        fine_tune_dataset (Dataset): Fine-tune dataset.
        test_dataset (Dataset): Test dataset.
    """
    net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
    net_with_loss.init_parameters_data()
    _load_checkpoint_to_net(config, net_with_loss)

    dataset = pre_training_dataset if pre_training_dataset is not None \
        else fine_tune_dataset

    if dataset is None:
        raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")

    update_steps = config.epochs * dataset.get_dataset_size()

    lr = _get_lr(config, update_steps)

    optimizer = _get_optimizer(config, net_with_loss, lr)

    # loss scale.
    if config.loss_scale_mode == "dynamic":
        scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
                                                scale_factor=config.loss_scale_factor,
                                                scale_window=config.scale_window)
    else:
        scale_manager = FixedLossScaleManager(loss_scale=config.init_loss_scale, drop_overflow_update=True)
    net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer,
                                                              scale_update_cell=scale_manager.get_update_cell())
    net_with_grads.set_train(True)
    model = Model(net_with_grads)
    time_cb = TimeMonitor(data_size=dataset.get_dataset_size())
    ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
                                   keep_checkpoint_max=config.keep_ckpt_max)

    rank_size = os.getenv('RANK_SIZE')
    callbacks = []
    callbacks.append(time_cb)
    if rank_size is not None and int(rank_size) > 1:
        loss_monitor = LossCallBack(config, rank_id=MultiAscend.get_rank())
        callbacks.append(loss_monitor)
        if MultiAscend.get_rank() % 8 == 0:
            ckpt_callback = ModelCheckpoint(
                prefix=config.ckpt_prefix,
                directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(MultiAscend.get_rank())),
                config=ckpt_config)
            callbacks.append(ckpt_callback)

    if rank_size is None or int(rank_size) == 1:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        loss_monitor = LossCallBack(config, rank_id=os.getenv('DEVICE_ID'))
        callbacks.append(loss_monitor)
        callbacks.append(ckpt_callback)

    print(f" | ALL SET, PREPARE TO TRAIN.")
    _train(model=model, config=config,
           pre_training_dataset=pre_training_dataset,
           fine_tune_dataset=fine_tune_dataset,
           test_dataset=test_dataset,
           callbacks=callbacks)
Exemple #9
0
    args.device_id = int(device_id)
    rank_size = os.getenv("RANK_SIZE", default=None)
    if rank_size is None:
        raise ValueError("Unsupported rank size.")
    if args.device_num > int(rank_size) or args.device_num == 1:
        args.device_num = int(rank_size)
    context.set_context(device_id=args.device_id)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args.device_target,
                        save_graphs=args.save_graphs)
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
                                      gradients_mean=True,
                                      device_num=args.device_num)
    init()
    args.rank = get_rank()
    local_data_url = os.path.join(local_data_url, str(args.device_id))
    local_train_url = os.path.join(local_train_url, str(args.device_id))
    args.train_output_path = os.path.join(args.train_output_path,
                                          str(args.device_id))
else:
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args.device_target,
                        save_graphs=args.save_graphs,
                        device_id=args.device_id)
    args.rank = 0
    args.device_num = 1

if args.run_cloudbrain:
    import moxing as mox
    args.train_dataset_path = os.path.join(local_data_url, "train")
Exemple #10
0
def train():
    """Train function."""
    args = parse_args()

    args.outputs_dir = params['save_model_path']

    if args.group_size > 1:
        init()
        context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
                                          gradients_mean=True)
        args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_{}/".format(str(get_rank())))
        args.rank = get_rank()
    else:
        args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/")
        args.rank = 0

    # with out loss_scale
    if args.group_size > 1:
        args.loss_scale = params['loss_scale'] / 2
        args.lr_steps = list(map(int, params["lr_steps_NP"].split(',')))
    else:
        args.loss_scale = params['loss_scale']
        args.lr_steps = list(map(int, params["lr_steps"].split(',')))

    # create network
    print('start create network')
    criterion = openpose_loss()
    criterion.add_flags_recursive(fp32=True)
    network = OpenPoseNet(vggpath=params['vgg_path'])
    # network.add_flags_recursive(fp32=True)

    if params["load_pretrain"]:
        print("load pretrain model:", params["pretrained_model_path"])
        load_model(network, params["pretrained_model_path"])
    train_net = BuildTrainNetwork(network, criterion)

    # create dataset
    if os.path.exists(args.jsonpath_train) and os.path.exists(args.imgpath_train) \
            and os.path.exists(args.maskpath_train):
        print('start create dataset')
    else:
        print('Error: wrong data path')


    num_worker = 20 if args.group_size > 1 else 48
    de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train,
                                      batch_size=params['batch_size'],
                                      rank=args.rank,
                                      group_size=args.group_size,
                                      num_worker=num_worker,
                                      multiprocessing=True,
                                      shuffle=True,
                                      repeat_num=1)
    steps_per_epoch = de_dataset_train.get_dataset_size()
    print("steps_per_epoch: ", steps_per_epoch)

    # lr scheduler
    lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size,
                                       params['lr_gamma'],
                                       steps_per_epoch,
                                       params["max_epoch_train"],
                                       args.lr_steps,
                                       args.group_size)
    vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params()))
    base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params()))
    stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params()))

    group_params = [{'params': vgg19_base_params, 'lr': lr_vgg},
                    {'params': base_params, 'lr': lr_base},
                    {'params': stages_params, 'lr': lr_stage}]

    opt = Adam(group_params, loss_scale=args.loss_scale)

    train_net.set_train(True)
    loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)

    model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager)

    params['ckpt_interval'] = max(steps_per_epoch, params['ckpt_interval'])
    config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'],
                                 keep_checkpoint_max=params["keep_checkpoint_max"])
    ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck)
    time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size())
    callback_list = [MyLossMonitor(), time_cb, ckpoint_cb]
    print("============== Starting Training ==============")
    model.train(params["max_epoch_train"], de_dataset_train, callbacks=callback_list,
                dataset_sink_mode=False)
Exemple #11
0
def create_dataset(dataset_path,
                   do_train,
                   config,
                   platform,
                   repeat_num=1,
                   batch_size=100,
                   model='ghsotnet'):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32

    Returns:
        dataset
    """
    if platform == "Ascend":
        rank_size = int(os.getenv("RANK_SIZE"))
        rank_id = int(os.getenv("RANK_ID"))
        if rank_size == 1:
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=True)
        else:
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=True,
                                      num_shards=rank_size,
                                      shard_id=rank_id)
    elif platform == "GPU":
        if do_train:
            from mindspore.communication.management import get_rank, get_group_size
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=True,
                                      num_shards=get_group_size(),
                                      shard_id=get_rank())
        else:
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=True)
    else:
        raise ValueError("Unsupport platform.")

    resize_height = config.image_height
    buffer_size = 1000

    # define map operations
    resize_crop_op = C.RandomCropDecodeResize(resize_height,
                                              scale=(0.08, 1.0),
                                              ratio=(0.75, 1.333))
    horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)

    color_op = C.RandomColorAdjust(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4)
    rescale_op = C.Rescale(1 / 255.0, 0)
    normalize_op = C.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
    change_swap_op = C.HWC2CHW()

    # define python operations
    decode_p = P.Decode()
    if model == 'ghostnet-600':
        s = 274
        c = 240
    else:
        s = 256
        c = 224
    resize_p = P.Resize(s, interpolation=Inter.BICUBIC)
    center_crop_p = P.CenterCrop(c)
    totensor = P.ToTensor()
    normalize_p = P.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    composeop = P.ComposeOp(
        [decode_p, resize_p, center_crop_p, totensor, normalize_p])
    if do_train:
        trans = [
            resize_crop_op, horizontal_flip_op, color_op, rescale_op,
            normalize_op, change_swap_op
        ]
    else:
        trans = composeop()
    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(input_columns="image",
                            operations=trans,
                            num_parallel_workers=8)
    data_set = data_set.map(input_columns="label_list",
                            operations=type_cast_op,
                            num_parallel_workers=8)

    # apply shuffle operations
    data_set = data_set.shuffle(buffer_size=buffer_size)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    data_set = data_set.repeat(repeat_num)

    return data_set
Exemple #12
0
            config_gpu.loss_scale, drop_overflow_update=False)
        lr = Tensor(get_lr(global_step=0,
                           lr_init=0,
                           lr_end=0,
                           lr_max=config_gpu.lr,
                           warmup_epochs=config_gpu.warmup_epochs,
                           total_epochs=epoch_size,
                           steps_per_epoch=step_size))
        opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum,
                       config_gpu.weight_decay, config_gpu.loss_scale)
        # define model
        model = Model(net, loss_fn=loss, optimizer=opt,
                      loss_scale_manager=loss_scale)

        cb = [Monitor(lr_init=lr.asnumpy())]
        ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
        if config_gpu.save_checkpoint:
            config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
                                         keep_checkpoint_max=config_gpu.keep_checkpoint_max)
            ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
            cb += [ckpt_cb]
        # begine train
        model.train(epoch_size, dataset, callbacks=cb)
    elif args_opt.platform == "Ascend":
        # train on ascend
        print("train args: ", args_opt, "\ncfg: ", config_ascend,
              "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))

        if run_distribute:
            context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
                                              parameter_broadcast=True, mirror_mean=True)
Exemple #13
0
def train():
    args = parse_args()

    if args.device_target == "CPU":
        context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
    else:
        context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
                            device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))

    # init multicards training
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.group_size)

    # dataset
    dataset = data_generator.SegDataset(image_mean=args.image_mean,
                                        image_std=args.image_std,
                                        data_file=args.data_file,
                                        batch_size=args.batch_size,
                                        crop_size=args.crop_size,
                                        max_scale=args.max_scale,
                                        min_scale=args.min_scale,
                                        ignore_label=args.ignore_label,
                                        num_classes=args.num_classes,
                                        num_readers=2,
                                        num_parallel_calls=4,
                                        shard_id=args.rank,
                                        shard_num=args.group_size)
    dataset = dataset.get_dataset(repeat=1)

    # network
    if args.model == 'deeplab_v3_s16':
        network = net_factory.nets_map[args.model]('train', args.num_classes, 16, args.freeze_bn)
    elif args.model == 'deeplab_v3_s8':
        network = net_factory.nets_map[args.model]('train', args.num_classes, 8, args.freeze_bn)
    else:
        raise NotImplementedError('model [{:s}] not recognized'.format(args.model))

    # loss
    loss_ = loss.SoftmaxCrossEntropyLoss(args.num_classes, args.ignore_label)
    loss_.add_flags_recursive(fp32=True)
    train_net = BuildTrainNetwork(network, loss_)

    # load pretrained model
    if args.ckpt_pre_trained:
        param_dict = load_checkpoint(args.ckpt_pre_trained)
        load_param_into_net(train_net, param_dict)

    # optimizer
    iters_per_epoch = dataset.get_dataset_size()
    total_train_steps = iters_per_epoch * args.train_epochs
    if args.lr_type == 'cos':
        lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps, total_train_steps)
    elif args.lr_type == 'poly':
        lr_iter = learning_rates.poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
    elif args.lr_type == 'exp':
        lr_iter = learning_rates.exponential_lr(args.base_lr, args.lr_decay_step, args.lr_decay_rate,
                                                total_train_steps, staircase=True)
    else:
        raise ValueError('unknown learning rate type')
    opt = nn.Momentum(params=train_net.trainable_params(), learning_rate=lr_iter, momentum=0.9, weight_decay=0.0001,
                      loss_scale=args.loss_scale)

    # loss scale
    manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
    amp_level = "O0" if args.device_target == "CPU" else "O3"
    model = Model(train_net, optimizer=opt, amp_level=amp_level, loss_scale_manager=manager_loss_scale)

    # callback for saving ckpts
    time_cb = TimeMonitor(data_size=iters_per_epoch)
    loss_cb = LossMonitor()
    cbs = [time_cb, loss_cb]

    if args.rank == 0:
        config_ck = CheckpointConfig(save_checkpoint_steps=args.save_steps,
                                     keep_checkpoint_max=args.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck)
        cbs.append(ckpoint_cb)

    model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU"))
Exemple #14
0
def train_on_gpu():
    config = config_gpu_quant
    print("training args: {}".format(args_opt))
    print("training configure: {}".format(config))

    # define network
    network = mobilenetV2(num_classes=config.num_classes)
    # define loss
    if config.label_smooth > 0:
        loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth,
                                           num_classes=config.num_classes)
    else:
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    # define dataset
    epoch_size = config.epoch_size
    dataset = create_dataset(dataset_path=args_opt.dataset_path,
                             do_train=True,
                             config=config,
                             device_target=args_opt.device_target,
                             repeat_num=1,
                             batch_size=config.batch_size)
    step_size = dataset.get_dataset_size()
    # resume
    if args_opt.pre_trained:
        param_dict = load_checkpoint(args_opt.pre_trained)
        load_nonquant_param_into_quant_net(network, param_dict)

    # convert fusion network to quantization aware network
    quantizer = QuantizationAwareTraining(bn_fold=True,
                                          per_channel=[True, False],
                                          symmetric=[False, False],
                                          freeze_bn=1000000,
                                          quant_delay=step_size * 2)
    network = quantizer.quantize(network)

    # get learning rate
    loss_scale = FixedLossScaleManager(config.loss_scale,
                                       drop_overflow_update=False)
    lr = Tensor(
        get_lr(global_step=config.start_epoch * step_size,
               lr_init=0,
               lr_end=0,
               lr_max=config.lr,
               warmup_epochs=config.warmup_epochs,
               total_epochs=epoch_size + config.start_epoch,
               steps_per_epoch=step_size))

    # define optimization
    opt = nn.Momentum(
        filter(lambda x: x.requires_grad, network.get_parameters()), lr,
        config.momentum, config.weight_decay, config.loss_scale)
    # define model
    model = Model(network,
                  loss_fn=loss,
                  optimizer=opt,
                  loss_scale_manager=loss_scale)

    print("============== Starting Training ==============")
    callback = [Monitor(lr_init=lr.asnumpy())]
    ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
        get_rank()) + "/"
    if config.save_checkpoint:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
            keep_checkpoint_max=config.keep_checkpoint_max)
        ckpt_cb = ModelCheckpoint(prefix="mobilenetV2",
                                  directory=ckpt_save_dir,
                                  config=config_ck)
        callback += [ckpt_cb]
    model.train(epoch_size, dataset, callbacks=callback)
    print("============== End Training ==============")
Exemple #15
0
def _build_training_pipeline(config: TransformerConfig,
                             pre_training_dataset=None,
                             fine_tune_dataset=None,
                             test_dataset=None,
                             platform="Ascend"):
    """
    Build training pipeline.

    Args:
        config (TransformerConfig): Config of mass model.
        pre_training_dataset (Dataset): Pre-training dataset.
        fine_tune_dataset (Dataset): Fine-tune dataset.
        test_dataset (Dataset): Test dataset.
    """
    net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
    net_with_loss.init_parameters_data()

    if config.existed_ckpt:
        if config.existed_ckpt.endswith(".npz"):
            weights = np.load(config.existed_ckpt)
        else:
            weights = load_checkpoint(config.existed_ckpt)
        for param in net_with_loss.trainable_params():
            weights_name = param.name
            if weights_name not in weights:
                raise ValueError(
                    f"Param {weights_name} is not found in ckpt file.")

            if isinstance(weights[weights_name], Parameter):
                param.default_input = weights[weights_name].default_input
            elif isinstance(weights[weights_name], Tensor):
                param.default_input = Tensor(weights[weights_name].asnumpy(),
                                             config.dtype)
            elif isinstance(weights[weights_name], np.ndarray):
                param.default_input = Tensor(weights[weights_name],
                                             config.dtype)
            else:
                param.default_input = weights[weights_name]
    else:
        for param in net_with_loss.trainable_params():
            name = param.name
            value = param.default_input
            if isinstance(value, Tensor):
                if name.endswith(".gamma"):
                    param.default_input = one_weight(value.asnumpy().shape)
                elif name.endswith(".beta") or name.endswith(".bias"):
                    param.default_input = zero_weight(value.asnumpy().shape)
                else:
                    param.default_input = weight_variable(
                        value.asnumpy().shape)

    dataset = pre_training_dataset if pre_training_dataset is not None \
        else fine_tune_dataset

    if dataset is None:
        raise ValueError(
            "pre-training dataset or fine-tuning dataset must be provided one."
        )

    update_steps = dataset.get_repeat_count() * dataset.get_dataset_size()
    if config.lr_scheduler == "isr":
        lr = Tensor(square_root_schedule(
            lr=config.lr,
            update_num=update_steps,
            decay_start_step=config.decay_start_step,
            warmup_steps=config.warmup_steps,
            min_lr=config.min_lr),
                    dtype=mstype.float32)
    elif config.lr_scheduler == "poly":
        lr = Tensor(polynomial_decay_scheduler(
            lr=config.lr,
            min_lr=config.min_lr,
            decay_steps=config.decay_steps,
            total_update_num=update_steps,
            warmup_steps=config.warmup_steps,
            power=config.poly_lr_scheduler_power),
                    dtype=mstype.float32)
    else:
        lr = config.lr

    if config.optimizer.lower() == "adam":
        optimizer = Adam(net_with_loss.trainable_params(),
                         lr,
                         beta1=0.9,
                         beta2=0.98)
    elif config.optimizer.lower() == "lamb":
        lr = BertLearningRate(decay_steps=12000,
                              learning_rate=config.lr,
                              end_learning_rate=config.min_lr,
                              power=10.0,
                              warmup_steps=config.warmup_steps)
        decay_params = list(
            filter(
                lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x
                .name.lower(), net_with_loss.trainable_params()))
        other_params = list(
            filter(
                lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.
                lower(), net_with_loss.trainable_params()))
        group_params = [{
            'params': decay_params,
            'weight_decay': 0.01
        }, {
            'params': other_params
        }]

        optimizer = Lamb(group_params, lr, eps=1e-6)
    elif config.optimizer.lower() == "momentum":
        optimizer = Momentum(net_with_loss.trainable_params(),
                             lr,
                             momentum=0.9)
    else:
        raise ValueError(f"optimizer only support `adam` and `momentum` now.")

    # loss scale.
    if config.loss_scale_mode == "dynamic":
        scale_manager = DynamicLossScaleManager(
            init_loss_scale=config.init_loss_scale,
            scale_factor=config.loss_scale_factor,
            scale_window=config.scale_window)
    else:
        scale_manager = FixedLossScaleManager(
            loss_scale=config.init_loss_scale, drop_overflow_update=True)
    net_with_grads = TransformerTrainOneStepWithLossScaleCell(
        network=net_with_loss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    net_with_grads.set_train(True)
    model = Model(net_with_grads)
    loss_monitor = LossCallBack(config)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=config.save_ckpt_steps,
        keep_checkpoint_max=config.keep_ckpt_max)

    rank_size = os.getenv('RANK_SIZE')
    callbacks = [loss_monitor]
    if rank_size is not None and int(
            rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)

    if rank_size is None or int(rank_size) == 1:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)

    print(f" | ALL SET, PREPARE TO TRAIN.")
    _train(model=model,
           config=config,
           pre_training_dataset=pre_training_dataset,
           fine_tune_dataset=fine_tune_dataset,
           test_dataset=test_dataset,
           callbacks=callbacks)
Exemple #16
0
def main(args):

    if args.is_distributed == 0:
        cfg = faceqa_1p_cfg
    else:
        cfg = faceqa_8p_cfg

    cfg.data_lst = args.train_label_file
    cfg.pretrained = args.pretrained

    # Init distributed
    if args.is_distributed:
        init()
        cfg.local_rank = get_rank()
        cfg.world_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
    else:
        parallel_mode = ParallelMode.STAND_ALONE

    # parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      device_num=cfg.world_size,
                                      gradients_mean=True)

    mindspore.common.set_seed(1)

    # logger
    cfg.outputs_dir = os.path.join(
        cfg.ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
    loss_meter = AverageMeter('loss')

    # Dataloader
    cfg.logger.info('start create dataloader')
    de_dataset = faceqa_dataset(imlist=cfg.data_lst,
                                local_rank=cfg.local_rank,
                                world_size=cfg.world_size,
                                per_batch_size=cfg.per_batch_size)
    cfg.steps_per_epoch = de_dataset.get_dataset_size()
    de_dataset = de_dataset.repeat(cfg.max_epoch)
    de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
    # Show cfg
    cfg.logger.save_args(cfg)
    cfg.logger.info('end create dataloader')

    # backbone and loss
    cfg.logger.important_info('start create network')
    create_network_start = time.time()

    network = FaceQABackbone()
    criterion = CriterionsFaceQA()

    # load pretrain model
    if os.path.isfile(cfg.pretrained):
        param_dict = load_checkpoint(cfg.pretrained)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        cfg.logger.info('load model %s success.', cfg.pretrained)

    # optimizer and lr scheduler
    lr = warmup_step(cfg, gamma=0.9)
    opt = Momentum(params=network.trainable_params(),
                   learning_rate=lr,
                   momentum=cfg.momentum,
                   weight_decay=cfg.weight_decay,
                   loss_scale=cfg.loss_scale)

    # package training process, adjust lr + forward + backward + optimizer
    train_net = BuildTrainNetwork(network, criterion)
    train_net = TrainOneStepCell(
        train_net,
        opt,
        sens=cfg.loss_scale,
    )

    # checkpoint save
    if cfg.local_rank == 0:
        ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
        train_config = CheckpointConfig(
            save_checkpoint_steps=cfg.ckpt_interval,
            keep_checkpoint_max=ckpt_max_num)
        ckpt_cb = ModelCheckpoint(config=train_config,
                                  directory=cfg.outputs_dir,
                                  prefix='{}'.format(cfg.local_rank))
        cb_params = _InternalCallbackParam()
        cb_params.train_network = train_net
        cb_params.epoch_num = ckpt_max_num
        cb_params.cur_epoch_num = 1
        run_context = RunContext(cb_params)
        ckpt_cb.begin(run_context)

    train_net.set_train()
    t_end = time.time()
    t_epoch = time.time()
    old_progress = -1

    cfg.logger.important_info('====start train====')
    for i, (data, gt) in enumerate(de_dataloader):
        # clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss
        data = data.astype(np.float32)
        gt = gt.astype(np.float32)
        data = Tensor(data)
        gt = Tensor(gt)

        loss = train_net(data, gt)
        loss_meter.update(loss.asnumpy())

        # ckpt
        if cfg.local_rank == 0:
            cb_params.cur_step_num = i + 1  # current step number
            cb_params.batch_num = i + 2
            ckpt_cb.step_end(run_context)

        # logging loss, fps, ...
        if i == 0:
            time_for_graph_compile = time.time() - create_network_start
            cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(
                cfg.task, time_for_graph_compile))

        if i % cfg.log_interval == 0 and cfg.local_rank == 0:
            time_used = time.time() - t_end
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * (
                i - old_progress) * cfg.world_size / time_used
            cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(
                epoch, i, loss_meter, fps))
            t_end = time.time()
            loss_meter.reset()
            old_progress = i

        if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
            epoch_time_used = time.time() - t_epoch
            epoch = int(i / cfg.steps_per_epoch)
            fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
            cfg.logger.info(
                '=================================================')
            cfg.logger.info(
                'epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(
                    epoch, i, fps))
            cfg.logger.info(
                '=================================================')
            t_epoch = time.time()

    cfg.logger.important_info('====train end====')
def run_general_distill():
    """
    run general distill
    """
    parser = argparse.ArgumentParser(description='tinybert general distill')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="3",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--save_ckpt_step",
                        type=int,
                        default=100,
                        help="Enable data sink, default is true.")
    parser.add_argument("--max_ckpt_num",
                        type=int,
                        default=1,
                        help="Enable data sink, default is true.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default=1,
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_ckpt_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_teacher_ckpt_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init('hccl')
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = D.get_rank()
            save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    enable_loss_scale = True
    if args_opt.device_target == "GPU":
        if bert_teacher_net_cfg.compute_type != mstype.float32:
            logger.warning('GPU only support fp32 temporarily, run with fp32.')
            bert_teacher_net_cfg.compute_type = mstype.float32
        if bert_student_net_cfg.compute_type != mstype.float32:
            logger.warning('GPU only support fp32 temporarily, run with fp32.')
            bert_student_net_cfg.compute_type = mstype.float32
        # Both the forward and backward of the network are calculated using fp32,
        # and the loss scale is not necessary
        enable_loss_scale = False

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.data_dir, args_opt.schema_dir)
    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    print("dataset repeatcount: ", dataset.get_repeat_count())
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=common_cfg.loss_scale_value,
            scale_factor=common_cfg.scale_factor,
            scale_window=common_cfg.scale_window)
        netwithgrads = BertTrainWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
Exemple #18
0
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.9,
                 affine=True,
                 gamma_init='ones',
                 beta_init='zeros',
                 moving_mean_init='zeros',
                 moving_var_init='ones',
                 use_batch_statistics=None,
                 device_num_each_group=1,
                 input_dims='2d',
                 data_format='NCHW'):
        super(_BatchNorm, self).__init__()
        if num_features < 1:
            raise ValueError("num_features must be at least 1")

        if momentum < 0 or momentum > 1:
            raise ValueError(
                "momentum should be a number in range [0, 1], but got {}".
                format(momentum))
        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'],
                                             'format', self.cls_name)
        if context.get_context(
                "device_target") != "GPU" and self.format == "NHWC":
            raise ValueError("NHWC format only support in GPU target.")
        self.use_batch_statistics = use_batch_statistics
        self.num_features = num_features
        self.eps = eps
        self.input_dims = input_dims
        self.moving_mean = Parameter(initializer(moving_mean_init,
                                                 num_features),
                                     name="mean",
                                     requires_grad=False)
        self.moving_variance = Parameter(initializer(moving_var_init,
                                                     num_features),
                                         name="variance",
                                         requires_grad=False)
        self.gamma = Parameter(initializer(gamma_init, num_features),
                               name="gamma",
                               requires_grad=affine)
        self.beta = Parameter(initializer(beta_init, num_features),
                              name="beta",
                              requires_grad=affine)
        self.group = validator.check_positive_int(device_num_each_group)
        self.is_global = False
        if self.group != 1:
            self.rank_id = get_rank()
            self.rank_size = get_group_size()
            self.device_list = [i for i in range(0, self.rank_size)]
            self.rank_list = self.list_group(self.device_list, self.group)
            self.rank_list_idx = len(self.rank_list)
            for i in range(self.rank_list_idx):
                if self.rank_id in self.rank_list[i] and self.group != 1:
                    self.is_global = True
                    management.create_group('group' + str(i),
                                            self.rank_list[i])
                    self.all_reduce = P.AllReduce(
                        P.ReduceOp.SUM,
                        'group' + str(i)).add_prim_attr('fusion', 1)
        self.shape = P.Shape()
        self.reduce_mean = P.ReduceMean(keep_dims=True)
        self.square = P.Square()
        self.sqrt = P.Sqrt()
        self.cast = P.Cast()
        self.dtype = P.DType()
        self.reshape = P.Reshape()
        self._target = context.get_context("device_target")
        self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
        self.momentum = 1.0 - momentum
        if context.get_context("enable_ge"):
            self.is_ge_backend = True
        else:
            self.is_ge_backend = False

        if self._target == "Ascend":
            self.bn_train = P.BatchNorm(is_training=True,
                                        epsilon=self.eps,
                                        momentum=self.momentum)
        if self._target == "GPU":
            self.bn_train = P.FusedBatchNormEx(mode=1,
                                               epsilon=self.eps,
                                               momentum=self.momentum,
                                               data_format=self.format)
        if self._target == "CPU":
            self.bn_train = P.FusedBatchNorm(mode=1,
                                             epsilon=self.eps,
                                             momentum=self.momentum)
        self.bn_infer = P.BatchNorm(is_training=False,
                                    epsilon=self.eps,
                                    data_format=self.format)
        self.enable_global_sync = self.is_global and (self.is_ge_backend or\
            (self.is_graph_mode and self._target == "Ascend"))

        data_parallel_strategy = ((1, ), (1, ))
        data_parallel_strategy_one = ((1, ), ())
        self.sub_mean = P.Sub().shard(data_parallel_strategy)
        self.sub_var = P.Sub().shard(data_parallel_strategy)
        self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
        self.mul_var = P.Mul().shard(data_parallel_strategy_one)
        self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
        self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
Exemple #19
0
def train_and_eval(config):
    """
    test_train_eval
    """
    set_seed(1000)
    data_path = config.data_path
    batch_size = config.batch_size
    epochs = config.epochs
    if config.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif config.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        dataset_type = DataType.H5
    parameter_server = bool(config.parameter_server)
    cache_enable = bool(config.vocab_cache_size > 0)
    print("epochs is {}".format(epochs))
    ds_train = create_dataset(data_path,
                              train_mode=True,
                              epochs=1,
                              batch_size=batch_size,
                              rank_id=get_rank(),
                              rank_size=get_group_size(),
                              data_type=dataset_type)
    ds_eval = create_dataset(data_path,
                             train_mode=False,
                             epochs=1,
                             batch_size=batch_size,
                             rank_id=get_rank(),
                             rank_size=get_group_size(),
                             data_type=dataset_type)
    print("ds_train.size: {}".format(ds_train.get_dataset_size()))
    print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

    net_builder = ModelBuilder()

    train_net, eval_net = net_builder.get_net(config)
    train_net.set_train()
    auc_metric = AUCMetric()

    model = Model(train_net,
                  eval_network=eval_net,
                  metrics={"auc": auc_metric})

    eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

    callback = LossCallBack(config=config)
    ckptconfig = CheckpointConfig(
        save_checkpoint_steps=ds_train.get_dataset_size(),
        keep_checkpoint_max=5)
    ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
                                 directory=config.ckpt_path + '/ckpt_' +
                                 str(get_rank()) + '/',
                                 config=ckptconfig)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if get_rank() == 0:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=(parameter_server and cache_enable))
Exemple #20
0
            context.set_context(device_id=device_id,
                                enable_auto_mixed_precision=True)
            context.set_auto_parallel_context(
                device_num=args_opt.device_num,
                parallel_mode=ParallelMode.DATA_PARALLEL,
                gradients_mean=True)
            init()
        # GPU target
        else:
            init()
            context.set_auto_parallel_context(
                device_num=get_group_size(),
                parallel_mode=ParallelMode.DATA_PARALLEL,
                gradients_mean=True)
        ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
            get_rank()) + "/"

    # create dataset
    dataset = create_dataset(dataset_path=args_opt.dataset_path,
                             do_train=True,
                             repeat_num=1,
                             batch_size=config.batch_size,
                             target=target)
    step_size = dataset.get_dataset_size()

    # define net
    net = mobilenet(class_num=config.class_num)
    if args_opt.parameter_server:
        net.set_param_ps()

    # init weight
Exemple #21
0
def parse_args():
    """Parse train arguments."""
    parser = argparse.ArgumentParser('mindspore coco training')

    # device related
    parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
                        help='device where the code will be implemented. (Default: Ascend)')

    # dataset related
    parser.add_argument('--data_dir', type=str, help='Train dataset directory.')
    parser.add_argument('--per_batch_size', default=32, type=int, help='Batch size for Training. Default: 32.')

    # network related
    parser.add_argument('--pretrained_backbone', default='', type=str,
                        help='The ckpt file of DarkNet53. Default: "".')
    parser.add_argument('--resume_yolov3', default='', type=str,
                        help='The ckpt file of YOLOv3, which used to fine tune. Default: ""')

    # optimizer and lr related
    parser.add_argument('--lr_scheduler', default='exponential', type=str,
                        help='Learning rate scheduler, options: exponential, cosine_annealing. Default: exponential')
    parser.add_argument('--lr', default=0.001, type=float, help='Learning rate. Default: 0.001')
    parser.add_argument('--lr_epochs', type=str, default='220,250',
                        help='Epoch of changing of lr changing, split with ",". Default: 220,250')
    parser.add_argument('--lr_gamma', type=float, default=0.1,
                        help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1')
    parser.add_argument('--eta_min', type=float, default=0., help='Eta_min in cosine_annealing scheduler. Default: 0')
    parser.add_argument('--T_max', type=int, default=320, help='T-max in cosine_annealing scheduler. Default: 320')
    parser.add_argument('--max_epoch', type=int, default=320, help='Max epoch num to train the model. Default: 320')
    parser.add_argument('--warmup_epochs', default=0, type=float, help='Warmup epochs. Default: 0')
    parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay factor. Default: 0.0005')
    parser.add_argument('--momentum', type=float, default=0.9, help='Momentum. Default: 0.9')

    # loss related
    parser.add_argument('--loss_scale', type=int, default=1024, help='Static loss scale. Default: 1024')
    parser.add_argument('--label_smooth', type=int, default=0, help='Whether to use label smooth in CE. Default:0')
    parser.add_argument('--label_smooth_factor', type=float, default=0.1,
                        help='Smooth strength of original one-hot. Default: 0.1')

    # logging related
    parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
    parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
    parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')

    parser.add_argument('--is_save_on_master', type=int, default=1,
                        help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1')

    # distributed related
    parser.add_argument('--is_distributed', type=int, default=1,
                        help='Distribute train or not, 1 for yes, 0 for no. Default: 1')
    parser.add_argument('--rank', type=int, default=0, help='Local rank of distributed. Default: 0')
    parser.add_argument('--group_size', type=int, default=1, help='World size of device. Default: 1')

    # profiler init
    parser.add_argument('--need_profiler', type=int, default=0,
                        help='Whether use profiler. 0 for no, 1 for yes. Default: 0')

    # reset default config
    parser.add_argument('--training_shape', type=str, default="", help='Fix training shape. Default: ""')
    parser.add_argument('--resize_rate', type=int, default=None,
                        help='Resize rate for multi-scale training. Default: None')

    args, _ = parser.parse_known_args()
    if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max:
        args.T_max = args.max_epoch

    args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
    args.data_root = os.path.join(args.data_dir, 'train2014')
    args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')

    # init distributed
    if args.is_distributed:
        if args.device_target == "Ascend":
            init()
        else:
            init("nccl")
        args.rank = get_rank()
        args.group_size = get_group_size()

    # select for master rank save ckpt or all rank save, compatiable for model parallel
    args.rank_save_ckpt_flag = 0
    if args.is_save_on_master:
        if args.rank == 0:
            args.rank_save_ckpt_flag = 1
    else:
        args.rank_save_ckpt_flag = 1

    # logger
    args.outputs_dir = os.path.join(args.ckpt_path,
                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
    args.logger = get_logger(args.outputs_dir, args.rank)
    args.logger.save_args(args)

    return args
Exemple #22
0
                    device_id=args_opt.device_id)

if __name__ == '__main__':
    if args_opt.run_distribute:
        if args_opt.device_target == "Ascend":
            rank = args_opt.rank_id
            device_num = args_opt.device_num
            context.set_auto_parallel_context(
                device_num=device_num,
                parallel_mode=ParallelMode.DATA_PARALLEL,
                gradients_mean=True)
            init()
        else:
            init("nccl")
            context.reset_auto_parallel_context()
            rank = get_rank()
            device_num = get_group_size()
            context.set_auto_parallel_context(
                device_num=device_num,
                parallel_mode=ParallelMode.DATA_PARALLEL,
                gradients_mean=True)
    else:
        rank = 0
        device_num = 1

    print("Start create dataset!")

    # It will generate mindrecord file in args_opt.mindrecord_dir,
    # and the file name is FasterRcnn.mindrecord0, 1, ... file_num.
    prefix = "FasterRcnn.mindrecord"
    mindrecord_dir = config.mindrecord_dir
Exemple #23
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init('hccl')
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        from mindspore.parallel._auto_parallel_context import auto_parallel_context
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
        logger.info("train steps: {}".format(args_opt.train_steps))

    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(net_with_loss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
            .format(cfg.optimizer))
    callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_bert',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        net_with_grads = BertTrainOneStepWithLossScaleCell(
            net_with_loss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
Exemple #24
0
        from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
        if cfg.is_dynamic_loss_scale == 1:
            loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
        else:
            loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)

    else:
        raise ValueError("Unsupported dataset.")

    if device_target == "Ascend":
        model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", keep_batchnorm_fp32=False,
                      loss_scale_manager=loss_scale_manager)
    elif device_target == "GPU":
        model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, loss_scale_manager=loss_scale_manager)
    else:
        raise ValueError("Unsupported platform.")

    if device_num > 1:
        ckpt_save_dir = os.path.join(args.ckpt_path + "_" + str(get_rank()))
    else:
        ckpt_save_dir = args.ckpt_path

    time_cb = TimeMonitor(data_size=step_per_epoch)
    config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
                                 keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=ckpt_save_dir, config=config_ck)

    print("============== Starting Training ==============")
    model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
                dataset_sink_mode=args.dataset_sink_mode, sink_size=args.sink_size)