def bert_withlossscale_manager_train_feed(): class ModelBert(nn.Cell): def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer manager = DynamicLossScaleManager() update_cell = LossScaleUpdateCell(manager) self.train_network = BertTrainOneStepWithLossScaleCell( network, self.optimizer, scale_update_cell=update_cell) self.train_network.set_train() def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7): return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) version = os.getenv('VERSION', 'base') batch_size = int(os.getenv('BATCH_SIZE', '1')) scaling_sens = Tensor(np.ones([1]).astype(np.float32)) inputs = load_test_data(batch_size) + (scaling_sens, ) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True)
def test_bert_train(): """ the main function """ class ModelBert(nn.Cell): """ ModelBert definition """ def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer self.train_network = BertTrainOneStepCell(network, self.optimizer) self.train_network.set_train() def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6): return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6) version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '1')) inputs = load_test_data(batch_size) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=False)
def test_AdamWeightDecayDynamicLR(): """ test_AdamWeightDecayDynamicLR """ inputs = Tensor(np.ones([1, 64]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label)
def test_AdamWeightDecayDynamicLR(): """ test_AdamWeightDecayDynamicLR """ context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label)
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): """ do train """ if load_checkpoint_path == "": raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() epoch_num = dataset.get_repeat_count() # optimizer if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=optimizer_cfg.AdamWeightDecayDynamicLR.power, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) elif optimizer_cfg.optimizer == 'Lamb': optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_filter=optimizer_cfg.Lamb.decay_filter) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] model.train(epoch_num, dataset, callbacks=callbacks)
def test_train(): ''' finetune function ''' target = args_opt.device_target if target == "Ascend": devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) elif target == "GPU": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") else: raise Exception("Target error, GPU or Ascend is supported.") #BertCLSTrain for classification #BertNERTrain for sequence labeling if cfg.task == 'NER': if cfg.use_crf: netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True, tag_to_index=tag_to_index, dropout_prob=0.1) else: netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) elif cfg.task == 'SQUAD': netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) else: netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) if cfg.task == 'SQUAD': dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num) else: dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num) # optimizer steps_per_epoch = dataset.get_dataset_size() if cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR( netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps) elif cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, weight_decay=cfg.Lamb.weight_decay, warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_filter=cfg.Lamb.decay_filter) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) else: raise Exception("Optimizer not supported.") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config) param_dict = load_checkpoint(cfg.pre_training_ckpt) load_param_into_net(netwithloss, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) if cfg.task == 'SQUAD': netwithgrads = BertSquadCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
def run_pretrain(args_opt): """pre-train bert""" global device_id global device_num global rank_id global job_id args_opt.device_id = device_id args_opt.device_num = device_num sync_dataset(args_opt.data_url) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init('hccl') device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init('nccl') device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( rank) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) from mindspore.parallel._auto_parallel_context import auto_parallel_context if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [29, 58, 87, 116, 145, 174, 203, 217]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [28, 55, 82, 109, 136, 163, 190, 205]) elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [30, 90, 150, 210, 270, 330, 390, 421]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [38, 93, 148, 203, 258, 313, 368, 397]) else: rank = 0 device_num = 1 if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) if args_opt.train_steps > 0: new_repeat_count = min( new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) if cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR( netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps, warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]" .format(cfg.optimizer)) callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] print("Enable save checkpoint: ", args_opt.enable_save_ckpt) print("Rank ID: ", rank_id) if args_opt.enable_save_ckpt == "true" and rank_id % device_num == 0: print("Enable save checkpoint") config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(netwithloss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_task_sink", type=str, default="true", help="Enable task sink, default is true.") parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.") parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"), enable_loop_sink=(args_opt.enable_loop_sink == "true"), enable_mem_reuse=(args_opt.enable_mem_reuse == "true")) context.set_context(reserve_class_name_in_scope=False) if args_opt.distribute == "true": device_num = args_opt.device_num context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) D.init() rank = args_opt.device_id % device_num else: rank = 0 device_num = 1 ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) if cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR( netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]" .format(cfg.optimizer)) callback = [LossCallBack()] if args_opt.enable_save_ckpt == "true": config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck) callback.append(ckpoint_cb) if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) load_param_into_net(netwithloss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument( '--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], help='device where the code will be implemented. (Default: Ascend)') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " "meaning run all steps according to epoch number.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") ckpt_save_dir = args_opt.save_checkpoint_path if args_opt.distribute == "true": if args_opt.device_target == 'Ascend': D.init('hccl') device_num = args_opt.device_num rank = args_opt.device_id % device_num else: D.init('nccl') device_num = D.get_group_size() rank = D.get_rank() ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str( rank) + '/' context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) from mindspore.parallel._auto_parallel_context import auto_parallel_context if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [29, 58, 87, 116, 145, 174, 203, 217]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [28, 55, 82, 109, 136, 163, 190, 205]) elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: auto_parallel_context().set_all_reduce_fusion_split_indices( [30, 90, 150, 210, 270, 330, 390, 421]) else: auto_parallel_context().set_all_reduce_fusion_split_indices( [38, 93, 148, 203, 258, 313, 368, 397]) else: rank = 0 device_num = 1 if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) data_epoch_size = new_repeat_count // args_opt.epoch_size # Epoch nums in one dataset. if args_opt.train_steps > 0: new_repeat_count = min( new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) if cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR( netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps, warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]" .format(cfg.optimizer)) callback = [ TimeMonitor(ds.get_dataset_size()), LossCallBack(data_epoch_size) ] if args_opt.enable_save_ckpt == "true": config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: param_dict = load_checkpoint(args_opt.load_checkpoint_path) load_param_into_net(netwithloss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def test_AdamWeightDecayDynamicLR_init(self): with pytest.raises(TypeError): AdamWeightDecayDynamicLR(0.5, 10)
def test_AdamWeightDecayDynamicLR_init(self): with pytest.raises(ValueError): AdamWeightDecayDynamicLR(None, 10)
def __init__(self, network): super(TrainStepWrapForAdamDynamicLr, self).__init__() self.network = network self.weights = ParameterTuple(network.get_parameters()) self.optimizer = AdamWeightDecayDynamicLR(self.weights, 10) self.sens = Tensor(np.ones(shape=(1, 10)).astype(np.float32))