Ejemplo n.º 1
0
def run_predistill():
    """
    run predistill
    """
    cfg = phase1_cfg
    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = args_opt.load_gd_ckpt_path
    netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
                                         student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
                                         is_training=True, task_type='classification',
                                         num_labels=args_opt.num_labels, is_predistill=True)

    rank = 0
    device_num = 1
    dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.train_data_dir, args_opt.schema_dir)

    dataset_size = dataset.get_dataset_size()
    print('td1 dataset size: ', dataset_size)
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase1_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                   end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                   warmup_steps=int(dataset_size / 10),
                                   decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
                                   power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                    {'params': other_params, 'weight_decay': 0.0},
                    {'order_params': params}]

    optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
                                                                               args_opt.save_ckpt_step,
                                                                               args_opt.max_ckpt_num,
                                                                               td_phase1_save_ckpt_dir)]
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
                                             scale_factor=cfg.scale_factor,
                                             scale_window=cfg.scale_window)
    netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(repeat_count, dataset, callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
Ejemplo n.º 2
0
def run_task_distill(ckpt_file):
    """
    run task distill
    """
    if ckpt_file == '':
        raise ValueError("Student ckpt file should not be None")
    cfg = phase2_cfg

    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = ckpt_file
    netwithloss = BertNetworkWithLoss_td(
        teacher_config=td_teacher_net_cfg,
        teacher_ckpt=load_teacher_checkpoint_path,
        student_config=td_student_net_cfg,
        student_ckpt=load_student_checkpoint_path,
        is_training=True,
        task_type=args_opt.task_type,
        num_labels=task.num_labels,
        is_predistill=False)

    rank = 0
    device_num = 1
    train_dataset = create_tinybert_dataset('td',
                                            cfg.batch_size,
                                            device_num,
                                            rank,
                                            args_opt.do_shuffle,
                                            args_opt.train_data_dir,
                                            args_opt.schema_dir,
                                            data_type=dataset_type)

    dataset_size = train_dataset.get_dataset_size()
    print('td2 train dataset size: ', dataset_size)
    print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size(
        ) // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase2_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(
        learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
        power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(
        filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(
        filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=optimizer_cfg.AdamWeightDecay.eps)

    eval_dataset = create_tinybert_dataset('td',
                                           eval_cfg.batch_size,
                                           device_num,
                                           rank,
                                           args_opt.do_shuffle,
                                           args_opt.eval_data_dir,
                                           args_opt.schema_dir,
                                           data_type=dataset_type)
    print('td2 eval dataset size: ', eval_dataset.get_dataset_size())

    if args_opt.do_eval.lower() == "true":
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            EvalCallBack(netwithloss.bert, eval_dataset)
        ]
    else:
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                          args_opt.max_ckpt_num, td_phase2_save_ckpt_dir)
        ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        netwithgrads = BertEvaluationWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                train_dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)