Example #1
0
def run_predistill():
    """
    run predistill
    """
    cfg = phase1_cfg
    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = args_opt.load_gd_ckpt_path
    netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
                                         student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
                                         is_training=True, task_type='classification',
                                         num_labels=args_opt.num_labels, is_predistill=True)

    rank = 0
    device_num = 1
    dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.train_data_dir, args_opt.schema_dir)

    dataset_size = dataset.get_dataset_size()
    print('td1 dataset size: ', dataset_size)
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase1_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
                                   end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
                                   warmup_steps=int(dataset_size / 10),
                                   decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size),
                                   power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
                    {'params': other_params, 'weight_decay': 0.0},
                    {'order_params': params}]

    optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
    callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
                                                                               args_opt.save_ckpt_step,
                                                                               args_opt.max_ckpt_num,
                                                                               td_phase1_save_ckpt_dir)]
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
                                             scale_factor=cfg.scale_factor,
                                             scale_window=cfg.scale_window)
    netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(repeat_count, dataset, callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
def run_general_distill():
    """
    run general distill
    """
    parser = argparse.ArgumentParser(description='tinybert general distill')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="3",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--save_ckpt_step",
                        type=int,
                        default=100,
                        help="Enable data sink, default is true.")
    parser.add_argument("--max_ckpt_num",
                        type=int,
                        default=1,
                        help="Enable data sink, default is true.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default=1,
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_ckpt_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_teacher_ckpt_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    if args_opt.distribute == "true":
        D.init('hccl')
        device_num = args_opt.device_num
        rank = args_opt.device_id % device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size,
                                      device_num, rank, args_opt.do_shuffle,
                                      args_opt.data_dir, args_opt.schema_dir)

    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset.get_dataset_size(
        ) // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]

    update_cell = DynamicLossScaleUpdateCell(
        loss_scale_value=common_cfg.loss_scale_value,
        scale_factor=common_cfg.scale_factor,
        scale_window=common_cfg.scale_window)

    netwithgrads = BertTrainWithLossScaleCell(netwithloss,
                                              optimizer=optimizer,
                                              scale_update_cell=update_cell)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
Example #3
0
def run_task_distill(ckpt_file):
    """
    run task distill
    """
    if ckpt_file == '':
        raise ValueError("Student ckpt file should not be None")
    cfg = phase2_cfg

    load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
    load_student_checkpoint_path = ckpt_file
    netwithloss = BertNetworkWithLoss_td(
        teacher_config=td_teacher_net_cfg,
        teacher_ckpt=load_teacher_checkpoint_path,
        student_config=td_student_net_cfg,
        student_ckpt=load_student_checkpoint_path,
        is_training=True,
        task_type=args_opt.task_type,
        num_labels=task.num_labels,
        is_predistill=False)

    rank = 0
    device_num = 1
    train_dataset = create_tinybert_dataset('td',
                                            cfg.batch_size,
                                            device_num,
                                            rank,
                                            args_opt.do_shuffle,
                                            args_opt.train_data_dir,
                                            args_opt.schema_dir,
                                            data_type=dataset_type)

    dataset_size = train_dataset.get_dataset_size()
    print('td2 train dataset size: ', dataset_size)
    print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size(
        ) // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.td_phase2_epoch_size
        time_monitor_steps = dataset_size

    optimizer_cfg = cfg.optimizer_cfg

    lr_schedule = BertLearningRate(
        learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size),
        power=optimizer_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(
        filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(
        filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=optimizer_cfg.AdamWeightDecay.eps)

    eval_dataset = create_tinybert_dataset('td',
                                           eval_cfg.batch_size,
                                           device_num,
                                           rank,
                                           args_opt.do_shuffle,
                                           args_opt.eval_data_dir,
                                           args_opt.schema_dir,
                                           data_type=dataset_type)
    print('td2 eval dataset size: ', eval_dataset.get_dataset_size())

    if args_opt.do_eval.lower() == "true":
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            EvalCallBack(netwithloss.bert, eval_dataset)
        ]
    else:
        callback = [
            TimeMonitor(time_monitor_steps),
            LossCallBack(),
            ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                          args_opt.max_ckpt_num, td_phase2_save_ckpt_dir)
        ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=cfg.loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)

        netwithgrads = BertEvaluationWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                train_dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)
Example #4
0
def run_general_distill():
    """
    run general distill
    """
    parser = argparse.ArgumentParser(description='tinybert general distill')
    parser.add_argument(
        '--device_target',
        type=str,
        default='Ascend',
        choices=['Ascend', 'GPU'],
        help='device where the code will be implemented. (Default: Ascend)')
    parser.add_argument("--distribute",
                        type=str,
                        default="false",
                        choices=["true", "false"],
                        help="Run distribute, default is false.")
    parser.add_argument("--epoch_size",
                        type=int,
                        default="3",
                        help="Epoch size, default is 1.")
    parser.add_argument("--device_id",
                        type=int,
                        default=0,
                        help="Device id, default is 0.")
    parser.add_argument("--device_num",
                        type=int,
                        default=1,
                        help="Use device nums, default is 1.")
    parser.add_argument("--save_ckpt_step",
                        type=int,
                        default=100,
                        help="Enable data sink, default is true.")
    parser.add_argument("--max_ckpt_num",
                        type=int,
                        default=1,
                        help="Enable data sink, default is true.")
    parser.add_argument("--do_shuffle",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable shuffle for dataset, default is true.")
    parser.add_argument("--enable_data_sink",
                        type=str,
                        default="true",
                        choices=["true", "false"],
                        help="Enable data sink, default is true.")
    parser.add_argument("--data_sink_steps",
                        type=int,
                        default=1,
                        help="Sink steps for each epoch, default is 1.")
    parser.add_argument("--save_ckpt_path",
                        type=str,
                        default="",
                        help="Save checkpoint path")
    parser.add_argument("--load_teacher_ckpt_path",
                        type=str,
                        default="",
                        help="Load checkpoint file path")
    parser.add_argument("--data_dir",
                        type=str,
                        default="",
                        help="Data path, it is better to use absolute path")
    parser.add_argument("--schema_dir",
                        type=str,
                        default="",
                        help="Schema path, it is better to use absolute path")
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="tfrecord",
        help="dataset type tfrecord/mindrecord, default is tfrecord")
    args_opt = parser.parse_args()

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args_opt.device_id)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(variable_memory_max_size="30GB")

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    enable_loss_scale = True
    if args_opt.device_target == "GPU":
        if bert_student_net_cfg.compute_type != mstype.float32:
            logger.warning(
                'Compute about the student only support float32 temporarily, run with float32.'
            )
            bert_student_net_cfg.compute_type = mstype.float32
        # Backward of the network are calculated using fp32,
        # and the loss scale is not necessary
        enable_loss_scale = False

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    if args_opt.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif args_opt.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        raise Exception("dataset format is not supported yet")
    dataset = create_tinybert_dataset('gd',
                                      common_cfg.batch_size,
                                      device_num,
                                      rank,
                                      args_opt.do_shuffle,
                                      args_opt.data_dir,
                                      args_opt.schema_dir,
                                      data_type=dataset_type)
    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    print("dataset repeatcount: ", dataset.get_repeat_count())
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=common_cfg.loss_scale_value,
            scale_factor=common_cfg.scale_factor,
            scale_window=common_cfg.scale_window)
        netwithgrads = BertTrainWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
Example #5
0
def run_general_distill():
    """
    run general distill
    """
    args_opt = get_argument()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        reserve_class_name_in_scope=False)
    if args_opt.device_target == "Ascend":
        context.set_context(device_id=args_opt.device_id)

    save_ckpt_dir = os.path.join(
        args_opt.save_ckpt_path,
        datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))

    if args_opt.distribute == "true":
        if args_opt.device_target == 'Ascend':
            D.init()
            device_num = args_opt.device_num
            rank = args_opt.device_id % device_num
        else:
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
    else:
        rank = 0
        device_num = 1

    if not os.path.exists(save_ckpt_dir):
        os.makedirs(save_ckpt_dir)

    enable_loss_scale = True
    if args_opt.device_target == "GPU":
        if bert_student_net_cfg.compute_type != mstype.float32:
            logger.warning(
                'Compute about the student only support float32 temporarily, run with float32.'
            )
            bert_student_net_cfg.compute_type = mstype.float32
        # Backward of the network are calculated using fp32,
        # and the loss scale is not necessary
        enable_loss_scale = False

    if args_opt.device_target == "CPU":
        logger.warning(
            'CPU only support float32 temporarily, run with float32.')
        bert_teacher_net_cfg.dtype = mstype.float32
        bert_teacher_net_cfg.compute_type = mstype.float32
        bert_student_net_cfg.dtype = mstype.float32
        bert_student_net_cfg.compute_type = mstype.float32
        enable_loss_scale = False

    netwithloss = BertNetworkWithLoss_gd(
        teacher_config=bert_teacher_net_cfg,
        teacher_ckpt=args_opt.load_teacher_ckpt_path,
        student_config=bert_student_net_cfg,
        is_training=True,
        use_one_hot_embeddings=False)

    if args_opt.dataset_type == "tfrecord":
        dataset_type = DataType.TFRECORD
    elif args_opt.dataset_type == "mindrecord":
        dataset_type = DataType.MINDRECORD
    else:
        raise Exception("dataset format is not supported yet")
    dataset = create_tinybert_dataset('gd',
                                      common_cfg.batch_size,
                                      device_num,
                                      rank,
                                      args_opt.do_shuffle,
                                      args_opt.data_dir,
                                      args_opt.schema_dir,
                                      data_type=dataset_type)
    dataset_size = dataset.get_dataset_size()
    print('dataset size: ', dataset_size)
    print("dataset repeatcount: ", dataset.get_repeat_count())
    if args_opt.enable_data_sink == "true":
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
        time_monitor_steps = args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size
        time_monitor_steps = dataset_size

    lr_schedule = BertLearningRate(
        learning_rate=common_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size / 10),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=common_cfg.AdamWeightDecay.power)
    params = netwithloss.trainable_params()
    decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter,
                               params))
    other_params = list(
        filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': common_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=common_cfg.AdamWeightDecay.eps)

    callback = [
        TimeMonitor(time_monitor_steps),
        LossCallBack(),
        ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step,
                      args_opt.max_ckpt_num, save_ckpt_dir)
    ]
    if enable_loss_scale:
        update_cell = DynamicLossScaleUpdateCell(
            loss_scale_value=common_cfg.loss_scale_value,
            scale_factor=common_cfg.scale_factor,
            scale_window=common_cfg.scale_window)
        netwithgrads = BertTrainWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == "true"),
                sink_size=args_opt.data_sink_steps)
def run_task_distill(args_opt):
    """
    run task distill
    """
    task = task_cfg[args_opt.task_name]
    teacher_net_cfg.seq_length = task.seq_length
    student_net_cfg.seq_length = task.seq_length
    train_cfg.batch_size = args_opt.train_batch_size
    eval_cfg.batch_size = args_opt.eval_batch_size
    teacher_ckpt = os.path.join(args_opt.teacher_model_dir, args_opt.task_name,
                                WEIGHTS_NAME)
    student_ckpt = os.path.join(args_opt.student_model_dir, args_opt.task_name,
                                WEIGHTS_NAME)
    train_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name,
                                  TRAIN_DATA_NAME)
    eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name,
                                 EVAL_DATA_NAME)
    save_ckpt_dir = os.path.join(args_opt.output_dir, args_opt.task_name)

    context.set_context(mode=context.GRAPH_MODE,
                        device_target=args_opt.device_target,
                        device_id=args.device_id)

    rank = 0
    device_num = 1
    train_dataset = create_dataset(batch_size=train_cfg.batch_size,
                                   device_num=device_num,
                                   rank=rank,
                                   do_shuffle=args_opt.do_shuffle,
                                   data_dir=train_data_dir,
                                   data_type=args_opt.dataset_type,
                                   seq_length=task.seq_length,
                                   task_type=task.task_type,
                                   drop_remainder=True)
    dataset_size = train_dataset.get_dataset_size()
    print('train dataset size:', dataset_size)
    eval_dataset = create_dataset(batch_size=eval_cfg.batch_size,
                                  device_num=device_num,
                                  rank=rank,
                                  do_shuffle=args_opt.do_shuffle,
                                  data_dir=eval_data_dir,
                                  data_type=args_opt.dataset_type,
                                  seq_length=task.seq_length,
                                  task_type=task.task_type,
                                  drop_remainder=False)
    print('eval dataset size:', eval_dataset.get_dataset_size())

    if args_opt.enable_data_sink == 'true':
        repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
    else:
        repeat_count = args_opt.epoch_size

    netwithloss = BertNetworkWithLoss(teacher_config=teacher_net_cfg,
                                      teacher_ckpt=teacher_ckpt,
                                      student_config=student_net_cfg,
                                      student_ckpt=student_ckpt,
                                      is_training=True,
                                      task_type=task.task_type,
                                      num_labels=task.num_labels)
    params = netwithloss.trainable_params()
    optimizer_cfg = train_cfg.optimizer_cfg
    lr_schedule = BertLearningRate(
        learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
        end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
        warmup_steps=int(dataset_size * args_opt.epoch_size *
                         optimizer_cfg.AdamWeightDecay.warmup_ratio),
        decay_steps=int(dataset_size * args_opt.epoch_size),
        power=optimizer_cfg.AdamWeightDecay.power)
    decay_params = list(
        filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
    other_params = list(
        filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x),
               params))
    group_params = [{
        'params': decay_params,
        'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]

    optimizer = AdamWeightDecay(group_params,
                                learning_rate=lr_schedule,
                                eps=optimizer_cfg.AdamWeightDecay.eps)

    netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)

    if args_opt.do_eval == 'true':
        eval_dataset = list(eval_dataset.create_dict_iterator())
        callback = [
            EvalCallBack(network=netwithloss.bert,
                         dataset=eval_dataset,
                         eval_ckpt_step=args_opt.eval_ckpt_step,
                         save_ckpt_dir=save_ckpt_dir,
                         embedding_bits=student_net_cfg.embedding_bits,
                         weight_bits=student_net_cfg.weight_bits,
                         clip_value=student_net_cfg.weight_clip_value,
                         metrics=task.metrics)
        ]
    else:
        callback = [
            StepCallBack(),
            ModelSaveCkpt(network=netwithloss.bert,
                          save_ckpt_step=args_opt.save_ckpt_step,
                          max_ckpt_num=args_opt.max_ckpt_num,
                          output_dir=save_ckpt_dir,
                          embedding_bits=student_net_cfg.embedding_bits,
                          weight_bits=student_net_cfg.weight_bits,
                          clip_value=student_net_cfg.weight_clip_value)
        ]
    model = Model(netwithgrads)
    model.train(repeat_count,
                train_dataset,
                callbacks=callback,
                dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
                sink_size=args_opt.data_sink_steps)