def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): """ do train """ if load_checkpoint_path == "": raise ValueError( "Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer optimizer = Adam(network.trainable_params(), learning_rate=optimizer_cfg.learning_rate) # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint( prefix="classifier", directory=None if save_checkpoint_path == "" else save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [ TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb ] model.train(epoch_num, dataset, callbacks=callbacks)
def run_predistill(): """ run predistill """ cfg = phase1_cfg context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = args_opt.load_gd_ckpt_path netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, is_training=True, task_type='classification', num_labels=args_opt.num_labels, is_predistill=True) rank = 0 device_num = 1 dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.train_data_dir, args_opt.schema_dir) dataset_size = dataset.get_dataset_size() print('td1 dataset size: ', dataset_size) if args_opt.enable_data_sink == 'true': repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.td_phase1_epoch_size time_monitor_steps = dataset_size optimizer_cfg = cfg.optimizer_cfg lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size / 10), decay_steps=int(dataset_size * args_opt.td_phase1_epoch_size), power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) callback = [TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, td_phase1_save_ckpt_dir)] update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == 'true'), sink_size=args_opt.data_sink_steps)
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): """ do train """ if load_checkpoint_path == "": raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer if optimizer_cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_steps=steps_per_epoch * epoch_num, power=optimizer_cfg.AdamWeightDecay.power) params = network.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}] optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif optimizer_cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_steps=steps_per_epoch * epoch_num, power=optimizer_cfg.Lamb.power) optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix="ner", directory=None if save_checkpoint_path == "" else save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb] train_begin = time.time() model.train(epoch_num, dataset, callbacks=callbacks) train_end = time.time() print("latency: {:.6f} s".format(train_end - train_begin))
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): """ do train """ if load_checkpoint_path == "": raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() epoch_num = dataset.get_repeat_count() # optimizer if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=optimizer_cfg.AdamWeightDecayDynamicLR.power, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) elif optimizer_cfg.optimizer == 'Lamb': optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_filter=optimizer_cfg.Lamb.decay_filter) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] model.train(epoch_num, dataset, callbacks=callbacks)
def _build_training_pipeline(config: GNMTConfig, pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None): """ Build training pipeline. Args: config (GNMTConfig): Config of mass model. pre_training_dataset (Dataset): Pre-training dataset. fine_tune_dataset (Dataset): Fine-tune dataset. test_dataset (Dataset): Test dataset. """ net_with_loss = GNMTNetworkWithLoss(config, is_training=True, use_one_hot_embeddings=True) net_with_loss.init_parameters_data() _load_checkpoint_to_net(config, net_with_loss) dataset = pre_training_dataset if pre_training_dataset is not None \ else fine_tune_dataset if dataset is None: raise ValueError( "pre-training dataset or fine-tuning dataset must be provided one." ) update_steps = config.epochs * dataset.get_dataset_size() lr = _get_lr(config, update_steps) optimizer = _get_optimizer(config, net_with_loss, lr) # Dynamic loss scale. scale_manager = DynamicLossScaleManager( init_loss_scale=config.init_loss_scale, scale_factor=config.loss_scale_factor, scale_window=config.scale_window) net_with_grads = GNMTTrainOneStepWithLossScaleCell( network=net_with_loss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) net_with_grads.set_train(True) model = Model(net_with_grads) loss_monitor = LossCallBack(config) dataset_size = dataset.get_dataset_size() time_cb = TimeMonitor(data_size=dataset_size) ckpt_config = CheckpointConfig( save_checkpoint_steps=config.save_ckpt_steps, keep_checkpoint_max=config.keep_ckpt_max) rank_size = os.getenv('RANK_SIZE') callbacks = [time_cb, loss_monitor] if rank_size is not None and int( rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) callbacks.append(ckpt_callback) summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50) callbacks.append(summary_callback) if rank_size is None or int(rank_size) == 1: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) callbacks.append(ckpt_callback) summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50) callbacks.append(summary_callback) print(f" | ALL SET, PREPARE TO TRAIN.") _train(model=model, config=config, pre_training_dataset=pre_training_dataset, fine_tune_dataset=fine_tune_dataset, test_dataset=test_dataset, callbacks=callbacks)
def run_pretrain(): """pre-train bert_clue""" parser = argparse_init() args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) is_auto_enable_graph_kernel = _auto_enable_graph_kernel( args_opt.device_target, args_opt.enable_graph_kernel) _set_graph_kernel_context(args_opt.device_target, args_opt.enable_graph_kernel, is_auto_enable_graph_kernel) ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init() device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init() device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( get_rank()) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) _set_bert_all_reduce_split() else: rank = 0 device_num = 1 _check_compute_type(args_opt, is_auto_enable_graph_kernel) if args_opt.accumulation_steps > 1: logger.info("accumulation steps: {}".format( args_opt.accumulation_steps)) logger.info("global batch size: {}".format( cfg.batch_size * args_opt.accumulation_steps)) if args_opt.enable_data_sink == "true": args_opt.data_sink_steps *= args_opt.accumulation_steps logger.info("data sink steps: {}".format(args_opt.data_sink_steps)) if args_opt.enable_save_ckpt == "true": args_opt.save_checkpoint_steps *= args_opt.accumulation_steps logger.info("save checkpoint steps: {}".format( args_opt.save_checkpoint_steps)) ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) new_repeat_count = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.data_sink_steps if args_opt.train_steps > 0: train_steps = args_opt.train_steps * args_opt.accumulation_steps new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps) else: args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.accumulation_steps logger.info("train steps: {}".format(args_opt.train_steps)) optimizer = _get_optimizer(args_opt, net_with_loss) callback = [ TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size()) ] if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min( 8, device_num) == 0: config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint( prefix='checkpoint_bert', directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) accumulation_steps = args_opt.accumulation_steps enable_global_norm = cfg.enable_global_norm if accumulation_steps <= 1: if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU': net_with_grads = BertTrainOneStepWithLossScaleCellForAdam( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: net_with_grads = BertTrainOneStepWithLossScaleCell( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true" net_with_accumulation = ( BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else BertTrainAccumulationAllReduceEachWithLossScaleCell) net_with_grads = net_with_accumulation( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell, accumulation_steps=accumulation_steps, enable_global_norm=enable_global_norm) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) model = ConvertModelUtils().convert_to_thor_model( model, network=net_with_grads, optimizer=optimizer, frequency=cfg.Thor.frequency) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def run_general_distill(): """ run general distill """ parser = argparse.ArgumentParser(description='tinybert general distill') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.") parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"], help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"], help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.") parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument( "--dataset_type", type=str, default="tfrecord", help="dataset type tfrecord/mindrecord, default is tfrecord") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") save_ckpt_dir = os.path.join( args_opt.save_ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init() device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init() device_num = D.get_group_size() rank = D.get_rank() save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) else: rank = 0 device_num = 1 if not os.path.exists(save_ckpt_dir): os.makedirs(save_ckpt_dir) enable_loss_scale = True if args_opt.device_target == "GPU": if bert_student_net_cfg.compute_type != mstype.float32: logger.warning( 'Compute about the student only support float32 temporarily, run with float32.' ) bert_student_net_cfg.compute_type = mstype.float32 # Backward of the network are calculated using fp32, # and the loss scale is not necessary enable_loss_scale = False netwithloss = BertNetworkWithLoss_gd( teacher_config=bert_teacher_net_cfg, teacher_ckpt=args_opt.load_teacher_ckpt_path, student_config=bert_student_net_cfg, is_training=True, use_one_hot_embeddings=False) if args_opt.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif args_opt.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: raise Exception("dataset format is not supported yet") dataset = create_tinybert_dataset('gd', common_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir, data_type=dataset_type) dataset_size = dataset.get_dataset_size() print('dataset size: ', dataset_size) print("dataset repeatcount: ", dataset.get_repeat_count()) if args_opt.enable_data_sink == "true": repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.epoch_size time_monitor_steps = dataset_size lr_schedule = BertLearningRate( learning_rate=common_cfg.AdamWeightDecay.learning_rate, end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size * args_opt.epoch_size / 10), decay_steps=int(dataset_size * args_opt.epoch_size), power=common_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps) callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, save_ckpt_dir) ] if enable_loss_scale: update_cell = DynamicLossScaleUpdateCell( loss_scale_value=common_cfg.loss_scale_value, scale_factor=common_cfg.scale_factor, scale_window=common_cfg.scale_window) netwithgrads = BertTrainWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"], help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", choices=["true", "false"], help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"], help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"], help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument( "--accumulation_steps", type=int, default="1", help= "Accumulating gradients N times before weight update, default is 1.") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " "meaning run all steps according to epoch number.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init() device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init() device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( rank) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: context.set_auto_parallel_context(all_reduce_fusion_config=[ 29, 58, 87, 116, 145, 174, 203, 217 ]) else: context.set_auto_parallel_context(all_reduce_fusion_config=[ 28, 55, 82, 109, 136, 163, 190, 205 ]) elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: context.set_auto_parallel_context(all_reduce_fusion_config=[ 30, 90, 150, 210, 270, 330, 390, 421 ]) else: context.set_auto_parallel_context(all_reduce_fusion_config=[ 38, 93, 148, 203, 258, 313, 368, 397 ]) else: rank = 0 device_num = 1 if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 if args_opt.accumulation_steps > 1: logger.info("accumulation steps: {}".format( args_opt.accumulation_steps)) logger.info("global batch size: {}".format( bert_net_cfg.batch_size * args_opt.accumulation_steps)) if args_opt.enable_data_sink == "true": args_opt.data_sink_steps *= args_opt.accumulation_steps logger.info("data sink steps: {}".format(args_opt.data_sink_steps)) if args_opt.enable_save_ckpt == "true": args_opt.save_checkpoint_steps *= args_opt.accumulation_steps logger.info("save checkpoint steps: {}".format( args_opt.save_checkpoint_steps)) ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) new_repeat_count = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.data_sink_steps if args_opt.train_steps > 0: train_steps = args_opt.train_steps * args_opt.accumulation_steps new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps) else: args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.accumulation_steps logger.info("train steps: {}".format(args_opt.train_steps)) if cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate( learning_rate=cfg.Lamb.learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, warmup_steps=cfg.Lamb.warmup_steps, decay_steps=args_opt.train_steps, power=cfg.Lamb.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.Lamb.decay_filter, params)) other_params = list( filter(lambda x: not cfg.Lamb.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay }, { 'params': other_params }, { 'order_params': params }] optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate( learning_rate=cfg.AdamWeightDecay.learning_rate, end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, warmup_steps=cfg.AdamWeightDecay.warmup_steps, decay_steps=args_opt.train_steps, power=cfg.AdamWeightDecay.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]" .format(cfg.optimizer)) callback = [ TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size()) ] if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min( 8, device_num) == 0: config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint( prefix='checkpoint_bert', directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) if args_opt.accumulation_steps <= 1: net_with_grads = BertTrainOneStepWithLossScaleCell( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: accumulation_steps = args_opt.accumulation_steps net_with_grads = BertTrainAccumulateStepsWithLossScaleCell( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell, accumulation_steps=accumulation_steps, enable_global_norm=cfg.enable_global_norm) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def run_general_distill(): """ run general distill """ args_opt = get_argument() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, reserve_class_name_in_scope=False) if args_opt.device_target == "Ascend": context.set_context(device_id=args_opt.device_id) save_ckpt_dir = os.path.join( args_opt.save_ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init() device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init() device_num = D.get_group_size() rank = D.get_rank() save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) else: rank = 0 device_num = 1 if not os.path.exists(save_ckpt_dir): os.makedirs(save_ckpt_dir) enable_loss_scale = True if args_opt.device_target == "GPU": if bert_student_net_cfg.compute_type != mstype.float32: logger.warning( 'Compute about the student only support float32 temporarily, run with float32.' ) bert_student_net_cfg.compute_type = mstype.float32 # Backward of the network are calculated using fp32, # and the loss scale is not necessary enable_loss_scale = False if args_opt.device_target == "CPU": logger.warning( 'CPU only support float32 temporarily, run with float32.') bert_teacher_net_cfg.dtype = mstype.float32 bert_teacher_net_cfg.compute_type = mstype.float32 bert_student_net_cfg.dtype = mstype.float32 bert_student_net_cfg.compute_type = mstype.float32 enable_loss_scale = False netwithloss = BertNetworkWithLoss_gd( teacher_config=bert_teacher_net_cfg, teacher_ckpt=args_opt.load_teacher_ckpt_path, student_config=bert_student_net_cfg, is_training=True, use_one_hot_embeddings=False) if args_opt.dataset_type == "tfrecord": dataset_type = DataType.TFRECORD elif args_opt.dataset_type == "mindrecord": dataset_type = DataType.MINDRECORD else: raise Exception("dataset format is not supported yet") dataset = create_tinybert_dataset('gd', common_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir, data_type=dataset_type) dataset_size = dataset.get_dataset_size() print('dataset size: ', dataset_size) print("dataset repeatcount: ", dataset.get_repeat_count()) if args_opt.enable_data_sink == "true": repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.epoch_size time_monitor_steps = dataset_size lr_schedule = BertLearningRate( learning_rate=common_cfg.AdamWeightDecay.learning_rate, end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size * args_opt.epoch_size / 10), decay_steps=int(dataset_size * args_opt.epoch_size), power=common_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps) callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, save_ckpt_dir) ] if enable_loss_scale: update_cell = DynamicLossScaleUpdateCell( loss_scale_value=common_cfg.loss_scale_value, scale_factor=common_cfg.scale_factor, scale_window=common_cfg.scale_window) netwithgrads = BertTrainWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def train(): """training CenterNet""" context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(reserve_class_name_in_scope=False) context.set_context(save_graphs=False) ckpt_save_dir = args_opt.save_checkpoint_path rank = 0 device_num = 1 num_workers = 8 if args_opt.device_target == "Ascend": context.set_context(enable_auto_mixed_precision=False) context.set_context(device_id=args_opt.device_id) if args_opt.distribute == "true": D.init() device_num = args_opt.device_num rank = args_opt.device_id % device_num ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( get_rank()) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) _set_parallel_all_reduce_split() else: args_opt.distribute = "false" args_opt.need_profiler = "false" args_opt.enable_data_sink = "false" # Start create dataset! # mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num. logger.info("Begin creating dataset for CenterNet") coco = COCOHP(dataset_config, run_mode="train", net_opt=net_config, enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir) dataset = coco.create_train_dataset( args_opt.mindrecord_dir, args_opt.mindrecord_prefix, batch_size=train_config.batch_size, device_num=device_num, rank=rank, num_parallel_workers=num_workers, do_shuffle=args_opt.do_shuffle == 'true') dataset_size = dataset.get_dataset_size() logger.info("Create dataset done!") net_with_loss = CenterNetMultiPoseLossCell(net_config) new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps if args_opt.train_steps > 0: new_repeat_count = min( new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) else: args_opt.train_steps = args_opt.epoch_size * dataset_size logger.info("train steps: {}".format(args_opt.train_steps)) optimizer = _get_optimizer(net_with_loss, dataset_size) enable_static_time = args_opt.device_target == "CPU" callback = [ TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time) ] if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min( 8, device_num) == 0: config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint( prefix='checkpoint_centernet', directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) if args_opt.device_target == "Ascend": net_with_grads = CenterNetWithLossScaleCell( net_with_loss, optimizer=optimizer, sens=train_config.loss_scale_value) else: net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"], help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", choices=["true", "false"], help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"], help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"], help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument( "--accumulation_steps", type=int, default="1", help= "Accumulating gradients N times before weight update, default is 1.") parser.add_argument( "--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"], help= "Whether to allreduce after accumulation of N steps or after each step, default is true." ) parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " "meaning run all steps according to epoch number.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--enable_graph_kernel", type=str, default="auto", choices=["auto", "true", "false"], help="Accelerate by graph kernel, default is auto.") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init() device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init() device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( get_rank()) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) if args_opt.device_target == 'Ascend': _set_bert_all_reduce_split() else: rank = 0 device_num = 1 is_auto_enable_graph_kernel = _auto_enable_graph_kernel( args_opt.device_target, args_opt.enable_graph_kernel) if args_opt.enable_graph_kernel == "true" or is_auto_enable_graph_kernel: context.set_context(enable_graph_kernel=True) if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \ not is_auto_enable_graph_kernel: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 if args_opt.accumulation_steps > 1: logger.info("accumulation steps: {}".format( args_opt.accumulation_steps)) logger.info("global batch size: {}".format( cfg.batch_size * args_opt.accumulation_steps)) if args_opt.enable_data_sink == "true": args_opt.data_sink_steps *= args_opt.accumulation_steps logger.info("data sink steps: {}".format(args_opt.data_sink_steps)) if args_opt.enable_save_ckpt == "true": args_opt.save_checkpoint_steps *= args_opt.accumulation_steps logger.info("save checkpoint steps: {}".format( args_opt.save_checkpoint_steps)) ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) new_repeat_count = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.data_sink_steps if args_opt.train_steps > 0: train_steps = args_opt.train_steps * args_opt.accumulation_steps new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps) else: args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.accumulation_steps logger.info("train steps: {}".format(args_opt.train_steps)) optimizer = _get_optimizer(args_opt, net_with_loss) callback = [ TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size()) ] if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min( 8, device_num) == 0: config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint( prefix='checkpoint_bert', directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) accumulation_steps = args_opt.accumulation_steps enable_global_norm = cfg.enable_global_norm if accumulation_steps <= 1: if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU': net_with_grads = BertTrainOneStepWithLossScaleCellForAdam( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: net_with_grads = BertTrainOneStepWithLossScaleCell( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true" net_with_accumulation = ( BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else BertTrainAccumulationAllReduceEachWithLossScaleCell) net_with_grads = net_with_accumulation( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell, accumulation_steps=accumulation_steps, enable_global_norm=enable_global_norm) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def _build_training_pipeline(config: TransformerConfig, pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None): """ Build training pipeline. Args: config (TransformerConfig): Config of mass model. pre_training_dataset (Dataset): Pre-training dataset. fine_tune_dataset (Dataset): Fine-tune dataset. test_dataset (Dataset): Test dataset. """ net_with_loss = TransformerNetworkWithLoss(config, is_training=True) net_with_loss.init_parameters_data() if config.existed_ckpt: if config.existed_ckpt.endswith(".npz"): weights = np.load(config.existed_ckpt) else: weights = load_checkpoint(config.existed_ckpt) for param in net_with_loss.trainable_params(): weights_name = param.name if weights_name not in weights: raise ValueError( f"Param {weights_name} is not found in ckpt file.") if isinstance(weights[weights_name], Parameter): param.default_input = weights[weights_name].default_input elif isinstance(weights[weights_name], Tensor): param.default_input = Tensor(weights[weights_name].asnumpy(), config.dtype) elif isinstance(weights[weights_name], np.ndarray): param.default_input = Tensor(weights[weights_name], config.dtype) else: param.default_input = weights[weights_name] else: for param in net_with_loss.trainable_params(): name = param.name value = param.default_input if isinstance(value, Tensor): if name.endswith(".gamma"): param.default_input = one_weight(value.asnumpy().shape) elif name.endswith(".beta") or name.endswith(".bias"): param.default_input = zero_weight(value.asnumpy().shape) else: param.default_input = weight_variable( value.asnumpy().shape) dataset = pre_training_dataset if pre_training_dataset is not None \ else fine_tune_dataset if dataset is None: raise ValueError( "pre-training dataset or fine-tuning dataset must be provided one." ) update_steps = dataset.get_repeat_count() * dataset.get_dataset_size() if config.lr_scheduler == "isr": lr = Tensor(square_root_schedule( lr=config.lr, update_num=update_steps, decay_start_step=config.decay_start_step, warmup_steps=config.warmup_steps, min_lr=config.min_lr), dtype=mstype.float32) elif config.lr_scheduler == "poly": lr = Tensor(polynomial_decay_scheduler( lr=config.lr, min_lr=config.min_lr, decay_steps=config.decay_steps, total_update_num=update_steps, warmup_steps=config.warmup_steps, power=config.poly_lr_scheduler_power), dtype=mstype.float32) else: lr = config.lr if config.optimizer.lower() == "adam": optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98) elif config.optimizer.lower() == "lamb": lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr, power=10.0, warmup_steps=config.warmup_steps) decay_params = list( filter( lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x .name.lower(), net_with_loss.trainable_params())) other_params = list( filter( lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name. lower(), net_with_loss.trainable_params())) group_params = [{ 'params': decay_params, 'weight_decay': 0.01 }, { 'params': other_params }] optimizer = Lamb(group_params, lr, eps=1e-6) elif config.optimizer.lower() == "momentum": optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9) else: raise ValueError(f"optimizer only support `adam` and `momentum` now.") # Dynamic loss scale. scale_manager = DynamicLossScaleManager( init_loss_scale=config.init_loss_scale, scale_factor=config.loss_scale_factor, scale_window=config.scale_window) net_with_grads = TransformerTrainOneStepWithLossScaleCell( network=net_with_loss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) net_with_grads.set_train(True) model = Model(net_with_grads) loss_monitor = LossCallBack(config) ckpt_config = CheckpointConfig( save_checkpoint_steps=config.save_ckpt_steps, keep_checkpoint_max=config.keep_ckpt_max) rank_size = os.getenv('RANK_SIZE') callbacks = [loss_monitor] if rank_size is not None and int( rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) callbacks.append(ckpt_callback) if rank_size is None or int(rank_size) == 1: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) callbacks.append(ckpt_callback) print(f" | ALL SET, PREPARE TO TRAIN.") _train(model=model, config=config, pre_training_dataset=pre_training_dataset, fine_tune_dataset=fine_tune_dataset, test_dataset=test_dataset, callbacks=callbacks)
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=4, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="false", help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="false", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="100", help="Sink steps for each epoch, default is 1.") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " "meaning run all steps according to epoch number.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id, save_graphs=False) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") context.set_context(max_call_depth=3000) ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": D.init() device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( rank) + '/' context.reset_auto_parallel_context() _set_bert_all_reduce_split() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) else: rank = 0 device_num = 1 if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) new_repeat_count = args_opt.epoch_size * ds.get_dataset_size( ) // args_opt.data_sink_steps if args_opt.train_steps > 0: new_repeat_count = min( new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) else: args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() logger.info("train steps: {}".format(args_opt.train_steps)) optimizer = _get_optimizer(args_opt, net_with_loss) callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] if args_opt.enable_save_ckpt == "true" and rank == 0: config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) net_with_grads = BertTrainOneStepWithLossScaleCell( net_with_loss, optimizer=optimizer, scale_update_cell=update_cell) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads, frequency=cfg.Thor.frequency) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " "meaning run all steps according to epoch number.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init('hccl') device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init('nccl') device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( rank) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) from mindspore.parallel._auto_parallel_context import auto_parallel_context if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [29, 58, 87, 116, 145, 174, 203, 217]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [28, 55, 82, 109, 136, 163, 190, 205]) elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [30, 90, 150, 210, 270, 330, 390, 421]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [38, 93, 148, 203, 258, 313, 368, 397]) else: rank = 0 device_num = 1 if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) if args_opt.train_steps > 0: new_repeat_count = min( new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) if cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR( netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps, warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]" .format(cfg.optimizer)) callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] if args_opt.enable_save_ckpt == "true": config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(netwithloss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def train(): ''' finetune function ''' # BertCLS train for classification # BertNER train for sequence labeling if cfg.task == 'NER': tag_to_index = None if cfg.use_crf: tag_to_index = json.loads(open(cfg.label2id_file).read()) print(tag_to_index) max_val = len(tag_to_index) tag_to_index["<START>"] = max_val tag_to_index["<STOP>"] = max_val + 1 number_labels = len(tag_to_index) else: number_labels = cfg.num_labels netwithloss = BertNER(bert_net_cfg, cfg.batch_size, True, num_labels=number_labels, use_crf=cfg.use_crf, tag_to_index=tag_to_index, dropout_prob=0.1) elif cfg.task == 'Classification': netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1, assessment_method=cfg.assessment_method) else: raise Exception("task error, NER or Classification is supported.") dataset = get_dataset(data_file=cfg.data_file, batch_size=cfg.batch_size) steps_per_epoch = dataset.get_dataset_size() print('steps_per_epoch:', steps_per_epoch) # optimizer steps_per_epoch = dataset.get_dataset_size() if cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_steps=steps_per_epoch * cfg.epoch_num, power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list( filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }] optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.Lamb.learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_steps=steps_per_epoch * cfg.epoch_num, power=optimizer_cfg.Lamb.power) optimizer = Lamb(netwithloss.trainable_params(), learning_rate=lr_schedule) elif cfg.optimizer == 'Momentum': optimizer = Momentum( netwithloss.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception("Optimizer not supported.") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config) param_dict = load_checkpoint(cfg.pre_training_ckpt) load_param_into_net(netwithloss, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [ TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb ] model.train(cfg.epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=True)
def run_task_distill(ckpt_file): """ run task distill """ if ckpt_file == '': raise ValueError("Student ckpt file should not be None") cfg = phase2_cfg load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = ckpt_file netwithloss = BertNetworkWithLoss_td( teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, is_training=True, task_type=args_opt.task_type, num_labels=task.num_labels, is_predistill=False) rank = 0 device_num = 1 train_dataset = create_tinybert_dataset('td', cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.train_data_dir, args_opt.schema_dir, data_type=dataset_type) dataset_size = train_dataset.get_dataset_size() print('td2 train dataset size: ', dataset_size) print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count()) if args_opt.enable_data_sink == 'true': repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size( ) // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.td_phase2_epoch_size time_monitor_steps = dataset_size optimizer_cfg = cfg.optimizer_cfg lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size * args_opt.td_phase2_epoch_size / 10), decay_steps=int(dataset_size * args_opt.td_phase2_epoch_size), power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list( filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.eval_data_dir, args_opt.schema_dir, data_type=dataset_type) print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) if args_opt.do_eval.lower() == "true": callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), EvalCallBack(netwithloss.bert, eval_dataset) ] else: callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, td_phase2_save_ckpt_dir) ] if enable_loss_scale: update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertEvaluationWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(repeat_count, train_dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == 'true'), sink_size=args_opt.data_sink_steps)
def run_general_distill(): """ run general distill """ parser = argparse.ArgumentParser(description='tinybert general distill') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="3", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") parser.add_argument("--max_ckpt_num", type=int, default=1, help="Enable data sink, default is true.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default=1, help="Sink steps for each epoch, default is 1.") parser.add_argument("--save_ckpt_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") save_ckpt_dir = os.path.join( args_opt.save_ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) if not os.path.exists(save_ckpt_dir): os.makedirs(save_ckpt_dir) if args_opt.distribute == "true": D.init('hccl') device_num = args_opt.device_num rank = args_opt.device_id % device_num context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) else: rank = 0 device_num = 1 netwithloss = BertNetworkWithLoss_gd( teacher_config=bert_teacher_net_cfg, teacher_ckpt=args_opt.load_teacher_ckpt_path, student_config=bert_student_net_cfg, is_training=True, use_one_hot_embeddings=False) dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) dataset_size = dataset.get_dataset_size() print('dataset size: ', dataset_size) if args_opt.enable_data_sink == "true": repeat_count = args_opt.epoch_size * dataset.get_dataset_size( ) // args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps else: repeat_count = args_opt.epoch_size time_monitor_steps = dataset_size lr_schedule = BertLearningRate( learning_rate=common_cfg.AdamWeightDecay.learning_rate, end_learning_rate=common_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(dataset_size * args_opt.epoch_size / 10), decay_steps=int(dataset_size * args_opt.epoch_size), power=common_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=common_cfg.AdamWeightDecay.eps) callback = [ TimeMonitor(time_monitor_steps), LossCallBack(), ModelSaveCkpt(netwithloss.bert, args_opt.save_ckpt_step, args_opt.max_ckpt_num, save_ckpt_dir) ] update_cell = DynamicLossScaleUpdateCell( loss_scale_value=common_cfg.loss_scale_value, scale_factor=common_cfg.scale_factor, scale_window=common_cfg.scale_window) netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) model.train(repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
def _build_training_pipeline(config: TransformerConfig, pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None, platform="Ascend"): """ Build training pipeline. Args: config (TransformerConfig): Config of mass model. pre_training_dataset (Dataset): Pre-training dataset. fine_tune_dataset (Dataset): Fine-tune dataset. test_dataset (Dataset): Test dataset. """ net_with_loss = TransformerNetworkWithLoss(config, is_training=True) net_with_loss.init_parameters_data() _load_checkpoint_to_net(config, net_with_loss) dataset = pre_training_dataset if pre_training_dataset is not None \ else fine_tune_dataset if dataset is None: raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.") update_steps = config.epochs * dataset.get_dataset_size() lr = _get_lr(config, update_steps) optimizer = _get_optimizer(config, net_with_loss, lr) # loss scale. if config.loss_scale_mode == "dynamic": scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale, scale_factor=config.loss_scale_factor, scale_window=config.scale_window) else: scale_manager = FixedLossScaleManager(loss_scale=config.init_loss_scale, drop_overflow_update=True) net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) net_with_grads.set_train(True) model = Model(net_with_grads) time_cb = TimeMonitor(data_size=dataset.get_dataset_size()) ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps, keep_checkpoint_max=config.keep_ckpt_max) rank_size = os.getenv('RANK_SIZE') callbacks = [] callbacks.append(time_cb) if rank_size is not None and int(rank_size) > 1: loss_monitor = LossCallBack(config, rank_id=MultiAscend.get_rank()) callbacks.append(loss_monitor) if MultiAscend.get_rank() % 8 == 0: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(MultiAscend.get_rank())), config=ckpt_config) callbacks.append(ckpt_callback) if rank_size is None or int(rank_size) == 1: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) loss_monitor = LossCallBack(config, rank_id=os.getenv('DEVICE_ID')) callbacks.append(loss_monitor) callbacks.append(ckpt_callback) print(f" | ALL SET, PREPARE TO TRAIN.") _train(model=model, config=config, pre_training_dataset=pre_training_dataset, fine_tune_dataset=fine_tune_dataset, test_dataset=test_dataset, callbacks=callbacks)