def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): """ do train """ if load_checkpoint_path == "": raise ValueError( "Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer optimizer = Adam(network.trainable_params(), learning_rate=optimizer_cfg.learning_rate) # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint( prefix="classifier", directory=None if save_checkpoint_path == "" else save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [ TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb ] model.train(epoch_num, dataset, callbacks=callbacks)
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): """ do train """ if load_checkpoint_path == "": raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() # optimizer if optimizer_cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_steps=steps_per_epoch * epoch_num, power=optimizer_cfg.AdamWeightDecay.power) params = network.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}] optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif optimizer_cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_steps=steps_per_epoch * epoch_num, power=optimizer_cfg.Lamb.power) optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix="ner", directory=None if save_checkpoint_path == "" else save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb] train_begin = time.time() model.train(epoch_num, dataset, callbacks=callbacks) train_end = time.time() print("latency: {:.6f} s".format(train_end - train_begin))
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): """ do train """ if load_checkpoint_path == "": raise ValueError("Pretrain model missed, finetune task must load pretrain model!") steps_per_epoch = dataset.get_dataset_size() epoch_num = dataset.get_repeat_count() # optimizer if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, power=optimizer_cfg.AdamWeightDecayDynamicLR.power, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) elif optimizer_cfg.optimizer == 'Lamb': optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, warmup_steps=int(steps_per_epoch * epoch_num * 0.1), decay_filter=optimizer_cfg.Lamb.decay_filter) elif optimizer_cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(network, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] model.train(epoch_num, dataset, callbacks=callbacks)
def train(): ''' finetune function ''' # BertCLS train for classification # BertNER train for sequence labeling if cfg.task == 'NER': tag_to_index = None if cfg.use_crf: tag_to_index = json.loads(open(cfg.label2id_file).read()) print(tag_to_index) max_val = len(tag_to_index) tag_to_index["<START>"] = max_val tag_to_index["<STOP>"] = max_val + 1 number_labels = len(tag_to_index) else: number_labels = cfg.num_labels netwithloss = BertNER(bert_net_cfg, cfg.batch_size, True, num_labels=number_labels, use_crf=cfg.use_crf, tag_to_index=tag_to_index, dropout_prob=0.1) elif cfg.task == 'Classification': netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1, assessment_method=cfg.assessment_method) else: raise Exception("task error, NER or Classification is supported.") dataset = get_dataset(data_file=cfg.data_file, batch_size=cfg.batch_size) steps_per_epoch = dataset.get_dataset_size() print('steps_per_epoch:', steps_per_epoch) # optimizer steps_per_epoch = dataset.get_dataset_size() if cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate, warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_steps=steps_per_epoch * cfg.epoch_num, power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list( filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) other_params = list( filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }] optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) elif cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate( learning_rate=optimizer_cfg.Lamb.learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_steps=steps_per_epoch * cfg.epoch_num, power=optimizer_cfg.Lamb.power) optimizer = Lamb(netwithloss.trainable_params(), learning_rate=lr_schedule) elif cfg.optimizer == 'Momentum': optimizer = Momentum( netwithloss.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, momentum=optimizer_cfg.Momentum.momentum) else: raise Exception("Optimizer not supported.") # load checkpoint into network ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config) param_dict = load_checkpoint(cfg.pre_training_ckpt) load_param_into_net(netwithloss, param_dict) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) model = Model(netwithgrads) callbacks = [ TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb ] model.train(cfg.epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=True)