コード例 #1
0
ファイル: run_pretrain.py プロジェクト: brucejunlee/mindspore
def _get_optimizer(args_opt, network):
    """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
                                       end_learning_rate=cfg.Lamb.end_learning_rate,
                                       warmup_steps=cfg.Lamb.warmup_steps,
                                       decay_steps=args_opt.train_steps,
                                       power=cfg.Lamb.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
                        {'params': other_params},
                        {'order_params': params}]
        optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=cfg.AdamWeightDecay.warmup_steps,
                                       decay_steps=args_opt.train_steps,
                                       power=cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0},
                        {'order_params': params}]
        if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
            optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
        elif context.get_context("mode") == context.PYNATIVE_MODE and args_opt.device_target == 'GPU':
            optimizer = AdamWeightDecayOp(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
        else:
            optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == "Thor":
        from src.utils import get_bert_thor_lr, get_bert_thor_damping
        lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps)
        damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power,
                                        cfg.Thor.damping_total_steps)
        split_indices = None
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
            else:
                split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
            else:
                split_indices = [38, 93, 148, 203, 258, 313, 368, 397]
        optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
                         cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
                         decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
                         split_indices=split_indices)
    else:
        raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
                         format(cfg.optimizer))
    return optimizer
コード例 #2
0
ファイル: run_pretrain.py プロジェクト: mark14wu/bert_demo
def _get_optimizer(args_opt, network):
    """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
            .format(cfg.optimizer))
    return optimizer
コード例 #3
0
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
    """ do train """
    if load_checkpoint_path == "":
        raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
    steps_per_epoch = dataset.get_dataset_size()
    # optimizer
    if optimizer_cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0}]
        optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    elif optimizer_cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
                                       end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
                                       warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
                                       decay_steps=steps_per_epoch * epoch_num,
                                       power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
    elif optimizer_cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
                             momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="ner",
                                 directory=None if save_checkpoint_path == "" else save_checkpoint_path,
                                 config=ckpt_config)
    param_dict = load_checkpoint(load_checkpoint_path)
    load_param_into_net(network, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
    netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
    train_begin = time.time()
    model.train(epoch_num, dataset, callbacks=callbacks)
    train_end = time.time()
    print("latency: {:.6f} s".format(train_end - train_begin))
コード例 #4
0
def run_predistill():
    """
    run predistill
    """
    cfg = phase1_cfg
    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = args_opt.load_gd_ckpt_path
    netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
                                         student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
                                         is_training=True, task_type='classification',
                                         num_labels=args_opt.num_labels, is_predistill=True)

    rank = 0
    device_num = 1
    dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.train_data_dir, args_opt.schema_dir)

    dataset_size = dataset.get_dataset_size()
    print('td1 dataset size: ', dataset_size)
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase1_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                   end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                   warmup_steps=int(dataset_size / 10),
                                   decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
                                   power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                    {'params': other_params, 'weight_decay': 0.0},
                    {'order_params': params}]

    optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
                                                                               args_opt.save_ckpt_step,
                                                                               args_opt.max_ckpt_num,
                                                                               td_phase1_save_ckpt_dir)]
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
                                             scale_factor=cfg.scale_factor,
                                             scale_window=cfg.scale_window)
    netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(repeat_count, dataset, callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
コード例 #5
0
def run_general_distill():
    """
    run general distill
    """
    parser = argparse.ArgumentParser(description='tinybert general distill')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="3",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--save_ckpt_step",
                        type=int,
                        default=100,
                        help="Enable data sink, default is true.")
    parser.add_argument("--max_ckpt_num",
                        type=int,
                        default=1,
                        help="Enable data sink, default is true.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default=1,
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_ckpt_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_teacher_ckpt_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    if args_opt.distribute == "true":
        D.init('hccl')
        device_num = args_opt.device_num
        rank = args_opt.device_id % device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.data_dir, args_opt.schema_dir)

    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset.get_dataset_size(
        ) // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]

    update_cell = DynamicLossScaleUpdateCell(
        loss_scale_value=common_cfg.loss_scale_value,
        scale_factor=common_cfg.scale_factor,
        scale_window=common_cfg.scale_window)

    netwithgrads = BertTrainWithLossScaleCell(netwithloss,
                                              optimizer=optimizer,
                                              scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
コード例 #6
0
ファイル: run_task_distill.py プロジェクト: yrpang/mindspore
def run_task_distill(ckpt_file):
    """
    run task distill
    """
    if ckpt_file == '':
        raise ValueError("Student ckpt file should not be None")
    cfg = phase2_cfg

    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = ckpt_file
    netwithloss = BertNetworkWithLoss_td(
        teacher_config=td_teacher_net_cfg,
        teacher_ckpt=load_teacher_checkpoint_path,
        student_config=td_student_net_cfg,
        student_ckpt=load_student_checkpoint_path,
        is_training=True,
        task_type=args_opt.task_type,
        num_labels=task.num_labels,
        is_predistill=False)

    rank = 0
    device_num = 1
    train_dataset = create_tinybert_dataset('td',
                                            cfg.batch_size,
                                            device_num,
                                            rank,
                                            args_opt.do_shuffle,
                                            args_opt.train_data_dir,
                                            args_opt.schema_dir,
                                            data_type=dataset_type)

    dataset_size = train_dataset.get_dataset_size()
    print('td2 train dataset size: ', dataset_size)
    print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size(
        ) // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase2_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(
        learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
        power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(
        filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(
        filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=optimizer_cfg.AdamWeightDecay.eps)

    eval_dataset = create_tinybert_dataset('td',
                                           eval_cfg.batch_size,
                                           device_num,
                                           rank,
                                           args_opt.do_shuffle,
                                           args_opt.eval_data_dir,
                                           args_opt.schema_dir,
                                           data_type=dataset_type)
    print('td2 eval dataset size: ', eval_dataset.get_dataset_size())

    if args_opt.do_eval.lower() == "true":
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            EvalCallBack(netwithloss.bert, eval_dataset)
        ]
    else:
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                          args_opt.max_ckpt_num, td_phase2_save_ckpt_dir)
        ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        netwithgrads = BertEvaluationWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                train_dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
コード例 #7
0
def run_general_distill():
    """
    run general distill
    """
    parser = argparse.ArgumentParser(description='tinybert general distill')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="3",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--save_ckpt_step",
                        type=int,
                        default=100,
                        help="Enable data sink, default is true.")
    parser.add_argument("--max_ckpt_num",
                        type=int,
                        default=1,
                        help="Enable data sink, default is true.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default=1,
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_ckpt_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_teacher_ckpt_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="tfrecord",
        help="dataset type tfrecord/mindrecord, default is tfrecord")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    enable_loss_scale = True
    if args_opt.device_target == "GPU":
        if bert_student_net_cfg.compute_type != mstype.float32:
            logger.warning(
                'Compute about the student only support float32 temporarily, run with float32.'
            )
            bert_student_net_cfg.compute_type = mstype.float32
        # Backward of the network are calculated using fp32,
        # and the loss scale is not necessary
        enable_loss_scale = False

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    if args_opt.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif args_opt.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        raise Exception("dataset format is not supported yet")
    dataset = create_tinybert_dataset('gd',
                                      common_cfg.batch_size,
                                      device_num,
                                      rank,
                                      args_opt.do_shuffle,
                                      args_opt.data_dir,
                                      args_opt.schema_dir,
                                      data_type=dataset_type)
    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    print("dataset repeatcount: ", dataset.get_repeat_count())
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=common_cfg.loss_scale_value,
            scale_factor=common_cfg.scale_factor,
            scale_window=common_cfg.scale_window)
        netwithgrads = BertTrainWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
コード例 #8
0
def run_pretrain():
    """pre-train bert_clue"""
    parser = argparse.ArgumentParser(description='bert pre_training')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="1",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--enable_save_ckpt",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable save checkpoint, default is true.")
    parser.add_argument("--enable_lossscale",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Use lossscale or not, default is not.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default="1",
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument(
        "--accumulation_steps",
        type=int,
        default="1",
        help=
        "Accumulating gradients N times before weight update, default is 1.")
    parser.add_argument("--save_checkpoint_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_checkpoint_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--save_checkpoint_steps",
                        type=int,
                        default=1000,
                        help="Save checkpoint steps, "
                        "default is 1000.")
    parser.add_argument("--train_steps",
                        type=int,
                        default=-1,
                        help="Training Steps, default is -1, "
                        "meaning run all steps according to epoch number.")
    parser.add_argument("--save_checkpoint_num",
                        type=int,
                        default=1,
                        help="Save checkpoint numbers, default is 1.")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")

    args_opt = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    ckpt_save_dir = args_opt.save_checkpoint_path
    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(
                rank) + '/'

        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    29, 58, 87, 116, 145, 174, 203, 217
                ])
            else:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    28, 55, 82, 109, 136, 163, 190, 205
                ])
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    30, 90, 150, 210, 270, 330, 390, 421
                ])
            else:
                context.set_auto_parallel_context(all_reduce_fusion_config=[
                    38, 93, 148, 203, 258, 313, 368, 397
                ])
    else:
        rank = 0
        device_num = 1

    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
        bert_net_cfg.compute_type = mstype.float32

    if args_opt.accumulation_steps > 1:
        logger.info("accumulation steps: {}".format(
            args_opt.accumulation_steps))
        logger.info("global batch size: {}".format(
            bert_net_cfg.batch_size * args_opt.accumulation_steps))
        if args_opt.enable_data_sink == "true":
            args_opt.data_sink_steps *= args_opt.accumulation_steps
            logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
        if args_opt.enable_save_ckpt == "true":
            args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
            logger.info("save checkpoint steps: {}".format(
                args_opt.save_checkpoint_steps))

    ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle,
                             args_opt.data_dir, args_opt.schema_dir)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = args_opt.epoch_size * ds.get_dataset_size(
    ) // args_opt.data_sink_steps
    if args_opt.train_steps > 0:
        train_steps = args_opt.train_steps * args_opt.accumulation_steps
        new_repeat_count = min(new_repeat_count,
                               train_steps // args_opt.data_sink_steps)
    else:
        args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size(
        ) // args_opt.accumulation_steps
        logger.info("train steps: {}".format(args_opt.train_steps))

    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(net_with_loss.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = net_with_loss.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]"
            .format(cfg.optimizer))
    callback = [
        TimeMonitor(args_opt.data_sink_steps),
        LossCallBack(ds.get_dataset_size())
    ]
    if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=args_opt.save_checkpoint_steps,
            keep_checkpoint_max=args_opt.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_bert',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if args_opt.load_checkpoint_path:
        param_dict = load_checkpoint(args_opt.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    if args_opt.enable_lossscale == "true":
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        if args_opt.accumulation_steps <= 1:
            net_with_grads = BertTrainOneStepWithLossScaleCell(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell)
        else:
            accumulation_steps = args_opt.accumulation_steps
            net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(
                net_with_loss,
                optimizer=optimizer,
                scale_update_cell=update_cell,
                accumulation_steps=accumulation_steps,
                enable_global_norm=cfg.enable_global_norm)
    else:
        net_with_grads = BertTrainOneStepCell(net_with_loss,
                                              optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                ds,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
コード例 #9
0
def test_train():
    '''
    finetune function
    '''
    target = args_opt.device_target
    if target == "Ascend":
        devid = int(os.getenv('DEVICE_ID'))
        context.set_context(mode=context.GRAPH_MODE,
                            device_target="Ascend",
                            device_id=devid)

    poetry, tokenizer, keep_words = create_tokenizer()
    print(len(keep_words))

    dataset = create_poetry_dataset(bert_net_cfg.batch_size, poetry, tokenizer)

    num_tokens = 3191
    poetrymodel = BertPoetryModel(bert_net_cfg,
                                  True,
                                  num_tokens,
                                  dropout_prob=0.1)
    netwithloss = BertPoetry(poetrymodel, bert_net_cfg, True, dropout_prob=0.1)
    callback = LossCallBack(poetrymodel)

    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    print("============ steps_per_epoch is {}".format(steps_per_epoch))
    lr_schedule = BertLearningRate(
        learning_rate=cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=1000,
        decay_steps=cfg.epoch_num * steps_per_epoch,
        power=cfg.AdamWeightDecay.power)
    optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr_schedule)
    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix,
                                 directory=cfg.ckpt_dir,
                                 config=ckpt_config)

    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    new_dict = {}

    # load corresponding rows of embedding_lookup
    for key in param_dict:
        if "bert_embedding_lookup" not in key:
            new_dict[key] = param_dict[key]
        else:
            value = param_dict[key]
            np_value = value.data.asnumpy()
            np_value = np_value[keep_words]
            tensor_value = Tensor(np_value, mstype.float32)
            parameter_value = Parameter(tensor_value, name=key)
            new_dict[key] = parameter_value

    load_param_into_net(netwithloss, new_dict)
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = BertPoetryCell(netwithloss,
                                  optimizer=optimizer,
                                  scale_update_cell=update_cell)

    model = Model(netwithgrads)
    model.train(cfg.epoch_num,
                dataset,
                callbacks=[callback, ckpoint_cb],
                dataset_sink_mode=True)
コード例 #10
0
def run_general_distill():
    """
    run general distill
    """
    args_opt = get_argument()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        reserve_class_name_in_scope=False)
    if args_opt.device_target == "Ascend":
        context.set_context(device_id=args_opt.device_id)

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    enable_loss_scale = True
    if args_opt.device_target == "GPU":
        if bert_student_net_cfg.compute_type != mstype.float32:
            logger.warning(
                'Compute about the student only support float32 temporarily, run with float32.'
            )
            bert_student_net_cfg.compute_type = mstype.float32
        # Backward of the network are calculated using fp32,
        # and the loss scale is not necessary
        enable_loss_scale = False

    if args_opt.device_target == "CPU":
        logger.warning(
            'CPU only support float32 temporarily, run with float32.')
        bert_teacher_net_cfg.dtype = mstype.float32
        bert_teacher_net_cfg.compute_type = mstype.float32
        bert_student_net_cfg.dtype = mstype.float32
        bert_student_net_cfg.compute_type = mstype.float32
        enable_loss_scale = False

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    if args_opt.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif args_opt.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        raise Exception("dataset format is not supported yet")
    dataset = create_tinybert_dataset('gd',
                                      common_cfg.batch_size,
                                      device_num,
                                      rank,
                                      args_opt.do_shuffle,
                                      args_opt.data_dir,
                                      args_opt.schema_dir,
                                      data_type=dataset_type)
    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    print("dataset repeatcount: ", dataset.get_repeat_count())
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=common_cfg.loss_scale_value,
            scale_factor=common_cfg.scale_factor,
            scale_window=common_cfg.scale_window)
        netwithgrads = BertTrainWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
コード例 #11
0
def run_task_distill(args_opt):
    """
    run task distill
    """
    task = task_cfg[args_opt.task_name]
    teacher_net_cfg.seq_length = task.seq_length
    student_net_cfg.seq_length = task.seq_length
    train_cfg.batch_size = args_opt.train_batch_size
    eval_cfg.batch_size = args_opt.eval_batch_size
    teacher_ckpt = os.path.join(args_opt.teacher_model_dir, args_opt.task_name,
                                WEIGHTS_NAME)
    student_ckpt = os.path.join(args_opt.student_model_dir, args_opt.task_name,
                                WEIGHTS_NAME)
    train_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name,
                                  TRAIN_DATA_NAME)
    eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name,
                                 EVAL_DATA_NAME)
    save_ckpt_dir = os.path.join(args_opt.output_dir, args_opt.task_name)

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args.device_id)

    rank = 0
    device_num = 1
    train_dataset = create_dataset(batch_size=train_cfg.batch_size,
                                   device_num=device_num,
                                   rank=rank,
                                   do_shuffle=args_opt.do_shuffle,
                                   data_dir=train_data_dir,
                                   data_type=args_opt.dataset_type,
                                   seq_length=task.seq_length,
                                   task_type=task.task_type,
                                   drop_remainder=True)
    dataset_size = train_dataset.get_dataset_size()
    print('train dataset size:', dataset_size)
    eval_dataset = create_dataset(batch_size=eval_cfg.batch_size,
                                  device_num=device_num,
                                  rank=rank,
                                  do_shuffle=args_opt.do_shuffle,
                                  data_dir=eval_data_dir,
                                  data_type=args_opt.dataset_type,
                                  seq_length=task.seq_length,
                                  task_type=task.task_type,
                                  drop_remainder=False)
    print('eval dataset size:', eval_dataset.get_dataset_size())

    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size

    netwithloss = BertNetworkWithLoss(teacher_config=teacher_net_cfg,
                                      teacher_ckpt=teacher_ckpt,
                                      student_config=student_net_cfg,
                                      student_ckpt=student_ckpt,
                                      is_training=True,
                                      task_type=task.task_type,
                                      num_labels=task.num_labels)
    params = netwithloss.trainable_params()
    optimizer_cfg = train_cfg.optimizer_cfg
    lr_schedule = BertLearningRate(
        learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size *
                         optimizer_cfg.AdamWeightDecay.warmup_ratio),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=optimizer_cfg.AdamWeightDecay.power)
    decay_params = list(
        filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(
        filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=optimizer_cfg.AdamWeightDecay.eps)

    netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)

    if args_opt.do_eval == 'true':
        eval_dataset = list(eval_dataset.create_dict_iterator())
        callback = [
            EvalCallBack(network=netwithloss.bert,
                         dataset=eval_dataset,
                         eval_ckpt_step=args_opt.eval_ckpt_step,
                         save_ckpt_dir=save_ckpt_dir,
                         embedding_bits=student_net_cfg.embedding_bits,
                         weight_bits=student_net_cfg.weight_bits,
                         clip_value=student_net_cfg.weight_clip_value,
                         metrics=task.metrics)
        ]
    else:
        callback = [
            StepCallBack(),
            ModelSaveCkpt(network=netwithloss.bert,
                          save_ckpt_step=args_opt.save_ckpt_step,
                          max_ckpt_num=args_opt.max_ckpt_num,
                          output_dir=save_ckpt_dir,
                          embedding_bits=student_net_cfg.embedding_bits,
                          weight_bits=student_net_cfg.weight_bits,
                          clip_value=student_net_cfg.weight_clip_value)
        ]
    model = Model(netwithgrads)
    model.train(repeat_count,
                train_dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
コード例 #12
0
def _get_optimizer(args_opt, network):
    """get bert optimizer, support Lamb, Momentum, AdamWeightDecay and Thor."""
    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.Lamb.learning_rate,
            end_learning_rate=cfg.Lamb.end_learning_rate,
            warmup_steps=cfg.Lamb.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.Lamb.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.Lamb.weight_decay
        }, {
            'params': other_params
        }, {
            'order_params': params
        }]
        optimizer = Lamb(group_params,
                         learning_rate=lr_schedule,
                         eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(),
                             learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=cfg.AdamWeightDecay.warmup_steps,
            decay_steps=args_opt.train_steps,
            power=cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{
            'params': decay_params,
            'weight_decay': cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }, {
            'order_params': params
        }]

        optimizer = AdamWeightDecay(group_params,
                                    learning_rate=lr_schedule,
                                    eps=cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == "Thor":
        if args_opt.distribute == "true":
            from src.thor_for_bert_arg import THOR
        else:
            from src.thor_for_bert import THOR
        lr = get_bert_lr()
        damping = get_bert_damping()
        optimizer = THOR(
            filter(lambda x: x.requires_grad,
                   network.get_parameters()), lr, cfg.Thor.momentum,
            filter(lambda x: 'matrix_A' in x.name, network.get_parameters()),
            filter(lambda x: 'matrix_G' in x.name, network.get_parameters()),
            cfg.Thor.weight_decay, cfg.Thor.loss_scale,
            bert_net_cfg.num_hidden_layers, bert_net_cfg.batch_size, damping)
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]"
            .format(cfg.optimizer))
    return optimizer
コード例 #13
0
def train():
    '''
    finetune function
    '''
    # BertCLS train for classification
    # BertNER train for sequence labeling

    if cfg.task == 'NER':
        tag_to_index = None
        if cfg.use_crf:
            tag_to_index = json.loads(open(cfg.label2id_file).read())
            print(tag_to_index)
            max_val = len(tag_to_index)
            tag_to_index["<START>"] = max_val
            tag_to_index["<STOP>"] = max_val + 1
            number_labels = len(tag_to_index)
        else:
            number_labels = cfg.num_labels

        netwithloss = BertNER(bert_net_cfg,
                              cfg.batch_size,
                              True,
                              num_labels=number_labels,
                              use_crf=cfg.use_crf,
                              tag_to_index=tag_to_index,
                              dropout_prob=0.1)
    elif cfg.task == 'Classification':
        netwithloss = BertCLS(bert_net_cfg,
                              True,
                              num_labels=cfg.num_labels,
                              dropout_prob=0.1,
                              assessment_method=cfg.assessment_method)
    else:
        raise Exception("task error, NER or Classification is supported.")

    dataset = get_dataset(data_file=cfg.data_file, batch_size=cfg.batch_size)
    steps_per_epoch = dataset.get_dataset_size()
    print('steps_per_epoch:', steps_per_epoch)

    # optimizer
    steps_per_epoch = dataset.get_dataset_size()
    if cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
            end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.AdamWeightDecay.power)
        params = netwithloss.trainable_params()
        decay_params = list(
            filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(
            filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
                   params))
        group_params = [{
            'params':
            decay_params,
            'weight_decay':
            optimizer_cfg.AdamWeightDecay.weight_decay
        }, {
            'params': other_params,
            'weight_decay': 0.0
        }]
        optimizer = AdamWeightDecay(group_params,
                                    lr_schedule,
                                    eps=optimizer_cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(
            learning_rate=optimizer_cfg.Lamb.learning_rate,
            end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
            warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
            decay_steps=steps_per_epoch * cfg.epoch_num,
            power=optimizer_cfg.Lamb.power)
        optimizer = Lamb(netwithloss.trainable_params(),
                         learning_rate=lr_schedule)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(
            netwithloss.trainable_params(),
            learning_rate=optimizer_cfg.Momentum.learning_rate,
            momentum=optimizer_cfg.Momentum.momentum)
    else:
        raise Exception("Optimizer not supported.")

    # load checkpoint into network
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix,
                                 directory=cfg.ckpt_dir,
                                 config=ckpt_config)
    param_dict = load_checkpoint(cfg.pre_training_ckpt)
    load_param_into_net(netwithloss, param_dict)

    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32,
                                             scale_factor=2,
                                             scale_window=1000)
    netwithgrads = BertFinetuneCell(netwithloss,
                                    optimizer=optimizer,
                                    scale_update_cell=update_cell)
    model = Model(netwithgrads)
    callbacks = [
        TimeMonitor(dataset.get_dataset_size()),
        LossCallBack(dataset.get_dataset_size()), ckpoint_cb
    ]
    model.train(cfg.epoch_num,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=True)