示例#1
0
    def train_model(self,
                    epochs,
                    minibatch_size,
                    grad_steps=1,
                    standardise=False,
                    noise_std=None):

        for i in range(epochs):
            minibatch = self.buffer.random_sample(minibatch_size)
            t = ProcessMinibatch(minibatch)

            if standardise:
                t.standardise(self.env.obs_high)

            if self.type == 'forward':
                target = t.next_states
            else:
                target = t.next_states - t.states

            if noise_std is not None:
                target += torch.normal(0, noise_std, size=t.states.shape)
                t.states += torch.normal(0, noise_std, size=t.states.shape)
                t.actions += torch.normal(0, noise_std, size=t.actions.shape)

            for _ in range(grad_steps):
                current = self.model(torch.cat((t.states, t.actions), dim=1))
                loss = self.loss_func(current, target)
                wandb.log({"model_loss": loss}, commit=False)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
示例#2
0
    def train_reward_fnc(self, epochs, minibatch_size):

        for i in range(epochs):
            minibatch = self.buffer.random_sample(minibatch_size)
            t = ProcessMinibatch(minibatch)
            target = t.rewards
            current = self.reward(torch.cat((t.states, t.actions), dim=1))
            loss = self.loss_func(current, target)
            wandb.log({"reward_loss": loss}, commit=False)
            self.rew_opt.zero_grad()
            loss.backward()
            self.rew_opt.step()
示例#3
0
    terminal = False
    while terminal is False:
        action, action_log_prob = actor.softmax_action(state)
        next_state, reward, terminal, _ = env.env.step(action)
        wandb.log({'reward': reward, 'step': global_step, 'episode': episode})
        episode_step += 1
        global_step += 1
        buffer.add(state, action, reward, next_state, terminal, episode_step,
                   action_log_prob)
        state = next_state
        episode_reward += reward

        if terminal is True:

            minibatch = buffer.ordered_sample(episode_step)
            t = ProcessMinibatch(minibatch)

            with torch.no_grad():
                td_error = t.rewards + gamma * (1 - t.terminals) * critic.net(
                    t.next_states) - critic.net(t.states)
            discounted_gamma = gamma**t.steps
            advantage = discounted_cumsum(td_error, discounted_gamma)
            advantage = (advantage - advantage.mean()) / advantage.std()
            rewards_to_go = discounted_cumsum(t.rewards, discounted_gamma)
            old_action_probs = t.action_log_prob.reshape(-1, 1).detach()

            for _ in range(params['actor_grad_steps']):
                action_prob = torch.gather(actor.net(t.states), 1,
                                           t.actions).log()
                ratio = torch.exp(action_prob - old_action_probs)
                clipped_ratio = torch.clamp(ratio, 1 - clip_ratio,
示例#4
0
        if (episode_step % params['sample_collection'] == 0 or terminal is True) and\
                len(buffer) >= params['minibatch_size']:

            # Train dynamics model
            # ~~~~~~~~~~~~~~~~~~~~
            dynamics.train_model(params['training_epoch'],
                                 params['minibatch_size'],
                                 noise_std=0.001)

            # Train value and policy networks
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            for _ in range(params['sampled_transitions']):

                minibatch = buffer.random_sample(1)
                t = ProcessMinibatch(minibatch)

                actor_loss = critic.net(t.states, actor.net(t.states))
                wandb.log(
                    {
                        "policy_loss": actor_loss,
                        'step': global_step,
                        'episode': episode
                    },
                    commit=False)
                actor.optimise(-actor_loss)

                imagine_state = t.next_states
                with torch.no_grad():
                    for j in range(params['imagination_steps']):
                        imagine_action = actor.target_net(imagine_state)
示例#5
0
# Gather random data and train dynamics models
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
while len(dataset_random) < params['random_buffer_size']:
    state = env.env.reset() + torch.normal(0, 0.001,
                                           size=(env.obs_size, )).numpy()
    terminal = False
    while terminal is False:
        action = torch.randint(env.action_size, size=(1, )).item()
        next_state, reward, terminal, _ = env.env.step(action)
        dataset_random.add(state, action, reward, next_state, terminal, None,
                           None)
        state = next_state

for i in range(params['training_epoch']):
    minibatch = dataset_random.random_sample(params['minibatch_size'])
    t = ProcessMinibatch(minibatch)
    t.standardise(env.obs_high)
    target = t.next_states - t.states + torch.normal(
        0, 0.001, size=t.states.shape)
    state_actions = torch.cat((t.states, t.actions), dim=1)
    current = model_net(state_actions +
                        torch.normal(0, 0.001, size=state_actions.shape))
    loss = loss_fnc(target, current)
    wandb.log({"model_loss": loss})
    opt.zero_grad()
    loss.backward()
    opt.step()

# Model based controller loop
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
global_step = 0