Exemple #1
0
    def train(self, env, config, batch_size=128, updates=500, max_seconds=30):
        models = config.get()
        for model in models:
            model.compile(optimizer=ko.RMSprop(lr=self.lr),
                          loss=[self._logits_loss, self._value_loss])
        # Storage helpers for a single batch of data.
        actions = np.empty((batch_size, config.num), dtype=np.int32)
        rewards, dones, values = np.empty((3, batch_size, config.num))
        observations = np.empty(
            (batch_size, config.window_size, env.observations_size))

        # Training loop: collect samples, send to optimizer, repeat updates times.
        deaths = {}
        for model in models:
            deaths[model.label] = 0
        obs_window = env.reset()
        episodes = []
        steps = 0
        pb = ProgressBar(f'{config.label}')
        total_progress = updates * batch_size
        progress = 0
        pb.reset()
        for _ in range(updates):
            for step in range(batch_size):
                steps += 1
                progress += 1
                observations[step] = obs_window
                for m_i, model in enumerate(models):
                    actions[step,
                            m_i], values[step,
                                         m_i] = model.action_value(obs_window)
                obs_window, rewards[step], dones[step] = env.step(
                    actions[step])
                if any(dones[step]) or max_seconds < steps * env.dt:
                    obs_window = env.reset()
                    episodes.append(steps * env.dt)
                    steps = 0
                    for dead, model in zip(dones[step], models):
                        if dead:
                            deaths[model.label] += 1
            losses = []
            for m_i, model in enumerate(models):
                _, next_value = model.action_value(obs_window)
                returns, advs = self._returns_advantages(
                    rewards[:, m_i], dones[:, m_i], values[:, m_i], next_value)
                # A trick to input actions and advantages through same API.
                acts_and_advs = np.concatenate(
                    [actions[:, m_i, None], advs[:, None]], axis=-1)
                loss = model.train_on_batch(
                    observations[:, -model.input_size:, :],
                    [acts_and_advs, returns])
                losses.append(loss[0])
            pb(progress / total_progress,
               f' loss: {sum(losses)/len(losses):6.3f}')
        return episodes, deaths
Exemple #2
0
def train(model: tf.keras.Model,
          checkpoint: tf.train.CheckpointManager,
          batch_size: Any,
          epochs: Any,
          train_dataset: Any,
          valid_dataset: AnyStr = None,
          max_train_steps: Any = -1,
          checkpoint_save_freq: Any = 2,
          *args,
          **kwargs) -> Dict:
    """ 训练器

    :param model: 训练模型
    :param checkpoint: 检查点管理器
    :param batch_size: batch 大小
    :param epochs: 训练周期
    :param train_dataset: 训练数据集
    :param valid_dataset: 验证数据集
    :param max_train_steps: 最大训练数据量,-1为全部
    :param checkpoint_save_freq: 检查点保存频率
    :return:
    """
    print("训练开始,正在准备数据中")
    # learning_rate = CustomSchedule(d_model=embedding_dim)
    loss_metric = tf.keras.metrics.Mean(name="train_loss_metric")
    optimizer = tf.optimizers.Adam(learning_rate=2e-5,
                                   beta_1=0.9,
                                   beta_2=0.999,
                                   name="optimizer")

    train_steps_per_epoch = max_train_steps if max_train_steps != -1 else (
        40000 // batch_size)
    valid_steps_per_epoch = 3944 // batch_size

    progress_bar = ProgressBar()
    for epoch in range(epochs):
        print("Epoch {}/{}".format(epoch + 1, epochs))
        start_time = time.time()
        loss_metric.reset_states()
        progress_bar.reset(total=train_steps_per_epoch, num=batch_size)

        train_metric = None
        for (batch,
             (train_enc, train_dec, month_enc, month_dec,
              labels)) in enumerate(train_dataset.take(max_train_steps)):
            train_metric, prediction = _train_step(model=model,
                                                   optimizer=optimizer,
                                                   loss_metric=loss_metric,
                                                   train_enc=train_enc,
                                                   train_dec=train_dec,
                                                   month_enc=month_enc,
                                                   month_dec=month_dec,
                                                   labels=labels)

            progress_bar(current=batch + 1,
                         metrics=get_dict_string(data=train_metric))
        progress_bar(current=progress_bar.total,
                     metrics=get_dict_string(data=train_metric))

        progress_bar.done(step_time=time.time() - start_time)

        if (epoch + 1) % checkpoint_save_freq == 0:
            checkpoint.save()

            if valid_steps_per_epoch == 0 or valid_dataset is None:
                print("验证数据量过小,小于batch_size,已跳过验证轮次")
            else:
                progress_bar.reset(total=valid_steps_per_epoch, num=batch_size)
                valid_metrics = _valid_step(model=model,
                                            dataset=valid_dataset,
                                            progress_bar=progress_bar,
                                            loss_metric=loss_metric,
                                            **kwargs)
    print("训练结束")
    return {}