Пример #1
0
    def save_checkpoint(self,
                        checkpoint_dir: str = "./training_checkpoint",
                        checkpoint: tf.train.Checkpoint = None) -> None:

        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
        if not checkpoint:
            checkpoint = tf.train.Checkpoint(
                generator=self.generator,
                discriminator=self.discriminator,
                discriminator_optimizer=self.discriminator_optimizer,
                generator_optimizer=self.generator_optimizer)

        checkpoint.save(checkpoint_prefix)
Пример #2
0
def _save_checkpoint(checkpoint: tf.train.Checkpoint, checkpoint_prefix) -> None:
    checkpoint.save(file_prefix=checkpoint_prefix)
Пример #3
0
    def train(self,
              checkpoint: tf.train.Checkpoint,
              dict_fn: str,
              data_fn: str,
              batch_size: int,
              buffer_size: int,
              max_train_data_size: int,
              epochs: int,
              max_valid_data_size: int,
              checkpoint_save_freq: int,
              checkpoint_save_size: int,
              save_dir: str,
              valid_data_split: float = 0.0,
              valid_data_fn: str = "",
              valid_freq: int = 1):
        """
        对模型进行训练,验证数据集优先级为:预设验证文本>训练划分文本>无验证
        Args:
            checkpoint: 模型的检查点
            dict_fn: 字典路径
            data_fn: 数据文本路径
            buffer_size: Dataset加载缓存大小
            batch_size: Dataset加载批大小
            max_train_data_size: 最大训练数据量
            epochs: 执行训练轮数
            checkpoint_save_freq: 检查点保存频率
            checkpoint_save_size: 检查点最大保存数
            save_dir: 历史指标显示图片保存位置
            max_valid_data_size: 最大验证数据量
            valid_data_split: 用于从训练数据中划分验证数据,默认0.1
            valid_data_fn: 验证数据文本路径
            valid_freq: 验证频率
        Returns:
        """
        print('训练开始,正在准备数据中...')
        train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch, checkpoint_prefix = \
            data_utils.load_data(dict_fn=dict_fn, data_fn=data_fn, start_sign=self.start_sign,
                                 buffer_size=buffer_size, batch_size=batch_size,
                                 end_sign=self.end_sign, checkpoint_dir=self.checkpoint_dir,
                                 max_length=self.max_length, valid_data_split=valid_data_split,
                                 valid_data_fn=valid_data_fn, max_train_data_size=max_train_data_size,
                                 max_valid_data_size=max_valid_data_size)

        valid_epochs_count = 0  # 用于记录验证轮次
        checkpoint_queue = deque(maxlen=checkpoint_save_size +
                                 1)  # 用于保存该次训练产生的检查点名
        history = {
            'accuracy': [],
            'loss': [],
            'val_accuracy': [],
            'val_loss': []
        }

        for epoch in range(epochs):
            valid_epochs_count += 1
            print('Epoch {}/{}'.format(epoch + 1, epochs))
            start_time = time.time()
            self._init_loss_accuracy()

            step_loss = 0
            step_accuracy = 0
            batch_sum = 0
            sample_sum = 0

            for (batch,
                 (inp, tar,
                  weight)) in enumerate(train_dataset.take(steps_per_epoch)):
                step_loss, step_accuracy = self._train_step(inp, tar, weight)
                batch_sum = batch_sum + len(inp)
                sample_sum = steps_per_epoch * len(inp)
                print('\r',
                      '{}/{} [==================================]'.format(
                          batch_sum, sample_sum),
                      end='',
                      flush=True)

            step_time = (time.time() - start_time)
            history['accuracy'].append(step_accuracy.numpy())
            history['loss'].append(step_loss.numpy())

            sys.stdout.write(
                ' - {:.4f}s/step - train_loss: {:.4f} - train_accuracy: {:.4f}\n'
                .format(step_time, step_loss, step_accuracy))
            sys.stdout.flush()

            if valid_epochs_count % checkpoint_save_freq == 0:
                checkpoint.save(file_prefix=checkpoint_prefix)
                checkpoint_queue.append(
                    tf.train.latest_checkpoint(
                        checkpoint_dir=self.checkpoint_dir))
                if len(checkpoint_queue) == checkpoint_save_size:
                    checkpoint_name = checkpoint_queue[0]
                    os.remove(checkpoint_name + '.index')
                    os.remove(checkpoint_name + '.data-00000-of-00001')

            if valid_dataset is not None and valid_epochs_count % valid_freq == 0:
                valid_loss, valid_accuracy = self._valid_step(
                    valid_dataset=valid_dataset,
                    steps_per_epoch=valid_steps_per_epoch)
                history['val_accuracy'].append(valid_accuracy.numpy())
                history['val_loss'].append(valid_loss.numpy())

        self._show_history(history=history,
                           save_dir=save_dir,
                           valid_freq=valid_freq)
        print('训练结束')
        return history
Пример #4
0
def save_checkpoint(ckpt: tf.train.Checkpoint, ckpt_path: str):
    dir_name = os.path.dirname(ckpt_path)
    os.makedirs(dir_name, exist_ok=True)
    ckpt.save(ckpt_path)
Пример #5
0
def fitTrainData(model: tf.keras.Model, optimizer: tf.keras.optimizers,
                 metrics: List[tf.keras.metrics.Mean], lossFunc, PSNRFunc,
                 X: np.ma.array, y: np.ma.array, batchSize: int, epochs: int,
                 bufferSize: int, valData: List[np.ma.array], valSteps: int,
                 checkpoint: tf.train.Checkpoint,
                 checkpointManager: tf.train.CheckpointManager, logDir: str,
                 ckptDir: str, saveBestOnly: bool):

    trainSet = loadTrainDataAsTFDataSet(X, y[0], y[1], epochs, batchSize,
                                        bufferSize)
    valSet = loadValDataAsTFDataSet(valData[0], valData[1], valData[2],
                                    valSteps, batchSize, bufferSize)

    # Logger
    w = tf.summary.create_file_writer(logDir)

    dataSetLength = len(X)
    totalSteps = tf.cast(dataSetLength / batchSize, tf.int64)
    globalStep = tf.cast(checkpoint.step, tf.int64)
    step = globalStep % totalSteps
    epoch = 0

    # Metrics
    trainLoss, trainPSNR, testLoss, testPSNR = metrics

    with w.as_default():
        for x_batch_train, y_batch_train, y_mask_batch_train in trainSet:
            if (totalSteps - step) == 0:
                epoch += 1
                step = globalStep % totalSteps
                logger.info('Start of epoch %d' % (epoch))
                # Reset metrics
                trainLoss.reset_states()
                trainPSNR.reset_states()
                testLoss.reset_states()
                testPSNR.reset_states()

            step += 1
            globalStep += 1
            trainStep(x_batch_train, y_batch_train, y_mask_batch_train,
                      checkpoint, lossFunc, PSNRFunc, trainLoss, trainPSNR)
            checkpoint.step.assign_add(1)

            t = f"step {step}/{int(totalSteps)}, loss: {trainLoss.result():.3f}, psnr: {trainPSNR.result():.3f}"
            logger.info(t)

            tf.summary.scalar('Train PSNR',
                              trainPSNR.result(),
                              step=globalStep)

            tf.summary.scalar('Train loss',
                              trainLoss.result(),
                              step=globalStep)

            if step != 0 and (step % opt.evalTestStep) == 0:
                # Reset states for test
                testLoss.reset_states()
                testPSNR.reset_states()
                for x_batch_val, y_batch_val, y_mask_batch_val in valSet:
                    testStep(x_batch_val, y_batch_val, y_mask_batch_val,
                             checkpoint, lossFunc, PSNRFunc, testLoss,
                             testPSNR)
                tf.summary.scalar('Test loss',
                                  testLoss.result(),
                                  step=globalStep)
                tf.summary.scalar('Test PSNR',
                                  testPSNR.result(),
                                  step=globalStep)
                t = f"Validation results... val_loss: {testLoss.result():.3f}, val_psnr: {testPSNR.result():.3f}"
                logger.info(t)
                w.flush()

                if saveBestOnly and (testPSNR.result() <= checkpoint.psnr):
                    continue

                checkpoint.psnr = testPSNR.result()
                checkpointManager.save()