def training_tf_loop(model, training_loss, epochs: int = 1, num_batches_per_epoch: int = 1, logging_epoch_freq: int = 100, manager: tf.train.CheckpointManager = None): """Runs Adam optimizer on model with training_loss (no monitoring). :param model: The model to be trained. :param training_loss: A function that returns the training objective. :param epochs: The number of full data passes (epochs). :param num_batches_per_epoch: The number of batches per epoch :param logging_epoch_freq: The epoch frequency that the training loss is printed. """ optimizer = tf.optimizers.Adam() @tf.function def tf_optimization_step(): optimizer.minimize(training_loss, model.trainable_variables) # t = time.time() for epoch in range(epochs): for _ in range(num_batches_per_epoch): tf_optimization_step() # tf_optimization_step(model, training_loss, optimizer) epoch_id = epoch + 1 if epoch_id % logging_epoch_freq == 0: tf.print(f"Epoch {epoch_id}: ELBO (train) {training_loss()}") if manager is not None: manager.save()
def train(data, class_weights, flags, net: Net, framework: Framework, manager: tf.train.CheckpointManager): log = get_logger() io = SharedFlagIO(flags, subprogram=True) flags = io.read_flags() if io.read_flags() is not None else flags log.info('Building {} train op'.format(flags.model)) goal = len(data) * flags.epoch first = True for i, (x_batch, loss_feed) in enumerate(framework.shuffle(data, class_weights)): loss = net(x_batch, training=True, **loss_feed) step = net.step.numpy() lr = net.optimizer.learning_rate.numpy() line = 'step: {} loss: {:f} lr: {:.2e} progress: {:.2f}%' if not first: flags.progress = i * flags.batch / goal * 100 log.info(line.format(step, loss, lr, flags.progress)) else: log.info(f"Following gradient from step {step}...") io.send_flags() flags = io.read_flags() ckpt = bool(not step % flags.save) if ckpt and not first: save = manager.save() log.info(f"Saved checkpoint: {save}") first = False if not ckpt: save = manager.save() log.info(f"Finished training at checkpoint: {save}")
def save_checkpoint(strategy, step, manager: tf.train.CheckpointManager): """Saves model to with provided checkpoint prefix.""" if should_export_checkpoint(strategy): saved_path = manager.save(step) logging.info('Saving model as TF checkpoint: %s', saved_path) else: # In multi worker training we need every worker to save checkpoint, because variables can trigger synchronization on read and synchronization needs # all workers to participate. To avoid workers overriding each other we save to a temporary directory on non-chief workers. tmp_dir = tempfile.mkdtemp() manager.save(step) tf.io.gfile.rmtree(tmp_dir)
def checkpointing_train_SGPR( model: gpflow.models.SGPR, X: tf.Tensor, Y: tf.Tensor, epochs: int, manager: tf.train.CheckpointManager, optimizer: tf.optimizers = tf.optimizers.Adam(learning_rate=0.1), logging_epoch_freq: int = 10, epoch_var: Optional[tf.Variable] = None, exp_tag: str = 'test', ): """ Training loop for Sparse GP with checkpointing """ set_trainable(model.mean_function, False) tf_optimization_step = tf.function(optimization_exact) loss = list() for epoch in range(epochs): tf_optimization_step(model) if epoch_var is not None: epoch_var.assign(epoch + 1) epoch_id = epoch + 1 loss.append(model.training_loss()) if epoch_id % logging_epoch_freq == 0: ckpt_path = manager.save() tf.print( f"Epoch {epoch_id}: LOSS (train) {model.training_loss()}, saved at {ckpt_path}" ) tf.print(f"MSE: {mean_squared_error(Y, model.predict_y(X)[0])}") plt.plot(range(epochs), loss) plt.xlabel('Epoch', fontsize=25) plt.ylabel('Loss', fontsize=25) plt.tight_layout()
def checkpointing_training_loop( model: gpflow.models.SVGP, batch_size: int, epochs: int, manager: tf.train.CheckpointManager, logging_epoch_freq: int = 100, epoch_var: Optional[tf.Variable] = None, step_var: Optional[tf.Variable] = None, ): tf_optimization_step = tf.function(optimization_step) batches = iter(train_dataset) for epoch in range(epochs): for step in range(ci_niter(num_batches_per_epoch)): tf_optimization_step(model, next(batches)) if step_var is not None: step_var.assign(epoch * num_batches_per_epoch + step + 1) if epoch_var is not None: epoch_var.assign(epoch + 1) epoch_id = epoch + 1 if epoch_id % logging_epoch_freq == 0: ckpt_path = manager.save() tf.print( f"Epoch {epoch_id}: ELBO (train) {model.elbo(data)}, saved at {ckpt_path}" )
def monitored_training_loop(model, training_loss, epochs: int = 1, num_batches_per_epoch: int = 1, fast_tasks: gpf.monitor.MonitorTaskGroup = None, slow_tasks: gpf.monitor.MonitorTaskGroup = None, logging_epoch_freq: int = 100, manager: tf.train.CheckpointManager = None): """Monitors (with images) Adam optimizer on model with training_loss. Monitoring is not inside tf.function so this method will be slower than monitored_training_tf_loop. :param model: The model to be trained. :param training_loss: A function that returns the training objective. :param epochs: The number of full data passes (epochs). :param num_batches_per_epoch: The number of batches per epoch :param fast_tasks: gpflow monitor fast tasks e.g. MonitorTaskGroup([ScalarToTensorBoard(log_dir, training_loss, "elbo")]) :param slow_tasks: gpflow monitor slow tasks e.g. plotting images :param logging_epoch_freq: The epoch frequency that the training loss is printed. """ optimizer = tf.optimizers.Adam() # checkpoint_path = "training_2/cp-{epoch:04d}.ckpt" # checkpoint_dir = os.path.dirname(checkpoint_path) @tf.function def tf_optimization_step(): optimizer.minimize(training_loss, model.trainable_variables) monitor = Monitor(fast_tasks, slow_tasks) t = time.time() for epoch in range(epochs): for _ in range(num_batches_per_epoch): tf_optimization_step() # duration = t - time.time() # print("Iteration duration: ", duration) # t = time.time() monitor(epoch) epoch_id = epoch + 1 if epoch_id % logging_epoch_freq == 0: tf.print(f"Epoch {epoch_id}: ELBO (train) {training_loss()}") if manager is not None: manager.save()
def monitored_training_tf_loop(model, training_loss, epochs: int = 1, num_batches_per_epoch: int = 1, fast_tasks: gpf.monitor.MonitorTaskGroup = None, logging_epoch_freq: int = 100, manager: tf.train.CheckpointManager = None): """Monitors Adam optimizer on model with training_loss. Both training and monitoring are inside tf.function (no image monitoring). This method only monitors the fast tasks as matplotlib code cannot be built in a TF graph. :param model: The model to be trained. :param training_loss: A function that returns the training objective. :param epochs: The number of full data passes (epochs). :param num_batches_per_epoch: The number of batches per epoch :param fast_tasks: gpflow monitor fast tasks e.g. MonitorTaskGroup([ScalarToTensorBoard(log_dir, training_loss, "elbo")]) :param logging_epoch_freq: The epoch frequency that the training loss is printed. """ optimizer = tf.optimizers.Adam() monitor = Monitor(fast_tasks) @tf.function def monitored_tf_opt_step(epoch): optimizer.minimize(training_loss, model.trainable_variables) monitor(epoch) # t = time.time() epochs = tf.constant(epochs) # needs to be tf.const for epoch in tf.range(epochs): for _ in range(num_batches_per_epoch): monitored_tf_opt_step(epoch) epoch_id = epoch + 1 if epoch_id % logging_epoch_freq == 0: tf.print(f"Epoch {epoch_id}: ELBO (train) {training_loss()}") if manager is not None: manager.save()
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any, epochs: Any, train_data_path: AnyStr, test_data_path: AnyStr, dict_path: AnyStr, max_sentence_length: Any, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param buffer_size: 缓冲大小 :param epochs: 训练周期 :param train_data_path: 训练数据保存路径 :param test_data_path: 测试数据保存路径 :param dict_path: 词表文件 :param max_sentence_length: 最大句子对长度 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=1e-5) train_steps_per_epoch = 125000 // batch_size valid_steps_per_epoch = 10000 // batch_size warmup_steps = train_steps_per_epoch // 3 total_steps = train_steps_per_epoch * epochs - warmup_steps # learning_rate = CustomSchedule(d_model=768, warmup_steps=warmup_steps) # optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, name="optimizer") # optimizer, _ = create_optimizer(init_lr=2e-5, num_train_steps=total_steps, num_warmup_steps=warmup_steps) train_dataset, valid_dataset = load_raw_dataset( train_data_path=train_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, dict_path=dict_path, test_data_path=test_data_path) print("训练开始") progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None for (batch, (queries, segments, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric = _train_step( model=model, optimizer=optimizer, loss_metric=loss_metric, queries=queries, segments=segments, targets=labels ) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}
def train( model: onmt.models.Model, optimizer: tf.keras.optimizers.Optimizer, learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule, source_file: str, target_file: str, checkpoint_manager: tf.train.CheckpointManager, maximum_length=100, shuffle_buffer_size=-1, # Uniform shuffle. train_steps=100000, save_every=1000, report_every=100, validation_source_file=None, validation_target_file=None, validate_every=2000, validate_now=False, bpe=False, bpe_combined=False, ): """Train a OpenNMT model. Arguments: model {onmt.models.Model} -- Model to train. optimizer {tf.keras.optimizers.Optimizer} -- Optimizer to use. learning_rate {tf.keras.optimizers.schedules.LearningRateSchedule} -- Learning rate schedule to use. source_file {str} -- Aligned source language file. target_file {str} -- Aligned target language file. checkpoint_manager {tf.train.CheckpointManager} -- Checkpoint manager. Keyword Arguments: maximum_length {int} -- [description] (default: {100}) shuffle_buffer_size {int} -- [description] (default: {-1}) save_every {int} -- [description] (default: {1000}) report_every {int} -- [description] (default: {100}) validation_source_file {[type]} -- [description] (default: {None}) validation_target_file {[type]} -- [description] (default: {None}) validate_every {int} -- [description] (default: {2000}) validate_now {bool} -- [description] (default: {False}) bpe {bool} -- [description] (default: {False}) bpe_combined {bool} -- [description] (default: {False}) Returns: [type] -- [description] """ # Create the training dataset. dataset = model.examples_inputter.make_training_dataset( source_file, target_file, batch_size=3072, batch_type="tokens", shuffle_buffer_size=shuffle_buffer_size, length_bucket_width= 1, # Bucketize sequences by the same length for efficiency. maximum_features_length=maximum_length, maximum_labels_length=maximum_length, ) @tf.function(input_signature=dataset.element_spec) def training_step(source, target): # Run the encoder. source_inputs = model.features_inputter(source, training=True) encoder_outputs, _, _ = model.encoder(source_inputs, source["length"], training=True) # Run the decoder. target_inputs = model.labels_inputter(target, training=True) decoder_state = model.decoder.initial_state( memory=encoder_outputs, memory_sequence_length=source["length"]) logits, _, _ = model.decoder(target_inputs, target["length"], state=decoder_state, training=True) # Compute the cross entropy loss. loss_num, loss_den, _ = onmt.utils.cross_entropy_sequence_loss( logits, target["ids_out"], target["length"], label_smoothing=0.1, average_in_time=True, training=True, ) loss = loss_num / loss_den # Compute and apply the gradients. variables = model.trainable_variables gradients = optimizer.get_gradients(loss, variables) optimizer.apply_gradients(list(zip(gradients, variables))) return loss # Runs the training loop. for source, target in dataset: loss = training_step(source, target) step = optimizer.iterations.numpy() if step % validate_every == 0 or validate_now: output_file_name = f"predictions.{step}.txt" if validation_source_file is not None: tf.get_logger().info( f"Saving validation predictions from {validation_source_file} to {output_file_name}" ) translate(model, validation_source_file, output_file=output_file_name) if bpe: output_file_name = decode_bpe_file(output_file_name, combined=bpe_combined) tf.get_logger().info( f"Computing BLEU between from {validation_target_file} to {output_file_name}" ) per_sentence_score, mean_score = compute_bleu( output_file_name, validation_target_file) tf.get_logger().info(f"BLEU score {mean_score}") if step % report_every == 0: tf.get_logger().info( "Step = %d ; Learning rate = %f ; Loss = %f", step, learning_rate(step), loss, ) if step % save_every == 0: tf.get_logger().info("Saving checkpoint for step %d", step) checkpoint_manager.save(checkpoint_number=step) tf.get_logger().info("Checkpoint saved.") if step == train_steps: break
def train(encoder: tf.keras.Model, decoder: tf.keras.Model, optimizer: tf.keras.optimizers.Adam, epochs: int, checkpoint: tf.train.CheckpointManager, train_data_path: str, max_len: int, vocab_size: int, batch_size: int, buffer_size: int, checkpoint_save_freq: int, num_mel: int, tokenized_type: str = "phoneme", dict_path: str = "", valid_data_split: float = 0.0, valid_data_path: str = "", max_train_data_size: int = 0, max_valid_data_size: int = 0): """ 训练模块 :param encoder: 模型的encoder :param decoder: 模型的decoder :param optimizer: 优化器 :param checkpoint: 检查点管理器 :param epochs: 训练周期 :param train_data_path: 文本数据路径 :param max_len: 文本序列最大长度 :param vocab_size: 词汇大小 :param num_mel: 产生的梅尔带数 :param buffer_size: Dataset加载缓存大小 :param batch_size: Dataset加载批大小 :param tokenized_type: 分词类型,默认按音素分词,模式:phoneme(音素)/word(单词)/char(字符) :param dict_path: 字典路径,若使用phoneme则不用传 :param valid_data_split: 用于从训练数据中划分验证数据 :param valid_data_path: 验证数据文本路径 :param max_train_data_size: 最大训练数据量 :param max_valid_data_size: 最大验证数据量 :param checkpoint_save_freq: 检查点保存频率 """ train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \ _dataset.load_data(train_data_path=train_data_path, max_len=max_len, vocab_size=vocab_size, batch_size=batch_size, buffer_size=buffer_size, tokenized_type=tokenized_type, dict_path=dict_path, valid_data_split=valid_data_split, valid_data_path=valid_data_path, max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size) if steps_per_epoch == 0: print("训练数据量过小,小于batch_size,请添加数据后重试") exit(0) for epoch in range(epochs): print('Epoch {}/{}'.format(epoch + 1, epochs)) start_time = time.time() total_loss = 0 for (batch, (mel, stop_token, sentence)) in enumerate(train_dataset.take(steps_per_epoch)): batch_start = time.time() mel = tf.transpose(mel, [0, 2, 1]) mel_input = tf.concat([ tf.zeros(shape=(mel.shape[0], 1, num_mel), dtype=tf.float32), mel[:, :-1, :] ], axis=1) batch_loss, mel_outputs = _train_step(encoder, decoder, optimizer, sentence, mel, mel_input, stop_token) total_loss += batch_loss print('\r{}/{} [Batch {} Loss {:.4f} {:.1f}s]'.format( (batch + 1), steps_per_epoch, batch + 1, batch_loss.numpy(), (time.time() - batch_start)), end="") print(' - {:.0f}s/step - loss: {:.4f}'.format( (time.time() - start_time) / steps_per_epoch, total_loss / steps_per_epoch)) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0: print("验证数据量过小,小于batch_size,请添加数据后重试") exit(0) _valid_step(encoder=encoder, decoder=decoder, dataset=valid_dataset, num_mel=num_mel, steps_per_epoch=valid_steps_per_epoch) return mel_outputs
def train(encoder: tf.keras.Model, decoder: tf.keras.Model, optimizer: tf.keras.optimizers.Adam, epochs: int, checkpoint: tf.train.CheckpointManager, train_data_path: str, max_len: int, vocab_size: int, batch_size: int, buffer_size: int, checkpoint_save_freq: int, dict_path: str = "", valid_data_split: float = 0.0, valid_data_path: str = "", max_train_data_size: int = 0, max_valid_data_size: int = 0): """ 训练模块 :param encoder: 模型的encoder :param decoder: 模型的decoder :param optimizer: 优化器 :param checkpoint: 检查点管理器 :param epochs: 训练周期 :param train_data_path: 文本数据路径 :param max_len: 文本序列最大长度 :param vocab_size: 词汇大小 :param buffer_size: Dataset加载缓存大小 :param batch_size: Dataset加载批大小 :param dict_path: 字典路径,若使用phoneme则不用传 :param valid_data_split: 用于从训练数据中划分验证数据 :param valid_data_path: 验证数据文本路径 :param max_train_data_size: 最大训练数据量 :param max_valid_data_size: 最大验证数据量 :param checkpoint_save_freq: 检查点保存频率 """ _, train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \ load_data(train_data_path=train_data_path, max_len=max_len, vocab_size=vocab_size, batch_size=batch_size, buffer_size=buffer_size, dict_path=dict_path, valid_data_split=valid_data_split, valid_data_path=valid_data_path, max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size) if steps_per_epoch == 0: print("训练数据量过小,小于batch_size,请添加数据后重试") exit(0) for epoch in range(epochs): print('Epoch {}/{}'.format(epoch + 1, epochs)) start_time = time.time() total_loss = 0 for (batch, (audio_feature, sentence)) in enumerate(train_dataset.take(steps_per_epoch)): batch_start = time.time() sentence_input = sentence[:, :-1] sentence_real = sentence[:, 1:] batch_loss, sentence_predictions = _train_step( encoder, decoder, optimizer, sentence_input, sentence_real, audio_feature) total_loss += batch_loss print('\r{}/{} [Batch {} Loss {:.4f} {:.1f}s]'.format( (batch + 1), steps_per_epoch, batch + 1, batch_loss.numpy(), (time.time() - batch_start)), end="") print(' - {:.0f}s/step - loss: {:.4f}'.format( (time.time() - start_time) / steps_per_epoch, total_loss / steps_per_epoch)) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0: print("验证数据量过小,小于batch_size,请添加数据后重试") exit(0) _valid_step(encoder=encoder, decoder=decoder, dataset=valid_dataset, steps_per_epoch=valid_steps_per_epoch)
def train(self, optimizer: tf.optimizers.Adam, checkpoint: tf.train.CheckpointManager, train_data_path: str, epochs: int, checkpoint_save_freq: int, valid_data_split: float = 0.0, max_train_data_size: int = 0, valid_data_path: str = "", max_valid_data_size: int = 0, history: dict = {}, **kwargs) -> Dict: """ 训练模块 :param optimizer: 优化器 :param checkpoint: 检查点管理器 :param train_data_path: 文本数据路径 :param epochs: 训练周期 :param checkpoint_save_freq: 检查点保存频率 :param valid_data_split: 用于从训练数据中划分验证数据 :param max_train_data_size: 最大训练数据量 :param valid_data_path: 验证数据文本路径 :param max_valid_data_size: 最大验证数据量 :param history: 用于保存训练过程中的历史指标数据 :return: 返回历史指标数据 """ print("训练开始,正在准备数据中") train_dataset, valid_dataset, train_steps_per_epoch, valid_steps_per_epoch = load_data( dict_path=self.dict_path, train_data_path=train_data_path, buffer_size=self.buffer_size, batch_size=self.batch_size, max_sentence=self.max_sentence, valid_data_split=valid_data_split, valid_data_path=valid_data_path, max_train_data_size=max_train_data_size, valid_data_type=self.valid_data_type, max_valid_data_size=max_valid_data_size, train_data_type=self.train_data_type, **kwargs) progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() self.loss_metric.reset_states() self.accuracy_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=self.batch_size) for (batch, batch_dataset) in enumerate( train_dataset.take(train_steps_per_epoch)): train_metrics = self._train_step(batch_dataset=batch_dataset, optimizer=optimizer, **kwargs) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metrics)) progress_bar.done(step_time=time.time() - start_time) for key, value in train_metrics.items(): history[key].append(value) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: valid_metrics = self._valid_step( dataset=valid_dataset, progress_bar=progress_bar, steps_per_epoch=valid_steps_per_epoch, **kwargs) for key, value in valid_metrics.items(): history[key].append(value) print("训练结束") self._save_model(**kwargs) return history
def train(model: tf.keras.Model, optimizer: tf.keras.optimizers.Adam, epochs: int, checkpoint: tf.train.CheckpointManager, train_data_path: str, batch_size: int, buffer_size: int, checkpoint_save_freq: int, dict_path: str = "", valid_data_split: float = 0.0, valid_data_path: str = "", train_length_path: str = "", valid_length_path: str = "", stop_early_limits: int = 0, max_train_data_size: int = 0, max_valid_data_size: int = 0, history_img_path: str = ""): """ 训练模块 :param model: 模型 :param optimizer: 优化器 :param checkpoint: 检查点管理器 :param epochs: 训练周期 :param train_data_path: 文本数据路径 :param buffer_size: Dataset加载缓存大小 :param batch_size: Dataset加载批大小 :param dict_path: 字典路径,若使用phoneme则不用传 :param valid_data_split: 用于从训练数据中划分验证数据 :param valid_data_path: 验证数据文本路径 :param max_train_data_size: 最大训练数据量 :param train_length_path: 训练样本长度保存路径 :param valid_length_path: 验证样本长度保存路径 :param stop_early_limits: 不增长停止个数 :param max_valid_data_size: 最大验证数据量 :param checkpoint_save_freq: 检查点保存频率 :param history_img_path: 历史指标数据图表保存路径 :return: 返回历史指标数据 """ train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \ load_data(train_data_path=train_data_path, batch_size=batch_size, buffer_size=buffer_size, valid_data_split=valid_data_split, valid_data_path=valid_data_path, train_length_path=train_length_path, valid_length_path=valid_length_path, max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size) tokenizer = load_tokenizer(dict_path=dict_path) history = {"loss": [], "wers": [], "norm_lers": []} if steps_per_epoch == 0: print("训练数据量过小,小于batch_size,请添加数据后重试") exit(0) for epoch in range(epochs): print('Epoch {}/{}'.format(epoch + 1, epochs)) start_time = time.time() total_loss = 0 for (batch, (audio_feature, sentence, length)) in enumerate(train_dataset.take(steps_per_epoch)): batch_start = time.time() batch_loss = _train_step(model, optimizer, sentence, length, audio_feature) total_loss += batch_loss print('\r{}/{} [Batch {} Loss {:.4f} {:.1f}s]'.format( (batch + 1), steps_per_epoch, batch + 1, batch_loss.numpy(), (time.time() - batch_start)), end="") print(' - {:.0f}s/step - loss: {:.4f}'.format( (time.time() - start_time) / steps_per_epoch, total_loss / steps_per_epoch)) history["loss"].append(total_loss / steps_per_epoch) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0: print("验证数据量过小,小于batch_size,请添加数据后重试") exit(0) valid_loss, valid_wer, valid_ler = _valid_step( model=model, dataset=valid_dataset, steps_per_epoch=valid_steps_per_epoch, tokenizer=tokenizer) history["wers"].append(valid_wer) history["norm_lers"].append(valid_ler) if stop_early_limits != 0 and len( history["wers"]) >= stop_early_limits: if can_stop(history["wers"][-stop_early_limits:]) \ or can_stop(history["norm_lers"][-stop_early_limits:]): print("指标反弹,停止训练!") break plot_history(history=history, valid_epoch_freq=checkpoint_save_freq, history_img_path=history_img_path) return history
def train(self, epochs, hpars_optimizer, lr_scheduler, kl_scheduler, monitor: Monitor, ckpt_manager: tf.train.CheckpointManager, ckpt_period=1, use_natural=True, clip_natgrad_value=None, first_epoch=1): bar = tqdm(total=self.experiment.len_train_data) avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf_floatx()) avg_breaked_loss = tf.keras.metrics.MeanTensor(name='breaked_loss', dtype=tf_floatx()) optimizer = None if use_natural else self.default_optimizer() self._optimizer_ = optimizer # TODO: initial monitoring if monitor: try: epoch = 0 kl_weight = kl_scheduler(epoch) for y in self.experiment.train_dataset: loss, breaked_loss = self._initial_monitor_on_batch( y, epoch, hpars_optimizer, kl_weight, lr_scheduler) avg_loss.update_state(loss) avg_breaked_loss.update_state(breaked_loss) bar.update() monitor(epoch, epoch=epoch, kl_scheduler=kl_scheduler, train_loss=avg_loss.result(), train_bloss=avg_breaked_loss.result()) avg_loss.reset_states() avg_breaked_loss.reset_states() except: pass for epoch in tf.range(first_epoch, epochs + 1, dtype=tf_floatx()): bar.reset() bar.set_description(f'Epoch {epoch}') kl_weight = kl_scheduler(epoch) for y in self.experiment.train_dataset: loss, breaked_loss = self._train_on_batch( y, epoch, hpars_optimizer, kl_weight, lr_scheduler, use_natural, optimizer, clip_value=clip_natgrad_value) avg_loss.update_state(loss) avg_breaked_loss.update_state(breaked_loss) bar.update() if monitor: try: monitor(epoch, epoch=epoch, kl_scheduler=kl_scheduler, train_loss=avg_loss.result(), train_bloss=avg_breaked_loss.result()) except: pass avg_loss.reset_states() avg_breaked_loss.reset_states() if ckpt_manager and epoch % ckpt_period == 0: ckpt_manager.save()
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any, epochs: Any, embedding_dim: Any, train_data_path: AnyStr, valid_data_path: AnyStr, max_sentence_length: Any, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param buffer_size: 缓冲大小 :param epochs: 训练周期 :param embedding_dim: 词嵌入大小 :param train_data_path: 训练数据保存路径 :param valid_data_path: 验证数据保存路径 :param max_sentence_length: 最大句子对长度 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") learning_rate = CustomSchedule(d_model=embedding_dim) loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, name="optimizer") train_dataset = load_raw_dataset(data_path=train_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="train") valid_dataset = load_raw_dataset(data_path=valid_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="valid") train_steps_per_epoch = max_train_steps if max_train_steps != -1 else ( 90000 // batch_size) valid_steps_per_epoch = 10000 // batch_size progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None result, targets = tf.convert_to_tensor( [], dtype=tf.float32), tf.convert_to_tensor([], dtype=tf.int32) for (batch, (queries, _, outputs, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric, prediction = _train_step(model=model, optimizer=optimizer, batch_size=batch_size, loss_metric=loss_metric, queries=queries, targets=outputs) result = tf.concat([result, prediction[:, 1]], axis=0) targets = tf.concat([targets, labels], axis=0) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) auc_score = roc_auc_score(y_true=targets, y_score=result) train_metric["train_auc"] = auc_score progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, batch_size=batch_size, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, buffer_size: Any, epochs: Any, embedding_dim: Any, train_data_path: AnyStr, valid_data_path: AnyStr, max_sentence_length=Any, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param buffer_size: 缓冲大小 :param epochs: 训练周期 :param embedding_dim: 词嵌入大小 :param train_data_path: 训练数据文件保存路径 :param valid_data_path: 验证数据文件保存路径 :param max_sentence_length: 最大句子长度 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") # learning_rate = CustomSchedule(d_model=embedding_dim) loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.98, name="optimizer") train_dataset = load_pair_dataset(data_path=train_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="train") valid_dataset = load_pair_dataset(data_path=valid_data_path, max_sentence_length=max_sentence_length, batch_size=batch_size, buffer_size=buffer_size, data_type="valid") train_steps_per_epoch = max_train_steps if max_train_steps != -1 else (206588 // batch_size) valid_steps_per_epoch = 10000 // batch_size progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None for (batch, (first_queries, second_queries, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric, prediction = _train_step( model=model, optimizer=optimizer, loss_metric=loss_metric, first_queries=first_queries, second_queries=second_queries, labels=labels ) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}
def train_model(batcher, model: Seq2Seq, ckpt_manager: tf.train.CheckpointManager) -> float: # 训练参数 epochs = config.epochs batch_size = config.batch_size learning_rate = config.lr max_enc_len = config.max_enc_steps max_dec_len = config.max_dec_steps # 优化器 optimizer = tf.keras.optimizers.Adam(name='Adam', learning_rate=learning_rate) # 损失函数 # loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none') # 定义损失函数 # @tf.function # def loss_func(real: tf.Tensor, pred: tf.Tensor) -> float: # # 计算 损失 # loss = loss_object(real, pred) # # 忽略 pad 和 unk # pad_mask = tf.math.equal(real, pad_idx) # # unk_mask = tf.math.equal(real, unk_idx) # # mask = tf.math.logical_not(tf.math.logical_or(pad_mask, unk_mask)) # # unk mask会导致预测结果中缺少UNK项,故而去掉,只保留pad mask # mask = tf.math.logical_not(pad_mask) # mask = tf.cast(mask, dtype=loss.dtype) # loss *= mask # # 返回按 平均的 loss # return tf.reduce_mean(loss) # 训练 # @tf.function def train_step(batcher) -> float: with tf.GradientTape() as tape: # 获取训练输入数据 enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = get_input_from_batch( batcher) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = get_output_from_batch( batcher) # # 解码器的首次输入 = 句子开始标签 # dec_input = tf.expand_dims([bos_idx] * dec_target.shape[0], 1) # 计算输出 step_losses, coverage = model.teacher_forcing_decode( enc_batch, dec_batch, enc_padding_mask, dec_padding_mask, extra_zeros, enc_batch_extend_vocab, coverage) # 计算 batch loss (句子开始标识不计算在内) sum_losses = tf.reduce_sum(tf.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var batch_loss = tf.reduce_mean(batch_avg_loss) # print(batch_loss) # 可训练参数 variables = model.trainable_variables # 计算梯度 gradients = tape.gradient(batch_loss, variables) # 更新梯度 optimizer.apply_gradients(zip(gradients, variables)) # 返回 batch loss return batch_loss # 构建训练数据集 # 进行训练 epoch_loss = -1 train_loss = [] eval_loss = [] for epoch in tqdm(range(epochs)): start = time.time() total_loss = 0. # 遍历所有的Batch with tqdm(range(config.step_per_epoch), unit='batch', desc='训练进度') as pbar: for step in pbar: batch = batcher.next_batch() # batch=batcher.next_batch() batch_loss = train_step(batch) total_loss += batch_loss pbar.set_postfix({ 'size': '{:d}'.format(batch_size), 'loss': '{:.6f}'.format(batch_loss), 'average': '{:.6f}'.format(total_loss / (step + 1)) }) pbar.update(1) # 定期保存模型: 每个Epoch path_ckpt_save = ckpt_manager.save() # 计算平均loss: epoch_loss = total_loss / config.step_per_epoch #steps_per_epoch logger.info('Epoch {:3d}: 新存档点保存在 {}'.format(epoch + 1, path_ckpt_save)) logger.info('Epoch {:3d}: 训练花费时间为 {:.2f} 分钟, Loss = {:.6f}'.format( epoch + 1, (time.time() - start) / 60., epoch_loss)) print('Epoch:', epoch + 1, '\ntrain_loss:', epoch_loss.numpy()) train_loss.append(epoch_loss.numpy()) eval_epoch_loss = evaluate() eval_loss.append(eval_epoch_loss.numpy()) print('train_loss:', train_loss) print('eval_loss:', eval_loss) # 返回loss return epoch_loss
def train(epochs: int, train_data_path: str, batch_size: int, buffer_size: int, checkpoint_save_freq: int, checkpoint: tf.train.CheckpointManager, model: tf.keras.Model, optimizer: tf.keras.optimizers.Adam, dict_path: str = "", valid_data_split: float = 0.0, valid_data_path: str = "", train_length_path: str = "", valid_length_path: str = "", max_train_data_size: int = 0, max_valid_data_size: int = 0, history_img_path: str = ""): """ 训练模块 :param epochs: 训练周期 :param train_data_path: 文本数据路径 :param dict_path: 字典路径,若使用phoneme则不用传 :param buffer_size: Dataset加载缓存大小 :param batch_size: Dataset加载批大小 :param checkpoint: 检查点管理器 :param model: 模型 :param optimizer: 优化器 :param valid_data_split: 用于从训练数据中划分验证数据 :param valid_data_path: 验证数据文本路径 :param max_train_data_size: 最大训练数据量 :param train_length_path: 训练样本长度保存路径 :param valid_length_path: 验证样本长度保存路径 :param max_valid_data_size: 最大验证数据量 :param checkpoint_save_freq: 检查点保存频率 :param history_img_path: 历史指标数据图表保存路径 :return: """ train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \ load_data(train_data_path=train_data_path, batch_size=batch_size, buffer_size=buffer_size, valid_data_split=valid_data_split, valid_data_path=valid_data_path, train_length_path=train_length_path, valid_length_path=valid_length_path, max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size) tokenizer = load_tokenizer(dict_path=dict_path) history = {"loss": [], "wers": [], "norm_lers": []} if steps_per_epoch == 0: print("训练数据量过小,小于batch_size,请添加数据后重试") exit(0) for epoch in range(epochs): total_loss = 0 start_time = time.time() enc_hidden = model.initialize_hidden_state() dec_input = tf.cast(tf.expand_dims( [tokenizer.word_index.get('<start>')] * batch_size, 1), dtype=tf.int64) print("Epoch {}/{}".format(epoch + 1, epochs)) for (batch, (audio_feature, sentence, length)) in enumerate(train_dataset.take(steps_per_epoch)): batch_start = time.time() batch_loss = _train_step(model, optimizer, audio_feature, sentence, enc_hidden, dec_input) total_loss += batch_loss print('\r{}/{} [Batch {} Loss {:.4f} {:.1f}s]'.format( (batch + 1), steps_per_epoch, batch + 1, batch_loss.numpy(), (time.time() - batch_start)), end="") print(' - {:.0f}s/step - loss: {:.4f}'.format( (time.time() - start_time) / steps_per_epoch, total_loss / steps_per_epoch)) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0: print("验证数据量过小,小于batch_size,请添加数据后重试") exit(0) valid_loss, valid_wer, valid_ler = _valid_step( model=model, dataset=valid_dataset, enc_hidden=enc_hidden, dec_input=dec_input, steps_per_epoch=valid_steps_per_epoch, tokenizer=tokenizer) history["wers"].append(valid_wer) history["norm_lers"].append(valid_ler) plot_history(history=history, valid_epoch_freq=checkpoint_save_freq, history_img_path=history_img_path) return history
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, epochs: Any, train_dataset: Any, valid_dataset: AnyStr = None, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param epochs: 训练周期 :param train_dataset: 训练数据集 :param valid_dataset: 验证数据集 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") # learning_rate = CustomSchedule(d_model=embedding_dim) loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=2e-5, beta_1=0.9, beta_2=0.999, name="optimizer") train_steps_per_epoch = max_train_steps if max_train_steps != -1 else ( 40000 // batch_size) valid_steps_per_epoch = 3944 // batch_size progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None for (batch, (train_enc, train_dec, month_enc, month_dec, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric, prediction = _train_step(model=model, optimizer=optimizer, loss_metric=loss_metric, train_enc=train_enc, train_dec=train_dec, month_enc=month_enc, month_dec=month_dec, labels=labels) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}
def fitTrainData(model: tf.keras.Model, optimizer: tf.keras.optimizers, metrics: List[tf.keras.metrics.Mean], lossFunc, PSNRFunc, X: np.ma.array, y: np.ma.array, batchSize: int, epochs: int, bufferSize: int, valData: List[np.ma.array], valSteps: int, checkpoint: tf.train.Checkpoint, checkpointManager: tf.train.CheckpointManager, logDir: str, ckptDir: str, saveBestOnly: bool): trainSet = loadTrainDataAsTFDataSet(X, y[0], y[1], epochs, batchSize, bufferSize) valSet = loadValDataAsTFDataSet(valData[0], valData[1], valData[2], valSteps, batchSize, bufferSize) # Logger w = tf.summary.create_file_writer(logDir) dataSetLength = len(X) totalSteps = tf.cast(dataSetLength / batchSize, tf.int64) globalStep = tf.cast(checkpoint.step, tf.int64) step = globalStep % totalSteps epoch = 0 # Metrics trainLoss, trainPSNR, testLoss, testPSNR = metrics with w.as_default(): for x_batch_train, y_batch_train, y_mask_batch_train in trainSet: if (totalSteps - step) == 0: epoch += 1 step = globalStep % totalSteps logger.info('Start of epoch %d' % (epoch)) # Reset metrics trainLoss.reset_states() trainPSNR.reset_states() testLoss.reset_states() testPSNR.reset_states() step += 1 globalStep += 1 trainStep(x_batch_train, y_batch_train, y_mask_batch_train, checkpoint, lossFunc, PSNRFunc, trainLoss, trainPSNR) checkpoint.step.assign_add(1) t = f"step {step}/{int(totalSteps)}, loss: {trainLoss.result():.3f}, psnr: {trainPSNR.result():.3f}" logger.info(t) tf.summary.scalar('Train PSNR', trainPSNR.result(), step=globalStep) tf.summary.scalar('Train loss', trainLoss.result(), step=globalStep) if step != 0 and (step % opt.evalTestStep) == 0: # Reset states for test testLoss.reset_states() testPSNR.reset_states() for x_batch_val, y_batch_val, y_mask_batch_val in valSet: testStep(x_batch_val, y_batch_val, y_mask_batch_val, checkpoint, lossFunc, PSNRFunc, testLoss, testPSNR) tf.summary.scalar('Test loss', testLoss.result(), step=globalStep) tf.summary.scalar('Test PSNR', testPSNR.result(), step=globalStep) t = f"Validation results... val_loss: {testLoss.result():.3f}, val_psnr: {testPSNR.result():.3f}" logger.info(t) w.flush() if saveBestOnly and (testPSNR.result() <= checkpoint.psnr): continue checkpoint.psnr = testPSNR.result() checkpointManager.save()
def train(epochs: int, train_data_path: str, max_len: int, vocab_size: int, batch_size: int, buffer_size: int, checkpoint_save_freq: int, checkpoint: tf.train.CheckpointManager, model: tf.keras.Model, optimizer: tf.keras.optimizers.Adam, tokenized_type: str = "phoneme", dict_path: str = "", valid_data_split: float = 0.0, valid_data_path: str = "", max_train_data_size: int = 0, max_valid_data_size: int = 0): """ 训练模块 :param epochs: 训练周期 :param train_data_path: 文本数据路径 :param max_len: 文本序列最大长度 :param vocab_size: 词汇大小 :param tokenized_type: 分词类型,默认按音素分词,模式:phoneme(音素)/word(单词)/char(字符) :param dict_path: 字典路径,若使用phoneme则不用传 :param buffer_size: Dataset加载缓存大小 :param batch_size: Dataset加载批大小 :param checkpoint: 检查点管理器 :param model: 模型 :param optimizer: 优化器 :param valid_data_split: 用于从训练数据中划分验证数据 :param valid_data_path: 验证数据文本路径 :param max_train_data_size: 最大训练数据量 :param max_valid_data_size: 最大验证数据量 :param checkpoint_save_freq: 检查点保存频率 :return: """ train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \ _dataset.load_data(train_data_path=train_data_path, max_len=max_len, vocab_size=vocab_size, batch_size=batch_size, buffer_size=buffer_size, tokenized_type=tokenized_type, dict_path=dict_path, valid_data_split=valid_data_split, valid_data_path=valid_data_path, max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size) for epoch in range(epochs): print('Epoch {}/{}'.format(epoch + 1, epochs)) start_time = time.time() total_loss = 0 for (batch, (mel, stop_token, sentence)) in enumerate(train_dataset.take(steps_per_epoch)): batch_start = time.time() batch_loss, mel_outputs = _train_step(model, optimizer, sentence, mel, stop_token) # 训练一个批次,返回批损失 total_loss += batch_loss print('\r{}/{} [Batch {} Loss {:.4f} {:.1f}s]'.format( (batch + 1), steps_per_epoch, batch + 1, batch_loss.numpy(), (time.time() - batch_start)), end='') print(' - {:.0f}s/step - loss: {:.4f}'.format( (time.time() - start_time) / steps_per_epoch, total_loss / steps_per_epoch)) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() _valid_step(model=model, dataset=valid_dataset, steps_per_epoch=valid_steps_per_epoch) return mel_outputs
def train(epochs: int, train_data_path: str, max_len: int, vocab_size: int, batch_size: int, buffer_size: int, checkpoint_save_freq: int, checkpoint: tf.train.CheckpointManager, model: tf.keras.Model, optimizer: tf.keras.optimizers.Adam, dict_path: str = "", valid_data_split: float = 0.0, valid_data_path: str = "", max_train_data_size: int = 0, max_valid_data_size: int = 0): """ 训练模块 :param epochs: 训练周期 :param train_data_path: 文本数据路径 :param max_len: 文本序列最大长度 :param vocab_size: 词汇大小 :param dict_path: 字典路径,若使用phoneme则不用传 :param buffer_size: Dataset加载缓存大小 :param batch_size: Dataset加载批大小 :param checkpoint: 检查点管理器 :param model: 模型 :param optimizer: 优化器 :param valid_data_split: 用于从训练数据中划分验证数据 :param valid_data_path: 验证数据文本路径 :param max_train_data_size: 最大训练数据量 :param max_valid_data_size: 最大验证数据量 :param checkpoint_save_freq: 检查点保存频率 :return: """ tokenizer, train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \ load_data(train_data_path=train_data_path, max_len=max_len, vocab_size=vocab_size, batch_size=batch_size, buffer_size=buffer_size, dict_path=dict_path, valid_data_split=valid_data_split, valid_data_path=valid_data_path, max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size) for epoch in range(epochs): start = time.time() enc_hidden = model.initialize_hidden_state() total_loss = 0 batch_start = time.time() print("Epoch {}/{}".format(epoch + 1, epochs)) for (batch, (audio_feature, sentence)) in enumerate(train_dataset.take(steps_per_epoch)): batch_loss = _train_step(audio_feature, sentence, enc_hidden, tokenizer, model, optimizer, batch_size) total_loss += batch_loss print('Epoch {} Batch {} Loss {:.4f} - {:.4f} sec'.format( epoch + 1, batch, batch_loss.numpy(), time.time() - batch_start)) batch_start = time.time() print('Epoch {} Loss {:.4f} - {:.4f} sec'.format( epoch + 1, total_loss / steps_per_epoch, time.time() - start)) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save()