예제 #1
0
        loss_scale = FixedLossScaleManager(config.loss_scale,
                                           drop_overflow_update=False)
        lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004))
        opt = THOR(
            filter(lambda x: x.requires_grad,
                   net.get_parameters()), lr, config.momentum,
            filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
            filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
            filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
            filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
            config.weight_decay, config.loss_scale)

        model = Model(net,
                      loss_fn=loss,
                      optimizer=opt,
                      amp_level='O2',
                      loss_scale_manager=loss_scale,
                      keep_batchnorm_fp32=False,
                      metrics={'acc'},
                      frequency=config.frequency)

        time_cb = TimeMonitor(data_size=step_size)
        loss_cb = LossMonitor()
        cb = [time_cb, loss_cb]
        if config.save_checkpoint:
            config_ck = CheckpointConfig(
                save_checkpoint_steps=config.save_checkpoint_steps,
                keep_checkpoint_max=config.keep_checkpoint_max)
            ckpt_cb = ModelCheckpoint(prefix="resnet",
                                      directory=config.save_checkpoint_path,
                                      config=config_ck)
            cb += [ckpt_cb]
예제 #2
0
    loss = CrossEntropy(smooth_factor=config.label_smooth_factor,
                        num_classes=config.class_num)
    opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()),
               Tensor(lr), config.momentum,
               filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
               filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
               filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
               filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
               config.weight_decay, config.loss_scale)
    loss_scale = FixedLossScaleManager(config.loss_scale,
                                       drop_overflow_update=False)
    if target == "Ascend":
        model = Model(net,
                      loss_fn=loss,
                      optimizer=opt,
                      amp_level='O2',
                      loss_scale_manager=loss_scale,
                      keep_batchnorm_fp32=False,
                      metrics={'acc'},
                      frequency=config.frequency)
    else:
        model = Model(net,
                      loss_fn=loss,
                      optimizer=opt,
                      loss_scale_manager=loss_scale,
                      metrics={'acc'},
                      amp_level="O2",
                      keep_batchnorm_fp32=True,
                      frequency=config.frequency)

    # define callbacks
    time_cb = TimeMonitor(data_size=step_size)
예제 #3
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=4,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="false",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="false",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="100",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id,
                        save_graphs=False)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    context.set_context(max_call_depth=3000)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        D.init()
        device_num = D.get_group_size()
        rank = D.get_rank()
        ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
            rank) + '/'
        context.reset_auto_parallel_context()
        _set_bert_all_reduce_split()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)

    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
        logger.info("train steps: {}".format(args_opt.train_steps))

    optimizer = _get_optimizer(args_opt, net_with_loss)
    callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
    if args_opt.enable_save_ckpt == "true" and rank == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        net_with_grads = BertTrainOneStepWithLossScaleCell(
            net_with_loss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads, frequency=cfg.Thor.frequency)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
예제 #4
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=4,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="false",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="false",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="100",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id,
                        save_graphs=False)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        from mindspore.parallel._auto_parallel_context import auto_parallel_context
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217],
                    "hccl_world_groupsum1")
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217],
                    "hccl_world_groupsum3")
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205],
                    "hccl_world_groupsum1")
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205],
                    "hccl_world_groupsum3")
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421],
                    "hccl_world_groupsum1")
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421],
                    "hccl_world_groupsum3")
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397],
                    "hccl_world_groupsum1")
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397],
                    "hccl_world_groupsum3")
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
        logger.info("train steps: {}".format(args_opt.train_steps))

    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(net_with_loss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == "Thor":
        lr = get_bert_lr()
        damping = get_bert_damping()
        optimizer = THOR(
            filter(lambda x: x.requires_grad,
                   net_with_loss.get_parameters()), lr, cfg.Thor.momentum,
            filter(lambda x: 'matrix_A' in x.name,
                   net_with_loss.get_parameters()),
            filter(lambda x: 'matrix_G' in x.name,
                   net_with_loss.get_parameters()), cfg.Thor.weight_decay,
            cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
            bert_net_cfg.batch_size, damping)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]"
            .format(cfg.optimizer))
    callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
    if args_opt.enable_save_ckpt == "true" and rank == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        net_with_grads = BertTrainOneStepWithLossScaleCell(
            net_with_loss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads, frequency=cfg.Thor.frequency)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
예제 #5
0
파일: train.py 프로젝트: yrpang/mindspore
    net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
                   frequency=config.frequency, batch_size=config.batch_size)

    # define loss, model
    if not config.use_label_smooth:
        config.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
    opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), config.momentum,
               filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
               filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
               filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
               filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
               config.weight_decay, config.loss_scale)
    loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,
                  keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency,
                  use_dynamic_frequency=config.use_dynamic_frequency,
                  first_stage_steps=config.first_stage_steps)

    # define callbacks
    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossMonitor()
    cb = [time_cb, loss_cb]
    if config.save_checkpoint:
        config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                     keep_checkpoint_max=config.keep_checkpoint_max)
        ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
        cb += [ckpt_cb]

    # train model
    model.train(config.epoch_size, dataset, callbacks=cb)