def run_predistill(): """ run predistill """ cfg = phase1_cfg context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = args_opt.load_gd_ckpt_path netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, is_training=True, task_type='classification', num_labels=args_opt.num_labels, is_predistill=True) rank = 0 device_num = 1 dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.train_data_dir, args_opt.schema_dir) dataset_size = dataset.get_dataset_size() print('td1 dataset size: ', dataset_size) if args_opt.enable_data_sink == 'true': repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.td_phase1_epoch_size time_monitor_steps = dataset_size optimizer_cfg = cfg.optimizer_cfg lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size / 10), decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size), power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, td_phase1_save_ckpt_dir)] update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == 'true'), sink_size=args_opt.data_sink_steps)
def run_task_distill(ckpt_file): """ run task distill """ if ckpt_file == '': raise ValueError("Student ckpt file should not be None") cfg = phase2_cfg load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = ckpt_file netwithloss = BertNetworkWithLoss_td( teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, is_training=True, task_type=args_opt.task_type, num_labels=task.num_labels, is_predistill=False) rank = 0 device_num = 1 train_dataset = create_tinybert_dataset('td', cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.train_data_dir, args_opt.schema_dir, data_type=dataset_type) dataset_size = train_dataset.get_dataset_size() print('td2 train dataset size: ', dataset_size) print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count()) if args_opt.enable_data_sink == 'true': repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size( ) // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.td_phase2_epoch_size time_monitor_steps = dataset_size optimizer_cfg = cfg.optimizer_cfg lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10), decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size), power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list( filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.eval_data_dir, args_opt.schema_dir, data_type=dataset_type) print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) if args_opt.do_eval.lower() == "true": callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), EvalCallBack(netwithloss.bert, eval_dataset) ] else: callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, td_phase2_save_ckpt_dir) ] if enable_loss_scale: update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertEvaluationWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(repeat_count, train_dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == 'true'), sink_size=args_opt.data_sink_steps)