def train_model(self, epochs, minibatch_size, grad_steps=1, standardise=False, noise_std=None): for i in range(epochs): minibatch = self.buffer.random_sample(minibatch_size) t = ProcessMinibatch(minibatch) if standardise: t.standardise(self.env.obs_high) if self.type == 'forward': target = t.next_states else: target = t.next_states - t.states if noise_std is not None: target += torch.normal(0, noise_std, size=t.states.shape) t.states += torch.normal(0, noise_std, size=t.states.shape) t.actions += torch.normal(0, noise_std, size=t.actions.shape) for _ in range(grad_steps): current = self.model(torch.cat((t.states, t.actions), dim=1)) loss = self.loss_func(current, target) wandb.log({"model_loss": loss}, commit=False) self.opt.zero_grad() loss.backward() self.opt.step()
def train_reward_fnc(self, epochs, minibatch_size): for i in range(epochs): minibatch = self.buffer.random_sample(minibatch_size) t = ProcessMinibatch(minibatch) target = t.rewards current = self.reward(torch.cat((t.states, t.actions), dim=1)) loss = self.loss_func(current, target) wandb.log({"reward_loss": loss}, commit=False) self.rew_opt.zero_grad() loss.backward() self.rew_opt.step()
terminal = False while terminal is False: action, action_log_prob = actor.softmax_action(state) next_state, reward, terminal, _ = env.env.step(action) wandb.log({'reward': reward, 'step': global_step, 'episode': episode}) episode_step += 1 global_step += 1 buffer.add(state, action, reward, next_state, terminal, episode_step, action_log_prob) state = next_state episode_reward += reward if terminal is True: minibatch = buffer.ordered_sample(episode_step) t = ProcessMinibatch(minibatch) with torch.no_grad(): td_error = t.rewards + gamma * (1 - t.terminals) * critic.net( t.next_states) - critic.net(t.states) discounted_gamma = gamma**t.steps advantage = discounted_cumsum(td_error, discounted_gamma) advantage = (advantage - advantage.mean()) / advantage.std() rewards_to_go = discounted_cumsum(t.rewards, discounted_gamma) old_action_probs = t.action_log_prob.reshape(-1, 1).detach() for _ in range(params['actor_grad_steps']): action_prob = torch.gather(actor.net(t.states), 1, t.actions).log() ratio = torch.exp(action_prob - old_action_probs) clipped_ratio = torch.clamp(ratio, 1 - clip_ratio,
if (episode_step % params['sample_collection'] == 0 or terminal is True) and\ len(buffer) >= params['minibatch_size']: # Train dynamics model # ~~~~~~~~~~~~~~~~~~~~ dynamics.train_model(params['training_epoch'], params['minibatch_size'], noise_std=0.001) # Train value and policy networks # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ for _ in range(params['sampled_transitions']): minibatch = buffer.random_sample(1) t = ProcessMinibatch(minibatch) actor_loss = critic.net(t.states, actor.net(t.states)) wandb.log( { "policy_loss": actor_loss, 'step': global_step, 'episode': episode }, commit=False) actor.optimise(-actor_loss) imagine_state = t.next_states with torch.no_grad(): for j in range(params['imagination_steps']): imagine_action = actor.target_net(imagine_state)
# Gather random data and train dynamics models # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ while len(dataset_random) < params['random_buffer_size']: state = env.env.reset() + torch.normal(0, 0.001, size=(env.obs_size, )).numpy() terminal = False while terminal is False: action = torch.randint(env.action_size, size=(1, )).item() next_state, reward, terminal, _ = env.env.step(action) dataset_random.add(state, action, reward, next_state, terminal, None, None) state = next_state for i in range(params['training_epoch']): minibatch = dataset_random.random_sample(params['minibatch_size']) t = ProcessMinibatch(minibatch) t.standardise(env.obs_high) target = t.next_states - t.states + torch.normal( 0, 0.001, size=t.states.shape) state_actions = torch.cat((t.states, t.actions), dim=1) current = model_net(state_actions + torch.normal(0, 0.001, size=state_actions.shape)) loss = loss_fnc(target, current) wandb.log({"model_loss": loss}) opt.zero_grad() loss.backward() opt.step() # Model based controller loop # ~~~~~~~~~~~~~~~~~~~~~~~~~~~ global_step = 0