def train(self, env, config, batch_size=128, updates=500, max_seconds=30): models = config.get() for model in models: model.compile(optimizer=ko.RMSprop(lr=self.lr), loss=[self._logits_loss, self._value_loss]) # Storage helpers for a single batch of data. actions = np.empty((batch_size, config.num), dtype=np.int32) rewards, dones, values = np.empty((3, batch_size, config.num)) observations = np.empty( (batch_size, config.window_size, env.observations_size)) # Training loop: collect samples, send to optimizer, repeat updates times. deaths = {} for model in models: deaths[model.label] = 0 obs_window = env.reset() episodes = [] steps = 0 pb = ProgressBar(f'{config.label}') total_progress = updates * batch_size progress = 0 pb.reset() for _ in range(updates): for step in range(batch_size): steps += 1 progress += 1 observations[step] = obs_window for m_i, model in enumerate(models): actions[step, m_i], values[step, m_i] = model.action_value(obs_window) obs_window, rewards[step], dones[step] = env.step( actions[step]) if any(dones[step]) or max_seconds < steps * env.dt: obs_window = env.reset() episodes.append(steps * env.dt) steps = 0 for dead, model in zip(dones[step], models): if dead: deaths[model.label] += 1 losses = [] for m_i, model in enumerate(models): _, next_value = model.action_value(obs_window) returns, advs = self._returns_advantages( rewards[:, m_i], dones[:, m_i], values[:, m_i], next_value) # A trick to input actions and advantages through same API. acts_and_advs = np.concatenate( [actions[:, m_i, None], advs[:, None]], axis=-1) loss = model.train_on_batch( observations[:, -model.input_size:, :], [acts_and_advs, returns]) losses.append(loss[0]) pb(progress / total_progress, f' loss: {sum(losses)/len(losses):6.3f}') return episodes, deaths
def train(model: tf.keras.Model, checkpoint: tf.train.CheckpointManager, batch_size: Any, epochs: Any, train_dataset: Any, valid_dataset: AnyStr = None, max_train_steps: Any = -1, checkpoint_save_freq: Any = 2, *args, **kwargs) -> Dict: """ 训练器 :param model: 训练模型 :param checkpoint: 检查点管理器 :param batch_size: batch 大小 :param epochs: 训练周期 :param train_dataset: 训练数据集 :param valid_dataset: 验证数据集 :param max_train_steps: 最大训练数据量,-1为全部 :param checkpoint_save_freq: 检查点保存频率 :return: """ print("训练开始,正在准备数据中") # learning_rate = CustomSchedule(d_model=embedding_dim) loss_metric = tf.keras.metrics.Mean(name="train_loss_metric") optimizer = tf.optimizers.Adam(learning_rate=2e-5, beta_1=0.9, beta_2=0.999, name="optimizer") train_steps_per_epoch = max_train_steps if max_train_steps != -1 else ( 40000 // batch_size) valid_steps_per_epoch = 3944 // batch_size progress_bar = ProgressBar() for epoch in range(epochs): print("Epoch {}/{}".format(epoch + 1, epochs)) start_time = time.time() loss_metric.reset_states() progress_bar.reset(total=train_steps_per_epoch, num=batch_size) train_metric = None for (batch, (train_enc, train_dec, month_enc, month_dec, labels)) in enumerate(train_dataset.take(max_train_steps)): train_metric, prediction = _train_step(model=model, optimizer=optimizer, loss_metric=loss_metric, train_enc=train_enc, train_dec=train_dec, month_enc=month_enc, month_dec=month_dec, labels=labels) progress_bar(current=batch + 1, metrics=get_dict_string(data=train_metric)) progress_bar(current=progress_bar.total, metrics=get_dict_string(data=train_metric)) progress_bar.done(step_time=time.time() - start_time) if (epoch + 1) % checkpoint_save_freq == 0: checkpoint.save() if valid_steps_per_epoch == 0 or valid_dataset is None: print("验证数据量过小,小于batch_size,已跳过验证轮次") else: progress_bar.reset(total=valid_steps_per_epoch, num=batch_size) valid_metrics = _valid_step(model=model, dataset=valid_dataset, progress_bar=progress_bar, loss_metric=loss_metric, **kwargs) print("训练结束") return {}