Beispiel #1
0
    def train(
        self,
        env: gym.Env,
        agent: Agent,
        network: Network,
        optimizer,
        window_size: int,
        nb_self_play: int,
        num_unroll_steps: int,
        td_steps: int,
        discount: float,
        batch_size: int,
        nb_train_update: int,
        nb_train_epochs: int,
        max_grad_norm: float,
        filename: str,
        ent_c: float,
    ):
        replay_buffer = ReplayBuffer(window_size, batch_size)

        for epoch in range(nb_train_epochs):
            network.eval()
            rewards = []
            for _ in range(nb_self_play):
                game_buffer = self._play_one_game(env, agent)
                # game_buffer.print_buffer()
                replay_buffer.append(game_buffer)
                rewards.append(np.sum(game_buffer.rewards))

            network.train()
            losses = []
            for _ in range(nb_train_update):
                batch = replay_buffer.sample_batch(num_unroll_steps, td_steps,
                                                   discount)
                losses.append(
                    self._update_weights(network, optimizer, batch,
                                         max_grad_norm, ent_c))
            v_loss, r_loss, p_loss, entropy = np.mean(losses, axis=0)
            print(
                f"Epoch[{epoch+1}]: Reward[{np.mean(rewards)}], Loss: V[{v_loss:.6f}]/R[{r_loss:.6f}]/P[{p_loss:.6f}]/E[{entropy:.6f}]"
            )

            if (epoch + 1) % 10 == 0:
                agent.save_model(filename)
Beispiel #2
0
                    recent_stats[i].append(instruction_data[3][i])
                    recent_stats[i] = recent_stats[i][-RECENT_HISTORY_LENGTH:]

        elif instruction == 'ADD_EXPERIENCE':

            socket.send_pyobj(0)  # Return immediately.

            recent_values.append(instruction_data[3])
            recent_values = recent_values[-RECENT_HISTORY_LENGTH:]

            # Add to buffer.
            instruction_data_cuda = [
                torch.tensor(t, dtype=torch.float, device=device)
                for t in instruction_data
            ]
            replay_buffer.append(instruction_data_cuda)

            # Check for minimum replay size.
            if len(replay_buffer) < REPLAY_MIN:
                print('Waiting for minimum buffer size ... {}/{}'.format(
                    len(replay_buffer), REPLAY_MIN))
                continue

            # Sample training mini-batch.
            sampled_evaluations = replay_buffer.sample(REPLAY_SAMPLE_SIZE)
            sampled_contexts = torch.stack([t[0] for t in sampled_evaluations])
            sampled_states = torch.stack([t[1] for t in sampled_evaluations])
            sampled_params = torch.stack([t[2] for t in sampled_evaluations])
            sampled_values = torch.stack([t[3] for t in sampled_evaluations])

            # Update critic.