def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"], help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", choices=["true", "false"], help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"], help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"], help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument( "--accumulation_steps", type=int, default="1", help= "Accumulating gradients N times before weight update, default is 1.") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " "meaning run all steps according to epoch number.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init() device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init() device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( rank) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) from mindspore.parallel._auto_parallel_context import auto_parallel_context if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [29, 58, 87, 116, 145, 174, 203, 217]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [28, 55, 82, 109, 136, 163, 190, 205]) elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [30, 90, 150, 210, 270, 330, 390, 421]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [38, 93, 148, 203, 258, 313, 368, 397]) else: rank = 0 device_num = 1 if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 if args_opt.accumulation_steps > 1: logger.info("accumulation steps: {}".format( args_opt.accumulation_steps)) logger.info("global batch size: {}".format( bert_net_cfg.batch_size * args_opt.accumulation_steps)) if args_opt.enable_data_sink == "true": args_opt.data_sink_steps *= args_opt.accumulation_steps logger.info("data sink steps: {}".format(args_opt.data_sink_steps)) if args_opt.enable_save_ckpt == "true": args_opt.save_checkpoint_steps *= args_opt.accumulation_steps logger.info("save checkpoint steps: {}".format( args_opt.save_checkpoint_steps)) ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) new_repeat_count = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.data_sink_steps if args_opt.train_steps > 0: train_steps = args_opt.train_steps * args_opt.accumulation_steps new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps) else: args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.accumulation_steps logger.info("train steps: {}".format(args_opt.train_steps)) if cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate( learning_rate=cfg.Lamb.learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, warmup_steps=cfg.Lamb.warmup_steps, decay_steps=args_opt.train_steps, power=cfg.Lamb.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.Lamb.decay_filter, params)) other_params = list( filter(lambda x: not cfg.Lamb.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay }, { 'params': other_params }, { 'order_params': params }] optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate( learning_rate=cfg.AdamWeightDecay.learning_rate, end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, warmup_steps=cfg.AdamWeightDecay.warmup_steps, decay_steps=args_opt.train_steps, power=cfg.AdamWeightDecay.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]" .format(cfg.optimizer)) callback = [ TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size()) ] if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min( 8, device_num) == 0: config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint( prefix='checkpoint_bert', directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) if args_opt.accumulation_steps <= 1: net_with_grads = BertTrainOneStepWithLossScaleCell( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: accumulation_steps = args_opt.accumulation_steps net_with_grads = BertTrainAccumulateStepsWithLossScaleCell( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell, accumulation_steps=accumulation_steps) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def __init__(self, network): super(TrainStepWrapForAdam, self).__init__() self.network = network self.weights = ParameterTuple(network.get_parameters()) self.optimizer = AdamWeightDecay(self.weights) self.clip_gradients = ClipGradients()
def test_adam_mindspore_with_empty_params(): net = nn.Flatten() with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): AdamWeightDecay(net.get_parameters())
def test_adamwithoutparam(): net = NetWithoutWeight() net.set_train() with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
def test_AdamWeightDecay_beta2(): net = Net() with pytest.raises(ValueError): AdamWeightDecay(net.get_parameters(), beta2=1.0, learning_rate=0.1)
def test_AdamWeightDecay_e(): net = Net() with pytest.raises(ValueError): AdamWeightDecay(net.get_parameters(), eps=-0.1, learning_rate=0.1)
def run_task_distill(ckpt_file): """ run task distill """ if ckpt_file == '': raise ValueError("Student ckpt file should not be None") cfg = phase2_cfg context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = ckpt_file netwithloss = BertNetworkWithLoss_td( teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, is_training=True, task_type='classification', num_labels=args_opt.num_labels, is_predistill=False) rank = 0 device_num = 1 train_dataset = create_tinybert_dataset( 'td', td_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.train_data_dir, args_opt.schema_dir) dataset_size = train_dataset.get_dataset_size() print('td2 train dataset size: ', dataset_size) if args_opt.enable_data_sink == 'true': repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size( ) // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.td_phase2_epoch_size time_monitor_steps = dataset_size optimizer_cfg = cfg.optimizer_cfg lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10), decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size), power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list( filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: x not in decay_params, params)) group_params = [{ 'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.eval_data_dir, args_opt.schema_dir) if args_opt.do_eval.lower() == "true": callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), EvalCallBack(netwithloss.bert, eval_dataset) ] else: callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, td_phase2_save_ckpt_dir) ] update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) model.train(repeat_count, train_dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == 'true'), sink_size=args_opt.data_sink_steps)
def run_general_distill(): """ run general distill """ parser = argparse.ArgumentParser(description='tinybert general distill') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.") parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") save_ckpt_dir = os.path.join( args_opt.save_ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init('hccl') device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init('nccl') device_num = D.get_group_size() rank = D.get_rank() save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) else: rank = 0 device_num = 1 if not os.path.exists(save_ckpt_dir): os.makedirs(save_ckpt_dir) enable_loss_scale = True if args_opt.device_target == "GPU": if bert_teacher_net_cfg.compute_type != mstype.float32: logger.warning('GPU only support fp32 temporarily, run with fp32.') bert_teacher_net_cfg.compute_type = mstype.float32 if bert_student_net_cfg.compute_type != mstype.float32: logger.warning('GPU only support fp32 temporarily, run with fp32.') bert_student_net_cfg.compute_type = mstype.float32 # Both the forward and backward of the network are calculated using fp32, # and the loss scale is not necessary enable_loss_scale = False netwithloss = BertNetworkWithLoss_gd( teacher_config=bert_teacher_net_cfg, teacher_ckpt=args_opt.load_teacher_ckpt_path, student_config=bert_student_net_cfg, is_training=True, use_one_hot_embeddings=False) dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) dataset_size = dataset.get_dataset_size() print('dataset size: ', dataset_size) print("dataset repeatcount: ", dataset.get_repeat_count()) if args_opt.enable_data_sink == "true": repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.epoch_size time_monitor_steps = dataset_size lr_schedule = BertLearningRate( learning_rate=common_cfg.AdamWeightDecay.learning_rate, end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size * args_opt.epoch_size / 10), decay_steps=int(dataset_size * args_opt.epoch_size), power=common_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps) callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, save_ckpt_dir) ] if enable_loss_scale: update_cell = DynamicLossScaleUpdateCell( loss_scale_value=common_cfg.loss_scale_value, scale_factor=common_cfg.scale_factor, scale_window=common_cfg.scale_window) netwithgrads = BertTrainWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): """ do train """ if load_checkpoint_path == "": raise ValueError( "Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer if optimizer_cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_steps=steps_per_epoch * epoch_num, power=optimizer_cfg.AdamWeightDecay.power) params = net_with_loss.trainable_params() decay_params = list( filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: x not in decay_params, params)) group_params = [{ 'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }] optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif optimizer_cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.Lamb.learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_steps=steps_per_epoch * epoch_num, power=optimizer_cfg.Lamb.power) optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum( network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception( "Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]" ) # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [ TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb ] model.train(epoch_num, dataset, callbacks=callbacks)
def test_AdamWightDecay_init(self): with pytest.raises(TypeError): AdamWeightDecay(9)
def test_AdamWightDecay_init(self): with pytest.raises(ValueError): AdamWeightDecay(None)