def test_bert_tdt(): """test bert tdt""" context.set_context(mode=context.GRAPH_MODE) context.set_context(device_target="Ascend") context.set_context(enable_task_sink=True) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) parallel_callback = ModelCallback() ds = me_de_train_dataset(cfg.bert_config.batch_size) config = cfg.bert_config netwithloss = BertNetworkWithLoss(config, True) optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate, end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False) netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) netwithgrads.set_train(True) model = Model(netwithgrads) config_ck = CheckpointConfig( save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck) model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False)
def bert_withlossscale_manager_train_feed(): class ModelBert(nn.Cell): def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer manager = DynamicLossScaleManager() update_cell = LossScaleUpdateCell(manager) self.train_network = BertTrainOneStepWithLossScaleCell( network, self.optimizer, scale_update_cell=update_cell) self.train_network.set_train() def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7): return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7) version = os.getenv('VERSION', 'base') batch_size = int(os.getenv('BATCH_SIZE', '1')) scaling_sens = Tensor(np.ones([1]).astype(np.float32)) inputs = load_test_data(batch_size) + (scaling_sens, ) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=True)
def test_bert_train(): """ the main function """ class ModelBert(nn.Cell): """ ModelBert definition """ def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer self.train_network = BertTrainOneStepCell(network, self.optimizer) self.train_network.set_train() def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6): return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6) version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '1')) inputs = load_test_data(batch_size) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) net = ModelBert(netwithloss, optimizer=optimizer) net.set_train() build_construct_graph(net, *inputs, execute=False)
def test_bert_tdt(): """test bert tdt""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(enable_task_sink=True) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) parallel_callback = ModelCallback() ds = me_de_train_dataset() version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) netwithgrads.set_train(True) model = Model(netwithgrads) params = netwithloss.trainable_params() for param in params: value = param.default_input name = param.name if isinstance(value, Tensor): if name.split('.')[-1] in ['weight']: if name.split('.')[-3] in ['cls2']: logger.info( "***************** BERT param name is 1 {}".format( name)) param.default_input = weight_variable( value.asnumpy().shape) else: logger.info( "***************** BERT param name is 2 {}".format( name)) tempshape = value.asnumpy().shape shape = (tempshape[1], tempshape[0]) weight_value = weight_variable(shape).asnumpy() param.default_input = Tensor( np.transpose(weight_value, [1, 0])) else: logger.info( "***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False) loss_value = np.array(parallel_callback.loss_list) expect_out = [ 12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289, 12.165594, 12.824818, 12.38842, 12.604046 ] logger.info("expected loss value output: {}".format(expect_out)) assert allclose(loss_value, expect_out, 0.00001, 0.00001)
def test_bert_tdt(): """test bert tdt""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(enable_task_sink=True) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) parallel_callback = ModelCallback() ds = me_de_train_dataset() version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = Lamb(netwithloss.trainable_params(), decay_steps=10000, start_learning_rate=1e-4, end_learning_rate=0.0, power=10.0, warmup_steps=0, decay_filter=lambda x: False) netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) netwithgrads.set_train(True) model = Model(netwithgrads) params = netwithloss.trainable_params() for param in params: value = param.default_input name = param.name if isinstance(value, Tensor): if name.split('.')[-1] in ['weight']: if name.split('.')[-3] in ['cls2']: logger.info("***************** BERT param name is 1 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) else: logger.info("***************** BERT param name is 2 {}".format(name)) tempshape = value.asnumpy().shape shape = (tempshape[1], tempshape[0]) weight_value = weight_variable(shape).asnumpy() param.default_input = Tensor(np.transpose(weight_value, [1, 0])) else: logger.info("***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False) loss_value = np.array(parallel_callback.loss_list) expect_out = [12.191790, 11.739655, 11.523477, 11.320723, 11.113152, 11.203759, 10.841681, 10.826849, 10.616718, 10.486609] logger.info("expected loss value output: {}".format(expect_out)) assert allclose(loss_value, expect_out, 0.001, 0.001)
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_task_sink", type=str, default="true", help="Enable task sink, default is true.") parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.") parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"), enable_loop_sink=(args_opt.enable_loop_sink == "true"), enable_mem_reuse=(args_opt.enable_mem_reuse == "true")) context.set_context(reserve_class_name_in_scope=False) if args_opt.distribute == "true": device_num = args_opt.device_num context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) D.init() rank = args_opt.device_id % device_num else: rank = 0 device_num = 1 ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) if cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR( netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]" .format(cfg.optimizer)) callback = [LossCallBack()] if args_opt.enable_save_ckpt == "true": config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck) callback.append(ckpoint_cb) if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) load_param_into_net(netwithloss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def load_test_data(): dataset = get_dataset() return dataset.next() input_ids, input_mask, token_type_id, \ next_sentence_labels, masked_lm_positions, \ masked_lm_ids, masked_lm_weights = load_test_data() test_sets = [ ('BertNetworkWithLoss_1', { 'block': BertNetworkWithLoss(BertConfig(batch_size=1), False, use_one_hot_embeddings=True), 'desc_inputs': [ input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights ], 'desc_bprop': [[1]] }), ('BertNetworkWithLoss_2', { 'block': BertNetworkWithLoss(BertConfig(batch_size=1), False, True), 'desc_inputs': [ input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights ], 'desc_bprop': [[1]]
def test_bert_tdt(): """test bert tdt""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) ds = me_de_train_dataset() version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) scale_window = 3 scale_manager = DynamicLossScaleManager(2**16, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) netwithgrads.set_train(True) model = Model(netwithgrads) callback = ModelCallback() params = netwithloss.trainable_params() for param in params: param.init_data() value = param.default_input name = param.name if isinstance(value, Tensor): if name.split('.')[-1] in ['weight']: if name.split('.')[-3] in ['cls2']: logger.info( "***************** BERT param name is 1 {}".format( name)) param.default_input = weight_variable( value.asnumpy().shape) else: logger.info( "***************** BERT param name is 2 {}".format( name)) tempshape = value.asnumpy().shape shape = (tempshape[1], tempshape[0]) weight_value = weight_variable(shape).asnumpy() param.default_input = Tensor( np.transpose(weight_value, [1, 0])) else: logger.info( "***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False) # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) expect_loss_value = [ 12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661 ] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) overflow = np.array(callback.overflow_list) expect_overflow = [ True, True, False, False, False, True, False, False, False, True ] print("overflow: {}".format(overflow)) assert (overflow == expect_overflow).all() loss_scale = np.array(callback.lossscale_list) expect_loss_scale = [ 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0 ] print("loss scale: {}".format(loss_scale)) assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
{ 'id': 'BertNetworkWithLoss', 'group': 'bert', 'block': BertNetworkWithLoss(config=BertConfig( batch_size=1, seq_length=128, vocab_size=21128, hidden_size=1024, num_hidden_layers=2, num_attention_heads=16, intermediate_size=4096, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, use_relative_positions=True, input_mask_from_dataset=True, token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float32), is_training=True), 'reduce_output': False }, { 'id': 'BertAttentionQueryKeyMul_CICase',
def test_bert_tdt(): """test bert tdt""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(enable_task_sink=True) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) ds = me_de_train_dataset() version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) scale_window = 3 scale_manager = DynamicLossScaleManager(2**32, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) netwithgrads.set_train(True) model = Model(netwithgrads) callback = ModelCallback() params = netwithloss.trainable_params() for param in params: value = param.default_input name = param.name if isinstance(value, Tensor): if name.split('.')[-1] in ['weight']: if name.split('.')[-3] in ['cls2']: logger.info( "***************** BERT param name is 1 {}".format( name)) param.default_input = weight_variable( value.asnumpy().shape) else: logger.info( "***************** BERT param name is 2 {}".format( name)) tempshape = value.asnumpy().shape shape = (tempshape[1], tempshape[0]) weight_value = weight_variable(shape).asnumpy() param.default_input = Tensor( np.transpose(weight_value, [1, 0])) else: logger.info( "***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False) # assertion occurs while the loss_scale value is wrong count = 0 for i in range(len(callback.overflow_list)): if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0: count = 0 assert callback.lossscale_list[i] == callback.lossscale_list[ i - 1] * Tensor(0.5, mstype.float32) if callback.overflow_list[i] == Tensor(False, mstype.bool_): count = count + 1 if count == scale_window: count = 0 assert callback.lossscale_list[i] == callback.lossscale_list[ i - 1] * Tensor(2.0, mstype.float32)