def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer manager = DynamicLossScaleManager() update_cell = LossScaleUpdateCell(manager) self.train_network = BertTrainOneStepWithLossScaleCell( network, self.optimizer, scale_update_cell=update_cell) self.train_network.set_train()
class ModelBert(nn.Cell): def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer) self.train_network.set_train() def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7): return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
class ModelBert(nn.Cell): def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer manager = DynamicLossScaleManager() update_cell = LossScaleUpdateCell(manager) self.train_network = BertTrainOneStepWithLossScaleCell( network, self.optimizer, scale_update_cell=update_cell) self.train_network.set_train() def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6): return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)
def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--enable_task_sink", type=str, default="true", help="Enable task sink, default is true.") parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.") parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"), enable_loop_sink=(args_opt.enable_loop_sink == "true"), enable_mem_reuse=(args_opt.enable_mem_reuse == "true")) context.set_context(reserve_class_name_in_scope=False) if args_opt.distribute == "true": device_num = args_opt.device_num context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) D.init() rank = args_opt.device_id % device_num else: rank = 0 device_num = 1 ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) if cfg.optimizer == 'Lamb': optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR( netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=cfg.AdamWeightDecayDynamicLR.power, weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, eps=cfg.AdamWeightDecayDynamicLR.eps) else: raise ValueError( "Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]" .format(cfg.optimizer)) callback = [LossCallBack()] if args_opt.enable_save_ckpt == "true": config_ck = CheckpointConfig( save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck) callback.append(ckpoint_cb) if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) load_param_into_net(netwithloss, param_dict) if args_opt.enable_lossscale == "true": update_cell = DynamicLossScaleUpdateCell( loss_scale_value=cfg.loss_scale_value, scale_factor=cfg.scale_factor, scale_window=cfg.scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=update_cell) else: netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) model = Model(netwithgrads) model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
def __init__(self, network, optimizer=None): super(ModelBert, self).__init__() self.optimizer = optimizer self.train_network = BertTrainOneStepWithLossScaleCell( network, self.optimizer) self.train_network.set_train()
def test_bert_tdt(): """test bert tdt""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) ds = me_de_train_dataset() version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) scale_window = 3 scale_manager = DynamicLossScaleManager(2**16, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) netwithgrads.set_train(True) model = Model(netwithgrads) callback = ModelCallback() params = netwithloss.trainable_params() for param in params: param.init_data() value = param.default_input name = param.name if isinstance(value, Tensor): if name.split('.')[-1] in ['weight']: if name.split('.')[-3] in ['cls2']: logger.info( "***************** BERT param name is 1 {}".format( name)) param.default_input = weight_variable( value.asnumpy().shape) else: logger.info( "***************** BERT param name is 2 {}".format( name)) tempshape = value.asnumpy().shape shape = (tempshape[1], tempshape[0]) weight_value = weight_variable(shape).asnumpy() param.default_input = Tensor( np.transpose(weight_value, [1, 0])) else: logger.info( "***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False) # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) expect_loss_value = [ 12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661 ] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) overflow = np.array(callback.overflow_list) expect_overflow = [ True, True, False, False, False, True, False, False, False, True ] print("overflow: {}".format(overflow)) assert (overflow == expect_overflow).all() loss_scale = np.array(callback.lossscale_list) expect_loss_scale = [ 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0 ] print("loss scale: {}".format(loss_scale)) assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
def test_bert_tdt(): """test bert tdt""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(enable_task_sink=True) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) ds = me_de_train_dataset() version = os.getenv('VERSION', 'large') batch_size = int(os.getenv('BATCH_SIZE', '16')) config = get_config(version=version, batch_size=batch_size) netwithloss = BertNetworkWithLoss(config, True) optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) scale_window = 3 scale_manager = DynamicLossScaleManager(2**32, 2, scale_window) netwithgrads = BertTrainOneStepWithLossScaleCell( netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) netwithgrads.set_train(True) model = Model(netwithgrads) callback = ModelCallback() params = netwithloss.trainable_params() for param in params: value = param.default_input name = param.name if isinstance(value, Tensor): if name.split('.')[-1] in ['weight']: if name.split('.')[-3] in ['cls2']: logger.info( "***************** BERT param name is 1 {}".format( name)) param.default_input = weight_variable( value.asnumpy().shape) else: logger.info( "***************** BERT param name is 2 {}".format( name)) tempshape = value.asnumpy().shape shape = (tempshape[1], tempshape[0]) weight_value = weight_variable(shape).asnumpy() param.default_input = Tensor( np.transpose(weight_value, [1, 0])) else: logger.info( "***************** BERT param name is 3 {}".format(name)) param.default_input = weight_variable(value.asnumpy().shape) model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False) # assertion occurs while the loss_scale value is wrong count = 0 for i in range(len(callback.overflow_list)): if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0: count = 0 assert callback.lossscale_list[i] == callback.lossscale_list[ i - 1] * Tensor(0.5, mstype.float32) if callback.overflow_list[i] == Tensor(False, mstype.bool_): count = count + 1 if count == scale_window: count = 0 assert callback.lossscale_list[i] == callback.lossscale_list[ i - 1] * Tensor(2.0, mstype.float32)