コード例 #1
0
ファイル: mve.py プロジェクト: sradicwebster/RL_framework
                            imagine_state.squeeze().numpy(),
                            imagine_action_scaled.squeeze().numpy())
                        imagine_next_state = imagine_state + dynamics.model(
                            torch.cat((imagine_state, imagine_action), dim=1))

                        t.states = torch.cat((t.states, imagine_state))
                        t.actions = torch.cat((t.actions, imagine_action))
                        t.rewards = torch.cat(
                            (t.rewards, torch.Tensor([gamma**(j + 1) * reward
                                                      ]).reshape(1, -1)))
                        imagine_state = imagine_next_state

                    imagine_action = actor.target_net(imagine_state).reshape(
                        1, -1)
                    bootstrap_Q = gamma**(params['imagination_steps'] +
                                          1) * critic.target_net(
                                              imagine_state, imagine_action)

                target = torch.stack([
                    t.rewards[i:].sum() + bootstrap_Q
                    for i in range(len(t.rewards))
                ]).reshape(-1, 1)
                current = critic.net(t.states, t.actions)
                critic_loss = critic_loss_fnc(target, current)
                wandb.log(
                    {
                        "value_loss": critic_loss,
                        'step': global_step,
                        'episode': episode
                    },
                    commit=False)
                critic.optimise(critic_loss)
コード例 #2
0
        episode_reward += reward
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step,
                   None)
        state = next_state

        if (episode_step % params['sample_collection'] == 0 or terminal is True) \
                and len(buffer) >= params['minibatch_size']:
            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target_action = actor.target_net(t.next_states)
                target = t.rewards + gamma * (1 -
                                              t.terminals) * critic.target_net(
                                                  t.next_states, target_action)
            current_v = critic.net(t.states, t.actions)
            critic_loss = critic_loss_fnc(target, current_v)
            wandb.log(
                {
                    "value_loss": critic_loss,
                    'step': global_step,
                    'episode': episode
                },
                commit=False)
            critic.optimise(critic_loss)

            actor_loss = critic.net(t.states, actor.net(t.states)).mean()
            wandb.log(
                {
                    "policy_loss": actor_loss,
コード例 #3
0
ファイル: sac.py プロジェクト: sradicwebster/RL_framework
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step,
                   action_log_prob)
        state = next_state

        if (episode_step % params['sample_collection'] == 0 or terminal is True) \
                and len(buffer) >= params['minibatch_size']:

            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target_action, target_action_log = actor.action_selection(
                    t.next_states)
                target_Q = torch.min(critic1.target_net(t.next_states, target_action),
                                     critic2.target_net(t.next_states, target_action)) \
                           - alpha * target_action_log
            target = t.rewards + gamma * (1 - t.terminals) * target_Q
            current_v1 = critic1.net(t.states, t.actions)
            current_v2 = critic2.net(t.states, t.actions)
            critic_loss1 = critic_loss_fnc(target, current_v1)
            critic_loss2 = critic_loss_fnc(target, current_v2)
            wandb.log(
                {
                    "value_loss": (critic_loss1 + critic_loss2) / 2,
                    'step': global_step,
                    'episode': episode
                },
                commit=False)
            critic1.optimise(critic_loss1)
コード例 #4
0
ファイル: dqn.py プロジェクト: sradicwebster/RL_framework
    state = env.env.reset()
    terminal = False
    while terminal is False:
        action = Qnet.epsilon_greedy_action(state, episode)
        next_state, reward, terminal, _ = env.env.step(action)
        wandb.log({'reward': reward, 'step': global_step, 'episode': episode})
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step, None)
        state = next_state
        episode_reward += reward

        if (episode_step % params['sample_collection'] == 0 or terminal is True) and\
                len(buffer) >= params['minibatch_size']:

            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target = t.rewards + (1-t.terminals) * gamma * Qnet.target_net(t.next_states).max(dim=1).values\
                    .reshape(-1, 1)
            current_v = torch.gather(Qnet.net(t.states), 1, t.actions)
            loss = loss_function(target, current_v)
            wandb.log({"value_loss": loss, 'step': global_step, 'episode': episode}, commit=False)
            Qnet.optimise(loss)

        if global_step % params['target_steps_update'] == 0:
            Qnet.hard_target_update()

    wandb.log({"episode_reward": episode_reward, "episode": episode})