Exemple #1
0
    def save_checkpoint(self,
                        checkpoint_dir: str = "./training_checkpoint",
                        checkpoint: tf.train.Checkpoint = None) -> None:

        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
        if not checkpoint:
            checkpoint = tf.train.Checkpoint(
                generator=self.generator,
                discriminator=self.discriminator,
                discriminator_optimizer=self.discriminator_optimizer,
                generator_optimizer=self.generator_optimizer)

        checkpoint.save(checkpoint_prefix)
Exemple #2
0
def _save_checkpoint(checkpoint: tf.train.Checkpoint, checkpoint_prefix) -> None:
    checkpoint.save(file_prefix=checkpoint_prefix)
Exemple #3
0
    def train(self,
              checkpoint: tf.train.Checkpoint,
              dict_fn: str,
              data_fn: str,
              batch_size: int,
              buffer_size: int,
              max_train_data_size: int,
              epochs: int,
              max_valid_data_size: int,
              checkpoint_save_freq: int,
              checkpoint_save_size: int,
              save_dir: str,
              valid_data_split: float = 0.0,
              valid_data_fn: str = "",
              valid_freq: int = 1):
        """
        对模型进行训练,验证数据集优先级为:预设验证文本>训练划分文本>无验证
        Args:
            checkpoint: 模型的检查点
            dict_fn: 字典路径
            data_fn: 数据文本路径
            buffer_size: Dataset加载缓存大小
            batch_size: Dataset加载批大小
            max_train_data_size: 最大训练数据量
            epochs: 执行训练轮数
            checkpoint_save_freq: 检查点保存频率
            checkpoint_save_size: 检查点最大保存数
            save_dir: 历史指标显示图片保存位置
            max_valid_data_size: 最大验证数据量
            valid_data_split: 用于从训练数据中划分验证数据,默认0.1
            valid_data_fn: 验证数据文本路径
            valid_freq: 验证频率
        Returns:
        """
        print('训练开始,正在准备数据中...')
        train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch, checkpoint_prefix = \
            data_utils.load_data(dict_fn=dict_fn, data_fn=data_fn, start_sign=self.start_sign,
                                 buffer_size=buffer_size, batch_size=batch_size,
                                 end_sign=self.end_sign, checkpoint_dir=self.checkpoint_dir,
                                 max_length=self.max_length, valid_data_split=valid_data_split,
                                 valid_data_fn=valid_data_fn, max_train_data_size=max_train_data_size,
                                 max_valid_data_size=max_valid_data_size)

        valid_epochs_count = 0  # 用于记录验证轮次
        checkpoint_queue = deque(maxlen=checkpoint_save_size +
                                 1)  # 用于保存该次训练产生的检查点名
        history = {
            'accuracy': [],
            'loss': [],
            'val_accuracy': [],
            'val_loss': []
        }

        for epoch in range(epochs):
            valid_epochs_count += 1
            print('Epoch {}/{}'.format(epoch + 1, epochs))
            start_time = time.time()
            self._init_loss_accuracy()

            step_loss = 0
            step_accuracy = 0
            batch_sum = 0
            sample_sum = 0

            for (batch,
                 (inp, tar,
                  weight)) in enumerate(train_dataset.take(steps_per_epoch)):
                step_loss, step_accuracy = self._train_step(inp, tar, weight)
                batch_sum = batch_sum + len(inp)
                sample_sum = steps_per_epoch * len(inp)
                print('\r',
                      '{}/{} [==================================]'.format(
                          batch_sum, sample_sum),
                      end='',
                      flush=True)

            step_time = (time.time() - start_time)
            history['accuracy'].append(step_accuracy.numpy())
            history['loss'].append(step_loss.numpy())

            sys.stdout.write(
                ' - {:.4f}s/step - train_loss: {:.4f} - train_accuracy: {:.4f}\n'
                .format(step_time, step_loss, step_accuracy))
            sys.stdout.flush()

            if valid_epochs_count % checkpoint_save_freq == 0:
                checkpoint.save(file_prefix=checkpoint_prefix)
                checkpoint_queue.append(
                    tf.train.latest_checkpoint(
                        checkpoint_dir=self.checkpoint_dir))
                if len(checkpoint_queue) == checkpoint_save_size:
                    checkpoint_name = checkpoint_queue[0]
                    os.remove(checkpoint_name + '.index')
                    os.remove(checkpoint_name + '.data-00000-of-00001')

            if valid_dataset is not None and valid_epochs_count % valid_freq == 0:
                valid_loss, valid_accuracy = self._valid_step(
                    valid_dataset=valid_dataset,
                    steps_per_epoch=valid_steps_per_epoch)
                history['val_accuracy'].append(valid_accuracy.numpy())
                history['val_loss'].append(valid_loss.numpy())

        self._show_history(history=history,
                           save_dir=save_dir,
                           valid_freq=valid_freq)
        print('训练结束')
        return history
Exemple #4
0
def save_checkpoint(ckpt: tf.train.Checkpoint, ckpt_path: str):
    dir_name = os.path.dirname(ckpt_path)
    os.makedirs(dir_name, exist_ok=True)
    ckpt.save(ckpt_path)