示例#1
0
def do_train(dataset=None,
             network=None,
             load_checkpoint_path="",
             save_checkpoint_path="",
             epoch_num=1):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError(
            "Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    # optimizer
    optimizer = Adam(network.trainable_params(),
                     learning_rate=optimizer_cfg.learning_rate)
    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(
        prefix="classifier",
        directory=None if save_checkpoint_path == "" else save_checkpoint_path,
        config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = BertFinetuneCell(network,
                                    optimizer=optimizer,
                                    scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [
        TimeMonitor(dataset.get_dataset_size()),
        LossCallBack(dataset.get_dataset_size()), ckpoint_cb
    ]
    model.train(epoch_num, dataset, callbacks=callbacks)
示例#2
0
def run_predistill():
    """
    run predistill
    """
    cfg = phase1_cfg
    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = args_opt.load_gd_ckpt_path
    netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
                                         student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
                                         is_training=True, task_type='classification',
                                         num_labels=args_opt.num_labels, is_predistill=True)

    rank = 0
    device_num = 1
    dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.train_data_dir, args_opt.schema_dir)

    dataset_size = dataset.get_dataset_size()
    print('td1 dataset size: ', dataset_size)
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase1_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                   end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                   warmup_steps=int(dataset_size / 10),
                                   decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
                                   power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                    {'params': other_params, 'weight_decay': 0.0},
                    {'order_params': params}]

    optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
                                                                               args_opt.save_ckpt_step,
                                                                               args_opt.max_ckpt_num,
                                                                               td_phase1_save_ckpt_dir)]
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
                                             scale_factor=cfg.scale_factor,
                                             scale_window=cfg.scale_window)
    netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(repeat_count, dataset, callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
def compile_net(net, grad_accumulation_step):
    context.set_context(save_graphs=True)
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2
    dataset = Dataset(_x, _b)
    opt = Momentum(net.trainable_params(), learning_rate, momentum)
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
    net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell,
                                                     accumulation_steps=grad_accumulation_step)
    model = Model(net_wrap)
    model.train(epoch_size, dataset, dataset_sink_mode=False)
    context.reset_auto_parallel_context()
示例#4
0
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    # optimizer
    if optimizer_cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0}]
        optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    elif optimizer_cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
                                       end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
    elif optimizer_cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
                             momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="ner",
                                 directory=None if save_checkpoint_path == "" else save_checkpoint_path,
                                 config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
    train_begin = time.time()
    model.train(epoch_num, dataset, callbacks=callbacks)
    train_end = time.time()
    print("latency: {:.6f} s".format(train_end - train_begin))
示例#5
0
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    epoch_num = dataset.get_repeat_count()
    # optimizer
    if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
                                             decay_steps=steps_per_epoch * epoch_num,
                                             learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate,
                                             end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate,
                                             power=optimizer_cfg.AdamWeightDecayDynamicLR.power,
                                             warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                             weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay,
                                             eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps)
    elif optimizer_cfg.optimizer == 'Lamb':
        optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num,
                         start_learning_rate=optimizer_cfg.Lamb.start_learning_rate,
                         end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
                         power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay,
                         warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                         decay_filter=optimizer_cfg.Lamb.decay_filter)
    elif optimizer_cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
                             momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb]
    model.train(epoch_num, dataset, callbacks=callbacks)
示例#6
0
def run_general_distill():
    """
    run general distill
    """
    parser = argparse.ArgumentParser(description='tinybert general distill')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="3",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--save_ckpt_step",
                        type=int,
                        default=100,
                        help="Enable data sink, default is true.")
    parser.add_argument("--max_ckpt_num",
                        type=int,
                        default=1,
                        help="Enable data sink, default is true.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default=1,
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_ckpt_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_teacher_ckpt_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="tfrecord",
        help="dataset type tfrecord/mindrecord, default is tfrecord")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    enable_loss_scale = True
    if args_opt.device_target == "GPU":
        if bert_student_net_cfg.compute_type != mstype.float32:
            logger.warning(
                'Compute about the student only support float32 temporarily, run with float32.'
            )
            bert_student_net_cfg.compute_type = mstype.float32
        # Backward of the network are calculated using fp32,
        # and the loss scale is not necessary
        enable_loss_scale = False

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    if args_opt.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif args_opt.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        raise Exception("dataset format is not supported yet")
    dataset = create_tinybert_dataset('gd',
                                      common_cfg.batch_size,
                                      device_num,
                                      rank,
                                      args_opt.do_shuffle,
                                      args_opt.data_dir,
                                      args_opt.schema_dir,
                                      data_type=dataset_type)
    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    print("dataset repeatcount: ", dataset.get_repeat_count())
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=common_cfg.loss_scale_value,
            scale_factor=common_cfg.scale_factor,
            scale_window=common_cfg.scale_window)
        netwithgrads = BertTrainWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
示例#7
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init('hccl')
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        from mindspore.parallel._auto_parallel_context import auto_parallel_context
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num,
                                               rank, args_opt.do_shuffle,
                                               args_opt.enable_data_sink,
                                               args_opt.data_sink_steps,
                                               args_opt.data_dir,
                                               args_opt.schema_dir)
    data_epoch_size = new_repeat_count // args_opt.epoch_size  # Epoch nums in one dataset.
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() * new_repeat_count,
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * new_repeat_count,
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps,
            warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [
        TimeMonitor(ds.get_dataset_size()),
        LossCallBack(data_epoch_size)
    ]
    if args_opt.enable_save_ckpt == "true":
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1,resume=False):
    """
    Do train
    Args:
        dataset: the train dataset.
        network:  the network with loss
        load_checkpoint_path: the file path which saved pretrain model checkpoint.
        save_checkpoint_path:  the file path which will save finetune model checkpoint.
        epoch_num: the number of epoch
    """
    if load_checkpoint_path == "":
        raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
    
    steps_per_epoch = dataset.get_dataset_size() # samples / batch_size
    
    #print info
    print("="*30,"TRAIN INFO","="*30)
    
    print("optimizer: {}".format(cfg.optimizer))
    
    
    
    #Select Optimizer
    if cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=cfg.AdamWeightDecay.power)
        params = network.trainable_params() # return a list of all trainable parmeters of the network

        # Use parameter groups and set different values
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) # without layernorm and bias
        other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) # with layernorm and bias
        group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0}]
        optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == 'Lamb':
    
        #print info
        print("lr: {}".format(cfg.Lamb.learning_rate))
        print("end_learning_rate: {}".format(cfg.Lamb.end_learning_rate))
        #print("warmup_steps: {}".format(int(steps_per_epoch * epoch_num * 0.1)))
        print("power: {}".format(cfg.Lamb.power))
        
        lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
                                       end_learning_rate=cfg.Lamb.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), lr_schedule)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="gpt2_summarization",
                                 directory=None if save_checkpoint_path == "" else save_checkpoint_path,
                                 config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    reorganized_param_dict = dict()
    if resume == False :
        print("Do not resume.\nRESUME STATE: {}".format(resume))
        for netName in param_dict:
            reorganized_param_dict['gpt2.gpt2.'+netName] = param_dict[netName]
        reorganized_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
    else:
        print("Start to resume training.\nRESUME STATE: {}".format(resume))
        reorganized_param_dict = param_dict
    load_param_into_net(network, reorganized_param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
    netwithgrads.set_train(True)
    loss_cb = LossMonitor(per_print_times=1)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
    print("============== Starting Training For Summrization Task ==============")
    model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
    print("============== Summrization Training Success ==============")
示例#9
0
def train():
    '''
    finetune function
    '''
    # BertCLS train for classification
    # BertNER train for sequence labeling

    if cfg.task == 'NER':
        tag_to_index = None
        if cfg.use_crf:
            tag_to_index = json.loads(open(cfg.label2id_file).read())
            print(tag_to_index)
            max_val = len(tag_to_index)
            tag_to_index["<START>"] = max_val
            tag_to_index["<STOP>"] = max_val + 1
            number_labels = len(tag_to_index)
        else:
            number_labels = cfg.num_labels

        netwithloss = BertNER(bert_net_cfg,
                              cfg.batch_size,
                              True,
                              num_labels=number_labels,
                              use_crf=cfg.use_crf,
                              tag_to_index=tag_to_index,
                              dropout_prob=0.1)
    elif cfg.task == 'Classification':
        netwithloss = BertCLS(bert_net_cfg,
                              True,
                              num_labels=cfg.num_labels,
                              dropout_prob=0.1,
                              assessment_method=cfg.assessment_method)
    else:
        raise Exception("task error, NER or Classification is supported.")

    dataset = get_dataset(data_file=cfg.data_file, batch_size=cfg.batch_size)
    steps_per_epoch = dataset.get_dataset_size()
    print('steps_per_epoch:', steps_per_epoch)

    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    if cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.AdamWeightDecay.power)
        params = netwithloss.trainable_params()
        decay_params = list(
            filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
                   params))
        group_params = [{
            'params':
            decay_params,
            'weight_decay':
            optimizer_cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }]
        optimizer = AdamWeightDecay(group_params,
                                    lr_schedule,
                                    eps=optimizer_cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.Lamb.learning_rate,
            end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(netwithloss.trainable_params(),
                         learning_rate=lr_schedule)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(
            netwithloss.trainable_params(),
            learning_rate=optimizer_cfg.Momentum.learning_rate,
            momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported.")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix,
                                 directory=cfg.ckpt_dir,
                                 config=ckpt_config)
    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    load_param_into_net(netwithloss, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = BertFinetuneCell(netwithloss,
                                    optimizer=optimizer,
                                    scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [
        TimeMonitor(dataset.get_dataset_size()),
        LossCallBack(dataset.get_dataset_size()), ckpoint_cb
    ]
    model.train(cfg.epoch_num,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)
def run_general_distill():
    """
    run general distill
    """
    parser = argparse.ArgumentParser(description='tinybert general distill')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="3",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--save_ckpt_step",
                        type=int,
                        default=100,
                        help="Enable data sink, default is true.")
    parser.add_argument("--max_ckpt_num",
                        type=int,
                        default=1,
                        help="Enable data sink, default is true.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default=1,
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_ckpt_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_teacher_ckpt_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    if args_opt.distribute == "true":
        D.init('hccl')
        device_num = args_opt.device_num
        rank = args_opt.device_id % device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.data_dir, args_opt.schema_dir)

    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset.get_dataset_size(
        ) // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]

    update_cell = DynamicLossScaleUpdateCell(
        loss_scale_value=common_cfg.loss_scale_value,
        scale_factor=common_cfg.scale_factor,
        scale_window=common_cfg.scale_window)

    netwithgrads = BertTrainWithLossScaleCell(netwithloss,
                                              optimizer=optimizer,
                                              scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
示例#11
0
def run_task_distill(ckpt_file):
    """
    run task distill
    """
    if ckpt_file == '':
        raise ValueError("Student ckpt file should not be None")
    cfg = phase2_cfg

    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = ckpt_file
    netwithloss = BertNetworkWithLoss_td(
        teacher_config=td_teacher_net_cfg,
        teacher_ckpt=load_teacher_checkpoint_path,
        student_config=td_student_net_cfg,
        student_ckpt=load_student_checkpoint_path,
        is_training=True,
        task_type=args_opt.task_type,
        num_labels=task.num_labels,
        is_predistill=False)

    rank = 0
    device_num = 1
    train_dataset = create_tinybert_dataset('td',
                                            cfg.batch_size,
                                            device_num,
                                            rank,
                                            args_opt.do_shuffle,
                                            args_opt.train_data_dir,
                                            args_opt.schema_dir,
                                            data_type=dataset_type)

    dataset_size = train_dataset.get_dataset_size()
    print('td2 train dataset size: ', dataset_size)
    print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size(
        ) // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase2_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(
        learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
        power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(
        filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(
        filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=optimizer_cfg.AdamWeightDecay.eps)

    eval_dataset = create_tinybert_dataset('td',
                                           eval_cfg.batch_size,
                                           device_num,
                                           rank,
                                           args_opt.do_shuffle,
                                           args_opt.eval_data_dir,
                                           args_opt.schema_dir,
                                           data_type=dataset_type)
    print('td2 eval dataset size: ', eval_dataset.get_dataset_size())

    if args_opt.do_eval.lower() == "true":
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            EvalCallBack(netwithloss.bert, eval_dataset)
        ]
    else:
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                          args_opt.max_ckpt_num, td_phase2_save_ckpt_dir)
        ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        netwithgrads = BertEvaluationWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                train_dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
示例#12
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_task_sink",
                        type=str,
                        default="true",
                        help="Enable task sink, default is true.")
    parser.add_argument("--enable_loop_sink",
                        type=str,
                        default="true",
                        help="Enable loop sink, default is true.")
    parser.add_argument("--enable_mem_reuse",
                        type=str,
                        default="true",
                        help="Enable mem reuse, default is true.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--checkpoint_path",
                        type=str,
                        default="",
                        help="Checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=args_opt.device_id)
    context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"),
                        enable_loop_sink=(args_opt.enable_loop_sink == "true"),
                        enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
    context.set_context(reserve_class_name_in_scope=False)

    if args_opt.distribute == "true":
        device_num = args_opt.device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        D.init()
        rank = args_opt.device_id % device_num
    else:
        rank = 0
        device_num = 1

    ds = create_bert_dataset(args_opt.epoch_size, device_num, rank,
                             args_opt.do_shuffle, args_opt.enable_data_sink,
                             args_opt.data_sink_steps, args_opt.data_dir,
                             args_opt.schema_dir)

    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() *
                         ds.get_repeat_count(),
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [LossCallBack()]
    if args_opt.enable_save_ckpt == "true":
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.checkpoint_path:
        param_dict = load_checkpoint(args_opt.checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(ds.get_repeat_count(),
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
示例#13
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument(
        "--accumulation_steps",
        type=int,
        default="1",
        help=
        "Accumulating gradients N times before weight update, default is 1.")
    parser.add_argument(
        "--allreduce_post_accumulation",
        type=str,
        default="true",
        choices=["true", "false"],
        help=
        "Whether to allreduce after accumulation of N steps or after each step, default is true."
    )
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    parser.add_argument("--enable_graph_kernel",
                        type=str,
                        default="auto",
                        choices=["auto", "true", "false"],
                        help="Accelerate by graph kernel, default is auto.")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
        ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
            get_rank()) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        if args_opt.device_target == 'Ascend':
            _set_bert_all_reduce_split()
    else:
        rank = 0
        device_num = 1

    is_auto_enable_graph_kernel = _auto_enable_graph_kernel(
        args_opt.device_target, args_opt.enable_graph_kernel)

    if args_opt.enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
        context.set_context(enable_graph_kernel=True)

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \
        not is_auto_enable_graph_kernel:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    if args_opt.accumulation_steps > 1:
        logger.info("accumulation steps: {}".format(
            args_opt.accumulation_steps))
        logger.info("global batch size: {}".format(
            cfg.batch_size * args_opt.accumulation_steps))
        if args_opt.enable_data_sink == "true":
            args_opt.data_sink_steps *= args_opt.accumulation_steps
            logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
        if args_opt.enable_save_ckpt == "true":
            args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
            logger.info("save checkpoint steps: {}".format(
                args_opt.save_checkpoint_steps))

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        train_steps = args_opt.train_steps * args_opt.accumulation_steps
        new_repeat_count = min(new_repeat_count,
                               train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size(
        ) // args_opt.accumulation_steps
        logger.info("train steps: {}".format(args_opt.train_steps))

    optimizer = _get_optimizer(args_opt, net_with_loss)
    callback = [
        TimeMonitor(args_opt.data_sink_steps),
        LossCallBack(ds.get_dataset_size())
    ]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_bert',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        accumulation_steps = args_opt.accumulation_steps
        enable_global_norm = cfg.enable_global_norm
        if accumulation_steps <= 1:
            if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
                net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
            else:
                net_with_grads = BertTrainOneStepWithLossScaleCell(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
        else:
            allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
            net_with_accumulation = (
                BertTrainAccumulationAllReducePostWithLossScaleCell
                if allreduce_post else
                BertTrainAccumulationAllReduceEachWithLossScaleCell)
            net_with_grads = net_with_accumulation(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell,
                accumulation_steps=accumulation_steps,
                enable_global_norm=enable_global_norm)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
def do_train(dataset=None,
             network=None,
             load_checkpoint_path="",
             save_checkpoint_path="",
             epoch_num=1):
    """
    Do train
    Args:
        dataset: the train dataset.
        network:  the network with loss
        load_checkpoint_path: the file path which saved pretrain model checkpoint.
        save_checkpoint_path:  the file path which will save finetune model checkpoint.
        epoch_num: the number of epoch
    """
    if load_checkpoint_path == "":
        raise ValueError(
            "Pretrain model missed, finetune task must load pretrain model!")

    steps_per_epoch = dataset.get_dataset_size(
    )  # samples / batch_size  doing####

    if cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = GPT2LearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
            decay_steps=steps_per_epoch * epoch_num,
            power=cfg.AdamWeightDecay.power)
        params = network.trainable_params(
        )  # return a list of all trainable parmeters of the network

        # Use parameter groups and set different values
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter,
                                   params))  # without layernorm and bias
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x),
                   params))  # with layernorm and bias
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }]
        optimizer = AdamWeightDecay(group_params,
                                    lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)

    elif cfg.optimizer == 'Lamb':
        lr_schedule = GPT2LearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
            decay_steps=steps_per_epoch * epoch_num,
            power=cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), lr_schedule)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(),
                             cfg.Momentum.learning_rate, cfg.Momentum.momentum)
    else:
        raise Exception(
            "Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]"
        )

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(
        prefix="gpt2_language_model_wiki2",
        directory=None if save_checkpoint_path == "" else save_checkpoint_path,
        config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)

    final_param_dict = {}
    for k, v in param_dict.items():
        final_param_dict['gpt2_loss.gpt2.gpt2.' + k] = param_dict[k]
    # set the weights of final linear weights to weights of gpt2 token embedding
    final_param_dict['gpt2_loss.gpt2.dense1.weight'] = param_dict[
        'gpt2_embedding_lookup.embedding_table']

    load_param_into_net(network, final_param_dict)
    print("Load new parameter successfully!\n")

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = GPT2FinetuneCell(network,
                                    optimizer=optimizer,
                                    scale_update_cell=update_cell)
    netwithgrads.set_train(True)

    loss_cb = LossMonitor()

    model = Model(netwithgrads)
    # callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
    callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]

    print("============== Starting Training ==============")
    model.train(epoch_num, dataset, callbacks=callbacks)
    print("============== Training Success ==============")
示例#15
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=4,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="false",
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="false",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="100",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id,
                        save_graphs=False)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    context.set_context(max_call_depth=3000)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        D.init()
        device_num = D.get_group_size()
        rank = D.get_rank()
        ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
            rank) + '/'
        context.reset_auto_parallel_context()
        _set_bert_all_reduce_split()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)

    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
        logger.info("train steps: {}".format(args_opt.train_steps))

    optimizer = _get_optimizer(args_opt, net_with_loss)
    callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
    if args_opt.enable_save_ckpt == "true" and rank == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        net_with_grads = BertTrainOneStepWithLossScaleCell(
            net_with_loss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads, frequency=cfg.Thor.frequency)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
示例#16
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse_init()
    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    is_auto_enable_graph_kernel = _auto_enable_graph_kernel(
        args_opt.device_target, args_opt.enable_graph_kernel)
    _set_graph_kernel_context(args_opt.device_target,
                              args_opt.enable_graph_kernel,
                              is_auto_enable_graph_kernel)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
        ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
            get_rank()) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        _set_bert_all_reduce_split()
    else:
        rank = 0
        device_num = 1

    _check_compute_type(args_opt, is_auto_enable_graph_kernel)

    if args_opt.accumulation_steps > 1:
        logger.info("accumulation steps: {}".format(
            args_opt.accumulation_steps))
        logger.info("global batch size: {}".format(
            cfg.batch_size * args_opt.accumulation_steps))
        if args_opt.enable_data_sink == "true":
            args_opt.data_sink_steps *= args_opt.accumulation_steps
            logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
        if args_opt.enable_save_ckpt == "true":
            args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
            logger.info("save checkpoint steps: {}".format(
                args_opt.save_checkpoint_steps))

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        train_steps = args_opt.train_steps * args_opt.accumulation_steps
        new_repeat_count = min(new_repeat_count,
                               train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size(
        ) // args_opt.accumulation_steps
        logger.info("train steps: {}".format(args_opt.train_steps))

    optimizer = _get_optimizer(args_opt, net_with_loss)
    callback = [
        TimeMonitor(args_opt.data_sink_steps),
        LossCallBack(ds.get_dataset_size())
    ]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_bert',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        accumulation_steps = args_opt.accumulation_steps
        enable_global_norm = cfg.enable_global_norm
        if accumulation_steps <= 1:
            if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
                net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
            else:
                net_with_grads = BertTrainOneStepWithLossScaleCell(
                    net_with_loss,
                    optimizer=optimizer,
                    scale_update_cell=update_cell)
        else:
            allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
            net_with_accumulation = (
                BertTrainAccumulationAllReducePostWithLossScaleCell
                if allreduce_post else
                BertTrainAccumulationAllReduceEachWithLossScaleCell)
            net_with_grads = net_with_accumulation(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell,
                accumulation_steps=accumulation_steps,
                enable_global_norm=enable_global_norm)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads)
    model = ConvertModelUtils().convert_to_thor_model(
        model,
        network=net_with_grads,
        optimizer=optimizer,
        frequency=cfg.Thor.frequency)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
示例#17
0
def do_train(dataset=None,
             network=None,
             load_checkpoint_path="",
             save_checkpoint_path="",
             epoch_num=1):
    """
    Do train
    Args:
        dataset: the train dataset.
        network:  the network with loss
        load_checkpoint_path: the file path which saved pretrained model checkpoint.
        save_checkpoint_path:  the file path which will save finetuned model checkpoint.
        epoch_num: the number of epoch.
    """
    if load_checkpoint_path == "":
        raise ValueError(
            "Pretrain model missed, finetune task must load pretrain model!")

    steps_per_epoch = dataset.get_dataset_size()

    # optimizer
    if cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = GPT2LearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
            decay_steps=steps_per_epoch * epoch_num,
            power=cfg.AdamWeightDecay.power)
        params = network.trainable_params()

        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }]
        optimizer = AdamWeightDecay(group_params,
                                    lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == 'Lamb':
        lr_schedule = GPT2LearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
            decay_steps=steps_per_epoch * epoch_num,
            power=cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), lr_schedule)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(),
                             cfg.Momentum.learning_rate, cfg.Momentum.momentum)
    else:
        raise Exception(
            "Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]"
        )

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    prefix_name = "gpt2_translation_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
                  + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
    ckpoint_cb = ModelCheckpoint(
        prefix=prefix_name,
        directory=None if save_checkpoint_path == "" else save_checkpoint_path,
        config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)

    final_param_dict = {}
    for name, _ in param_dict.items():
        final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
    final_param_dict['gpt2.dense1.weight'] = param_dict[
        'gpt2_embedding_lookup.embedding_table']

    load_param_into_net(network, final_param_dict)
    print("Load the pretrained parameter successfully! \n")

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = GPT2FinetuneCell(network,
                                    optimizer=optimizer,
                                    scale_update_cell=update_cell)
    netwithgrads.set_train(True)
    loss_cb = LossMonitor(per_print_times=1)

    model = Model(netwithgrads)

    callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]

    print(
        "=================== Starting Training For Translation Task ===================="
    )
    model.train(epoch_num,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=False)
    print(
        "===================      Translation Training Success      ===================="
    )
示例#18
0
from tests.dataset_mock import MindData

GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()


@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
    return grad * reciprocal(scale)


update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536,
                                         scale_factor=2,
                                         scale_window=1000)


@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
    dt = F.dtype(grad)
    if clip_type == 0:
        new_grad = C.clip_by_value(
            grad, F.cast(F.tuple_to_array((-clip_value, )), dt),
            F.cast(F.tuple_to_array((clip_value, )), dt))
    else:
        new_grad = nn.ClipByNorm()(grad,
                                   F.cast(F.tuple_to_array((clip_value, )),
                                          dt))
    return new_grad
示例#19
0
def run_general_distill():
    """
    run general distill
    """
    args_opt = get_argument()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        reserve_class_name_in_scope=False)
    if args_opt.device_target == "Ascend":
        context.set_context(device_id=args_opt.device_id)

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    enable_loss_scale = True
    if args_opt.device_target == "GPU":
        if bert_student_net_cfg.compute_type != mstype.float32:
            logger.warning(
                'Compute about the student only support float32 temporarily, run with float32.'
            )
            bert_student_net_cfg.compute_type = mstype.float32
        # Backward of the network are calculated using fp32,
        # and the loss scale is not necessary
        enable_loss_scale = False

    if args_opt.device_target == "CPU":
        logger.warning(
            'CPU only support float32 temporarily, run with float32.')
        bert_teacher_net_cfg.dtype = mstype.float32
        bert_teacher_net_cfg.compute_type = mstype.float32
        bert_student_net_cfg.dtype = mstype.float32
        bert_student_net_cfg.compute_type = mstype.float32
        enable_loss_scale = False

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    if args_opt.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif args_opt.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        raise Exception("dataset format is not supported yet")
    dataset = create_tinybert_dataset('gd',
                                      common_cfg.batch_size,
                                      device_num,
                                      rank,
                                      args_opt.do_shuffle,
                                      args_opt.data_dir,
                                      args_opt.schema_dir,
                                      data_type=dataset_type)
    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    print("dataset repeatcount: ", dataset.get_repeat_count())
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=common_cfg.loss_scale_value,
            scale_factor=common_cfg.scale_factor,
            scale_window=common_cfg.scale_window)
        netwithgrads = BertTrainWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
示例#20
0
def run_pretrain(args_opt):
    """pre-train bert"""
    global device_id
    global device_num
    global rank_id
    global job_id
    args_opt.device_id = device_id
    args_opt.device_num = device_num
    sync_dataset(args_opt.data_url)

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init('hccl')
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
        from mindspore.parallel._auto_parallel_context import auto_parallel_context
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [29, 58, 87, 116, 145, 174, 203, 217])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [28, 55, 82, 109, 136, 163, 190, 205])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [30, 90, 150, 210, 270, 330, 390, 421])
            else:
                auto_parallel_context().set_all_reduce_fusion_split_indices(
                    [38, 93, 148, 203, 258, 313, 368, 397])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num,
                                               rank, args_opt.do_shuffle,
                                               args_opt.enable_data_sink,
                                               args_opt.data_sink_steps,
                                               args_opt.data_dir,
                                               args_opt.schema_dir)
    if args_opt.train_steps > 0:
        new_repeat_count = min(
            new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
    netwithloss = BertNetworkWithLoss(bert_net_cfg, True)

    if cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=ds.get_dataset_size() * new_repeat_count,
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         warmup_steps=cfg.Lamb.warmup_steps,
                         weight_decay=cfg.Lamb.weight_decay,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=ds.get_dataset_size() * new_repeat_count,
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps,
            warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]"
            .format(cfg.optimizer))
    callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
    print("Enable save checkpoint: ", args_opt.enable_save_ckpt)
    print("Rank ID: ", rank_id)
    if args_opt.enable_save_ckpt == "true" and rank_id % device_num == 0:
        print("Enable save checkpoint")
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
                                     directory=ckpt_save_dir,
                                     config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(netwithloss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        netwithgrads = BertTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)

    model = Model(netwithgrads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"))
示例#21
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument(
        "--accumulation_steps",
        type=int,
        default="1",
        help=
        "Accumulating gradients N times before weight update, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    29, 58, 87, 116, 145, 174, 203, 217
                ])
            else:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    28, 55, 82, 109, 136, 163, 190, 205
                ])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    30, 90, 150, 210, 270, 330, 390, 421
                ])
            else:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    38, 93, 148, 203, 258, 313, 368, 397
                ])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    if args_opt.accumulation_steps > 1:
        logger.info("accumulation steps: {}".format(
            args_opt.accumulation_steps))
        logger.info("global batch size: {}".format(
            bert_net_cfg.batch_size * args_opt.accumulation_steps))
        if args_opt.enable_data_sink == "true":
            args_opt.data_sink_steps *= args_opt.accumulation_steps
            logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
        if args_opt.enable_save_ckpt == "true":
            args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
            logger.info("save checkpoint steps: {}".format(
                args_opt.save_checkpoint_steps))

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        train_steps = args_opt.train_steps * args_opt.accumulation_steps
        new_repeat_count = min(new_repeat_count,
                               train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size(
        ) // args_opt.accumulation_steps
        logger.info("train steps: {}".format(args_opt.train_steps))

    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(net_with_loss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
            .format(cfg.optimizer))
    callback = [
        TimeMonitor(args_opt.data_sink_steps),
        LossCallBack(ds.get_dataset_size())
    ]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_bert',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        if args_opt.accumulation_steps <= 1:
            net_with_grads = BertTrainOneStepWithLossScaleCell(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell)
        else:
            accumulation_steps = args_opt.accumulation_steps
            net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell,
                accumulation_steps=accumulation_steps,
                enable_global_norm=cfg.enable_global_norm)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
示例#22
0
def test_train():
    '''
    finetune function
    '''
    target = args_opt.device_target
    if target == "Ascend":
        devid = int(os.getenv('DEVICE_ID'))
        context.set_context(mode=context.GRAPH_MODE,
                            device_target="Ascend",
                            device_id=devid)
    elif target == "GPU":
        context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
    else:
        raise Exception("Target error, GPU or Ascend is supported.")
    #BertCLSTrain for classification
    #BertNERTrain for sequence labeling
    if cfg.task == 'NER':
        if cfg.use_crf:
            netwithloss = BertNER(bert_net_cfg,
                                  True,
                                  num_labels=len(tag_to_index),
                                  use_crf=True,
                                  tag_to_index=tag_to_index,
                                  dropout_prob=0.1)
        else:
            netwithloss = BertNER(bert_net_cfg,
                                  True,
                                  num_labels=cfg.num_labels,
                                  dropout_prob=0.1)
    elif cfg.task == 'SQUAD':
        netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
    else:
        netwithloss = BertCLS(bert_net_cfg,
                              True,
                              num_labels=cfg.num_labels,
                              dropout_prob=0.1)
    if cfg.task == 'SQUAD':
        dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
    else:
        dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    if cfg.optimizer == 'AdamWeightDecayDynamicLR':
        optimizer = AdamWeightDecayDynamicLR(
            netwithloss.trainable_params(),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
            end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
            power=cfg.AdamWeightDecayDynamicLR.power,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
            eps=cfg.AdamWeightDecayDynamicLR.eps)
    elif cfg.optimizer == 'Lamb':
        optimizer = Lamb(netwithloss.trainable_params(),
                         decay_steps=steps_per_epoch * cfg.epoch_num,
                         start_learning_rate=cfg.Lamb.start_learning_rate,
                         end_learning_rate=cfg.Lamb.end_learning_rate,
                         power=cfg.Lamb.power,
                         weight_decay=cfg.Lamb.weight_decay,
                         warmup_steps=int(steps_per_epoch * cfg.epoch_num *
                                          0.1),
                         decay_filter=cfg.Lamb.decay_filter)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(netwithloss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported.")
    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix,
                                 directory=cfg.ckpt_dir,
                                 config=ckpt_config)
    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    load_param_into_net(netwithloss, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    if cfg.task == 'SQUAD':
        netwithgrads = BertSquadCell(netwithloss,
                                     optimizer=optimizer,
                                     scale_update_cell=update_cell)
    else:
        netwithgrads = BertFinetuneCell(netwithloss,
                                        optimizer=optimizer,
                                        scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
def run_train_pipeline(args_opt):
    device_id = int(os.getenv("DEVICE_ID"))
    rank_id = int(os.getenv("RANK_ID"))
    local_rank = rank_id
    print('local_rank:{}, device id:{} start to run...'.format(
        local_rank, device_id),
          flush=True)
    context.set_context(save_graphs=False,
                        mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=device_id)
    context.set_context(variable_memory_max_size="31GB")
    strategy_ckpt_save_file = "/cache/" + "strategy" + str(
        local_rank) + ".ckpt"
    if args_opt.distribute == "true":
        D.init()
        device_num = D.get_group_size()
        rank = D.get_rank()
        print("device_id is {}, rank_id is {}, device_num is {}".format(
            device_id, rank, device_num))
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
            gradients_mean=False,
            device_num=device_num,
            full_batch=True,
            loss_repeated_mean=True,
            enable_parallel_optimizer=bool(args_opt.optimizer_shard),
            pipeline_stages=args_opt.stage_num,
            strategy_ckpt_save_file=strategy_ckpt_save_file)
        set_algo_parameters(elementwise_op_strategy_follow=True)
        _set_multi_subgraphs()
    else:
        rank = 0
        device_num = 1

    model_parallel_num = args_opt.tensor_model_parallel_num
    stage_device_num = int(device_num / args_opt.stage_num)
    data_parallel_num = int(stage_device_num / model_parallel_num)
    per_batch_size = args_opt.per_batch_size
    batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
    config = PANGUALPHAConfig(data_parallel_num=data_parallel_num,
                              model_parallel_num=model_parallel_num,
                              batch_size=batch_size,
                              seq_length=args_opt.seq_length,
                              vocab_size=args_opt.vocab_size,
                              embedding_size=args_opt.embedding_size,
                              num_layers=args_opt.num_layers,
                              num_heads=args_opt.num_heads,
                              expand_ratio=4,
                              post_layernorm_residual=False,
                              dropout_rate=0.1,
                              compute_dtype=mstype.float16,
                              use_past=False,
                              self_layernorm=True,
                              forward_reduce_scatter=True,
                              stage_num=args_opt.stage_num,
                              micro_size=args_opt.micro_size,
                              word_emb_dp=False)
    print("===config is: ", config, flush=True)
    pangu_alpha = PANGUALPHAPipeline(config)
    loss = CrossEntropyLoss(config)
    pangu_alpha_with_loss = PANGUALPHAWithLossPipeline(config, pangu_alpha,
                                                       loss)
    pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)

    print("=====args_opt is: ", args_opt, flush=True)
    lr = LearningRate(learning_rate=args_opt.start_lr,
                      end_learning_rate=args_opt.end_lr,
                      warmup_steps=args_opt.warmup_step,
                      decay_steps=args_opt.decay_steps)

    per_stage_layers = config.num_layers // config.stage_num
    per_stage_devices = device_num // config.stage_num
    self_stage = rank_id // per_stage_devices
    range_min = self_stage * per_stage_layers
    range_max = range_min + per_stage_layers
    if self_stage == 0:
        params = [pangu_alpha.embedding_table]
        params.extend(pangu_alpha.backbone.pangu_alpha_embedding.
                      position_embedding.trainable_params())
    elif self_stage == config.stage_num - 1:
        params = [pangu_alpha.embedding_table]
        params.extend(pangu_alpha.backbone.layernorm.trainable_params())
        params.extend(
            pangu_alpha.backbone.top_query_embedding.trainable_params())
    else:
        params = []
    for i in range(range_min, range_max):
        params.extend(pangu_alpha.backbone.blocks[i].trainable_params())

    decay_filter = lambda x: 'layernorm' not in x.name.lower(
    ) and "bias" not in x.name.lower()

    decay_params = list(filter(decay_filter, params))
    other_params = list(filter(lambda x: not decay_filter(x), params))
    group_params = [{
        'params': decay_params,
        'weight_decay': args_opt.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]
    if args_opt.optimizer == "lamb":
        optimizer = nn.Lamb(group_params, learning_rate=lr)
    else:
        optimizer = nn.AdamWeightDecay(group_params,
                                       learning_rate=lr,
                                       beta1=0.9,
                                       beta2=0.95,
                                       eps=1e-8)

    save_steps = args_opt.save_steps
    ckpt_dir = os.path.join(args_opt.ckpt_save_sir, f"rank_{str(local_rank)}")
    if not os.path.exists(ckpt_dir):
        Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

    ds = create_dataset(config.batch_size,
                        data_path=args_opt.data_url,
                        data_start_index=0)

    epoch_num = args_opt.epoch_size
    step_per_epoch = ds.get_dataset_size()
    callback_size = args_opt.sink_size
    actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
    callback = [
        TimeMonitor(callback_size),
        LossCallBack(callback_size, local_rank, config.stage_num)
    ]
    config_ck = CheckpointConfig(save_checkpoint_steps=save_steps,
                                 keep_checkpoint_max=1,
                                 integrated_save=False,
                                 filter_prefix="accu_grads")
    ckpoint_cb = ModelCheckpoint(prefix="PanguAlpha",
                                 directory=ckpt_dir,
                                 config=config_ck)
    callback.append(ckpoint_cb)
    loss_scale_value = math.pow(2, 32)
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
                                             scale_factor=2,
                                             scale_window=1000)

    pangu_alpha_with_grads = PANGUALPHATrainPipelineWithLossScaleCell(
        pangu_alpha_with_loss,
        optimizer=optimizer,
        config=config,
        scale_update_cell=update_cell)

    model = Model(pangu_alpha_with_grads)
    de.config.set_sending_batches(2 * args_opt.sink_size)
    model.train(actual_epoch_num,
                ds,
                callbacks=callback,
                sink_size=callback_size,
                dataset_sink_mode=True)