예제 #1
0
            action, action_log_prob = actor.softmax_action(state)
        next_state, reward, terminal, _ = env.env.step(action)
        wandb.log({'reward': reward, 'step': global_step, 'episode': episode})
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step,
                   action_log_prob)
        state = next_state
        episode_reward += reward

        if (episode_step % params['sample_collection'] == 0 or terminal is True) and\
                len(buffer) >= params['minibatch_size']:

            minibatch = buffer.ordered_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)
            target = t.rewards + gamma * (1 - t.terminals) * critic.net(
                t.next_states)
            current_v = critic.net(t.states)
            critic_loss = critic_loss_fnc(target, current_v)
            wandb.log(
                {
                    "value_loss": critic_loss,
                    'step': global_step,
                    'episode': episode
                },
                commit=False)
            critic.optimise(critic_loss)

            advantage = (target - current_v).detach()
            discounted_gamma = gamma**t.steps
            log_prob = torch.gather(actor.net(t.states), 1, t.actions).log()
            actor_loss = (discounted_gamma * advantage * log_prob).mean()
예제 #2
0
        next_state, reward, terminal, _ = env.env.step(action)
        wandb.log({'reward': reward, 'step': global_step, 'episode': episode})
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step,
                   action_log_prob)
        state = next_state
        episode_reward += reward

        if terminal is True:

            minibatch = buffer.ordered_sample(episode_step)
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                td_error = t.rewards + gamma * (1 - t.terminals) * critic.net(
                    t.next_states) - critic.net(t.states)
            discounted_gamma = gamma**t.steps
            advantage = discounted_cumsum(td_error, discounted_gamma)
            advantage = (advantage - advantage.mean()) / advantage.std()
            rewards_to_go = discounted_cumsum(t.rewards, discounted_gamma)
            old_action_probs = t.action_log_prob.reshape(-1, 1).detach()

            for _ in range(params['actor_grad_steps']):
                action_prob = torch.gather(actor.net(t.states), 1,
                                           t.actions).log()
                ratio = torch.exp(action_prob - old_action_probs)
                clipped_ratio = torch.clamp(ratio, 1 - clip_ratio,
                                            1 + clip_ratio)
                actor_loss = (torch.min(ratio, clipped_ratio) *
                              advantage).mean()
                wandb.log(
예제 #3
0
                len(buffer) >= params['minibatch_size']:

            # Train dynamics model
            # ~~~~~~~~~~~~~~~~~~~~
            dynamics.train_model(params['training_epoch'],
                                 params['minibatch_size'],
                                 noise_std=0.001)

            # Train value and policy networks
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            for _ in range(params['sampled_transitions']):

                minibatch = buffer.random_sample(1)
                t = ProcessMinibatch(minibatch)

                actor_loss = critic.net(t.states, actor.net(t.states))
                wandb.log(
                    {
                        "policy_loss": actor_loss,
                        'step': global_step,
                        'episode': episode
                    },
                    commit=False)
                actor.optimise(-actor_loss)

                imagine_state = t.next_states
                with torch.no_grad():
                    for j in range(params['imagination_steps']):
                        imagine_action = actor.target_net(imagine_state)
                        imagine_action_scaled = env.action_low + (
                            env.action_high -
예제 #4
0
        state = next_state

        if (episode_step % params['sample_collection'] == 0 or terminal is True) \
                and len(buffer) >= params['minibatch_size']:

            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target_action, target_action_log = actor.action_selection(
                    t.next_states)
                target_Q = torch.min(critic1.target_net(t.next_states, target_action),
                                     critic2.target_net(t.next_states, target_action)) \
                           - alpha * target_action_log
            target = t.rewards + gamma * (1 - t.terminals) * target_Q
            current_v1 = critic1.net(t.states, t.actions)
            current_v2 = critic2.net(t.states, t.actions)
            critic_loss1 = critic_loss_fnc(target, current_v1)
            critic_loss2 = critic_loss_fnc(target, current_v2)
            wandb.log(
                {
                    "value_loss": (critic_loss1 + critic_loss2) / 2,
                    'step': global_step,
                    'episode': episode
                },
                commit=False)
            critic1.optimise(critic_loss1)
            critic2.optimise(critic_loss2)
            critic1.soft_target_update()
            critic2.soft_target_update()
예제 #5
0
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step,
                   None)
        state = next_state

        if (episode_step % params['sample_collection'] == 0 or terminal is True) \
                and len(buffer) >= params['minibatch_size']:
            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target_action = actor.target_net(t.next_states)
                target = t.rewards + gamma * (1 -
                                              t.terminals) * critic.target_net(
                                                  t.next_states, target_action)
            current_v = critic.net(t.states, t.actions)
            critic_loss = critic_loss_fnc(target, current_v)
            wandb.log(
                {
                    "value_loss": critic_loss,
                    'step': global_step,
                    'episode': episode
                },
                commit=False)
            critic.optimise(critic_loss)

            actor_loss = critic.net(t.states, actor.net(t.states)).mean()
            wandb.log(
                {
                    "policy_loss": actor_loss,
                    'step': global_step,
예제 #6
0
        if (episode_step % params['sample_collection'] == 0 or terminal is True) \
                and len(buffer) >= params['minibatch_size']:
            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target_action = actor.target_net(t.next_states) + torch.clamp(target_noise.sample(),
                                                                              -noise_clip, noise_clip)
                target_action = torch.stack([torch.clamp(target_action[:, i], env.action_low[i].item(),
                                                         env.action_high[i].item()) for i in range(env.action_size)])\
                    .reshape(-1, env.action_size)
                q_target = torch.min(critic1.target_net(t.next_states, target_action),
                                     critic2.target_net(t.next_states, target_action))
                target = t.rewards + gamma * (1 - t.terminals) * q_target

            current_v1 = critic1.net(t.states, t.actions)
            current_v2 = critic2.net(t.states, t.actions)
            critic_loss1 = critic_loss_fnc(target, current_v1)
            critic_loss2 = critic_loss_fnc(target, current_v2)
            wandb.log({"value_loss": (critic_loss1+critic_loss2)/2, 'step': global_step, 'episode': episode},
                      commit=False)
            critic1.optimise(critic_loss1)
            critic2.optimise(critic_loss2)
            critic1.soft_target_update()
            critic2.soft_target_update()

            if episode_step % policy_delay == 0:
                actor_loss = critic1.net(t.states, actor.net(t.states)).mean()
                wandb.log({"policy_loss": actor_loss, 'step': global_step, 'episode': episode}, commit=False)
                actor.optimise(-actor_loss)
                actor.soft_target_update()
예제 #7
0
    state = env.env.reset()
    terminal = False
    while terminal is False:
        action = Qnet.epsilon_greedy_action(state, episode)
        next_state, reward, terminal, _ = env.env.step(action)
        wandb.log({'reward': reward, 'step': global_step, 'episode': episode})
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step, None)
        state = next_state
        episode_reward += reward

        if (episode_step % params['sample_collection'] == 0 or terminal is True) and\
                len(buffer) >= params['minibatch_size']:

            minibatch = buffer.random_sample(params['minibatch_size'])
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                target = t.rewards + (1-t.terminals) * gamma * Qnet.target_net(t.next_states).max(dim=1).values\
                    .reshape(-1, 1)
            current_v = torch.gather(Qnet.net(t.states), 1, t.actions)
            loss = loss_function(target, current_v)
            wandb.log({"value_loss": loss, 'step': global_step, 'episode': episode}, commit=False)
            Qnet.optimise(loss)

        if global_step % params['target_steps_update'] == 0:
            Qnet.hard_target_update()

    wandb.log({"episode_reward": episode_reward, "episode": episode})