Ejemplo n.º 1
0
                target_Q = torch.min(critic1.target_net(t.next_states, target_action),
                                     critic2.target_net(t.next_states, target_action)) \
                           - alpha * target_action_log
            target = t.rewards + gamma * (1 - t.terminals) * target_Q
            current_v1 = critic1.net(t.states, t.actions)
            current_v2 = critic2.net(t.states, t.actions)
            critic_loss1 = critic_loss_fnc(target, current_v1)
            critic_loss2 = critic_loss_fnc(target, current_v2)
            wandb.log(
                {
                    "value_loss": (critic_loss1 + critic_loss2) / 2,
                    'step': global_step,
                    'episode': episode
                },
                commit=False)
            critic1.optimise(critic_loss1)
            critic2.optimise(critic_loss2)
            critic1.soft_target_update()
            critic2.soft_target_update()

            policy_action, policy_action_log = actor.action_selection(t.states)
            Q_min = torch.min(critic1.net(t.states, policy_action),
                              critic2.net(t.states, policy_action))
            actor_loss = (Q_min - alpha * policy_action_log).mean()
            wandb.log(
                {
                    "policy_loss": actor_loss,
                    'step': global_step,
                    'episode': episode
                },
                commit=False)
Ejemplo n.º 2
0
                                           t.actions).log()
                ratio = torch.exp(action_prob - old_action_probs)
                clipped_ratio = torch.clamp(ratio, 1 - clip_ratio,
                                            1 + clip_ratio)
                actor_loss = (torch.min(ratio, clipped_ratio) *
                              advantage).mean()
                wandb.log(
                    {
                        "policy_loss": actor_loss,
                        'step': global_step,
                        'episode': episode
                    },
                    commit=False)
                actor.optimise(-actor_loss)

            for _ in range(params['critic_grad_steps']):
                current_v = critic.net(t.states)
                critic_loss = critic_loss_fnc(rewards_to_go, current_v)
                wandb.log(
                    {
                        "value_loss": critic_loss,
                        'step': global_step,
                        'episode': episode
                    },
                    commit=False)
                critic.optimise(critic_loss)

            buffer.empty()

    wandb.log({"episode_reward": episode_reward, "episode": episode})
Ejemplo n.º 3
0
    state = env.env.reset()
    terminal = False
    while terminal is False:
        action = Qnet.epsilon_greedy_action(state, episode)
        next_state, reward, terminal, _ = env.env.step(action)
        wandb.log({'reward': reward, 'step': global_step, 'episode': episode})
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step, None)
        state = next_state
        episode_reward += reward

        if (episode_step % params['sample_collection'] == 0 or terminal is True) and\
                len(buffer) >= params['minibatch_size']:

            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target = t.rewards + (1-t.terminals) * gamma * Qnet.target_net(t.next_states).max(dim=1).values\
                    .reshape(-1, 1)
            current_v = torch.gather(Qnet.net(t.states), 1, t.actions)
            loss = loss_function(target, current_v)
            wandb.log({"value_loss": loss, 'step': global_step, 'episode': episode}, commit=False)
            Qnet.optimise(loss)

        if global_step % params['target_steps_update'] == 0:
            Qnet.hard_target_update()

    wandb.log({"episode_reward": episode_reward, "episode": episode})