def test_ReplayBuffer(self): mem = ReplayBuffer(2) mem.push(1) mem.push(2) [sample] = mem.sample(2) self.assertEqual(sorted(sample), [1, 2]) mem.push(3) [sample] = mem.sample(2) self.assertEqual(sorted(sample), [2, 3]) mem.push(4) [sample] = mem.sample(2) self.assertEqual(sorted(sample), [3, 4])
class Trainer: def __init__(self, policyclass, config): self.config = config self.env = gym.make(config.env) self.device = torch.cuda.current_device() if torch.cuda.is_available( ) else "cpu" self.a_space, self.obs_space = self.env.action_space.n, self.env.observation_space.shape[ 0] self.policy_net = policyclass(self.obs_space, self.a_space).to(self.device) self.target_net = policyclass(self.obs_space, self.a_space).to(self.device) self.target_net.eval() self.buf = ReplayBuffer(config.capacity) self.lossfn = nn.MSELoss() self.optimizer = AdamW(self.policy_net.parameters(), lr=config.lr) self.eps = config.eps_start self.eps_interval = (config.eps_start - config.eps_end) / config.num_epochs self.eps_interval *= 2 if config.render: self.env.render() if config.monitor: self.env = gym.wrappers.Monitor(self.env, config.vid_save_path, \ video_callable = lambda ep: ep % config.vid_interval == 0,force= True) def zero_grad(self): for param in self.policy_net.parameters(): param.grad = None def train(self): config, env, buf = self.config, self.env, self.buf lr = config.lr def update_target_net(): self.target_net.load_state_dict(self.policy_net.state_dict()) def run_epoch(): loss = None curr_state = env.reset() done, next_state, reward_list = False, None, [] while not done: action = self.get_eps_act( torch.tensor(curr_state, device=self.device, dtype=torch.float32).unsqueeze(0)) next_state, reward, done, _ = env.step(action) reward_list.append(reward) self.buf.push(curr_state, action, reward, next_state, done) curr_state = next_state if len(self.buf) >= self.config.batch_size: loss = self.optimize_model() return sum(reward_list), loss pbar = tqdm(range(config.num_epochs)) for eps in pbar: rewards, loss = run_epoch() if eps % config.target_update == 0: update_target_net() if loss is not None: strprint = f"epoch {eps+1}: loss {loss:.5f}. eps {self.eps} reward {rewards}" else: strprint = f"epoch {eps+1}: eps {self.eps} reward {rewards}" pbar.set_description(strprint) if self.eps > self.config.eps_end: self.eps = self.eps - self.eps_interval self.save_model() def optimize_model(self): S, A, R, S_, done = self.buf.torch_samples(self.config.batch_size, device=self.device) target = self.config.gamma * self.target_net(S_).max( 1)[0].detach().view(-1, 1) * (1 - done) target = target + R estimate = self.policy_net(S).gather(1, A) loss = self.lossfn(estimate, target) self.zero_grad() loss.backward() nn.utils.clip_grad_value_(self.policy_net.parameters(), self.config.grad_norm_clip) self.optimizer.step() return loss.item() def get_eps_act(self, state): """ accepts a tensor that is loaded onto the device already """ if random.random() > self.eps: action = self.policy_net(state).max(1)[1].item() else: action = random.randrange(self.a_space) return action def save_model(self): logger.info("Saving Model to {self.config.save_path}") torch.save(self.policy_net.state_dict(), self.config.save_path)
def train(config_filepath, save_dir, device, visualize_interval): conf = load_toml_config(config_filepath) data_dir, log_dir = create_save_dir(save_dir) # Save config file shutil.copyfile(config_filepath, os.path.join(save_dir, os.path.basename(config_filepath))) device = torch.device(device) # Set up log metrics metrics = { 'episode': [], 'episodic_step': [], 'collected_total_samples': [], 'reward': [], 'q_loss': [], 'policy_loss': [], 'alpha_loss': [], 'alpha': [], 'policy_switch_epoch': [], 'policy_switch_sample': [], 'test_episode': [], 'test_reward': [], } policy_switch_samples = conf.policy_switch_samples if hasattr( conf, "policy_switch_samples") else None total_collected_samples = 0 # Create environment env = make_env(conf.environment, render=False) # Instantiate modules memory = ReplayBuffer(int(conf.replay_buffer_capacity), env.observation_space.shape, env.action_space.shape) agent = getattr(agents, conf.agent_type)(env.observation_space, env.action_space, device=device, **conf.agent) # Load checkpoint if specified in config if conf.checkpoint != '': ckpt = torch.load(conf.checkpoint, map_location=device) metrics = ckpt['metrics'] agent.load_state_dict(ckpt['agent']) memory.load_state_dict(ckpt['memory']) policy_switch_samples = ckpt['policy_switch_samples'] total_collected_samples = ckpt['total_collected_samples'] def save_checkpoint(): # Save checkpoint ckpt = { 'metrics': metrics, 'agent': agent.state_dict(), 'memory': memory.state_dict(), 'policy_switch_samples': policy_switch_samples, 'total_collected_samples': total_collected_samples } path = os.path.join(data_dir, 'checkpoint.pth') torch.save(ckpt, path) # Save agent model only model_ckpt = {'agent': agent.state_dict()} model_path = os.path.join(data_dir, 'model.pth') torch.save(model_ckpt, model_path) # Save metrics only metrics_ckpt = {'metrics': metrics} metrics_path = os.path.join(data_dir, 'metrics.pth') torch.save(metrics_ckpt, metrics_path) # Train agent init_episode = 0 if len( metrics['episode']) == 0 else metrics['episode'][-1] + 1 pbar = tqdm.tqdm(range(init_episode, conf.episodes)) reward_moving_avg = None agent_update_count = 0 for episode in pbar: episodic_reward = 0 o = env.reset() q1_loss, q2_loss, policy_loss, alpha_loss, alpha = None, None, None, None, None for t in range(conf.horizon): if total_collected_samples <= conf.random_sample_num: # Select random actions at the begining of training. h = env.action_space.sample() elif memory.step <= conf.random_sample_num: # Select actions from random latent variable soon after inserting a new subpolicy. h = agent.select_action(o, random=True) else: h = agent.select_action(o) a = agent.post_process_action( o, h) # Convert abstract action h to actual action a o_next, r, done, _ = env.step(a) total_collected_samples += 1 episodic_reward += r memory.push(o, h, r, o_next, done) o = o_next if memory.step > conf.random_sample_num: # Update agent batch_data = memory.sample(conf.agent_update_batch_size) q1_loss, q2_loss, policy_loss, alpha_loss, alpha = agent.update_parameters( batch_data, agent_update_count) agent_update_count += 1 if done: break # Describe and save episodic metrics reward_moving_avg = ( 1. - MOVING_AVG_COEF ) * reward_moving_avg + MOVING_AVG_COEF * episodic_reward if reward_moving_avg else episodic_reward pbar.set_description( "EPISODE {} (total samples {}, subpolicy samples {}) --- Step {}, Reward {:.1f} (avg {:.1f})" .format(episode, total_collected_samples, memory.step, t, episodic_reward, reward_moving_avg)) metrics['episode'].append(episode) metrics['reward'].append(episodic_reward) metrics['episodic_step'].append(t) metrics['collected_total_samples'].append(total_collected_samples) if episode % visualize_interval == 0: # Visualize metrics lineplot(metrics['episode'][-len(metrics['reward']):], metrics['reward'], 'REWARD', log_dir) reward_avg = np.array(metrics['reward']) / np.array( metrics['episodic_step']) lineplot(metrics['episode'][-len(reward_avg):], reward_avg, 'AVG_REWARD', log_dir) lineplot( metrics['collected_total_samples'][-len(metrics['reward']):], metrics['reward'], 'SAMPLE-REWARD', log_dir, xaxis='sample') # Save metrics for agent update if q1_loss is not None: metrics['q_loss'].append(np.mean([q1_loss, q2_loss])) metrics['policy_loss'].append(policy_loss) metrics['alpha_loss'].append(alpha_loss) metrics['alpha'].append(alpha) if episode % visualize_interval == 0: lineplot(metrics['episode'][-len(metrics['q_loss']):], metrics['q_loss'], 'Q_LOSS', log_dir) lineplot(metrics['episode'][-len(metrics['policy_loss']):], metrics['policy_loss'], 'POLICY_LOSS', log_dir) lineplot(metrics['episode'][-len(metrics['alpha_loss']):], metrics['alpha_loss'], 'ALPHA_LOSS', log_dir) lineplot(metrics['episode'][-len(metrics['alpha']):], metrics['alpha'], 'ALPHA', log_dir) # Insert new subpolicy layer and reset memory if a specific amount of samples is collected if policy_switch_samples and len( policy_switch_samples ) > 0 and total_collected_samples >= policy_switch_samples[0]: print( "----------------------\nInser new policy\n----------------------" ) agent.insert_subpolicy() memory.reset() metrics['policy_switch_epoch'].append(episode) metrics['policy_switch_sample'].append(total_collected_samples) policy_switch_samples = policy_switch_samples[1:] # Test a policy if episode % conf.test_interval == 0: test_rewards = [] for _ in range(conf.test_times): episodic_reward = 0 obs = env.reset() for t in range(conf.horizon): h = agent.select_action(obs, eval=True) a = agent.post_process_action(o, h) obs_next, r, done, _ = env.step(a) episodic_reward += r obs = obs_next if done: break test_rewards.append(episodic_reward) test_reward_avg, test_reward_std = np.mean(test_rewards), np.std( test_rewards) print(" TEST --- ({} episodes) Reward {:.1f} (pm {:.1f})".format( conf.test_times, test_reward_avg, test_reward_std)) metrics['test_episode'].append(episode) metrics['test_reward'].append(test_rewards) lineplot(metrics['test_episode'][-len(metrics['test_reward']):], metrics['test_reward'], "TEST_REWARD", log_dir) # Save checkpoint if episode % conf.checkpoint_interval: save_checkpoint() # Save the final model torch.save({'agent': agent.state_dict()}, os.path.join(data_dir, 'final_model.pth'))