예제 #1
0
 def test_ReplayBuffer(self):
     mem = ReplayBuffer(2)
     mem.push(1)
     mem.push(2)
     [sample] = mem.sample(2)
     self.assertEqual(sorted(sample), [1, 2])
     mem.push(3)
     [sample] = mem.sample(2)
     self.assertEqual(sorted(sample), [2, 3])
     mem.push(4)
     [sample] = mem.sample(2)
     self.assertEqual(sorted(sample), [3, 4])
예제 #2
0
class Trainer:
    def __init__(self, policyclass, config):
        self.config = config
        self.env = gym.make(config.env)
        self.device = torch.cuda.current_device() if torch.cuda.is_available(
        ) else "cpu"
        self.a_space, self.obs_space = self.env.action_space.n, self.env.observation_space.shape[
            0]
        self.policy_net = policyclass(self.obs_space,
                                      self.a_space).to(self.device)
        self.target_net = policyclass(self.obs_space,
                                      self.a_space).to(self.device)
        self.target_net.eval()
        self.buf = ReplayBuffer(config.capacity)
        self.lossfn = nn.MSELoss()
        self.optimizer = AdamW(self.policy_net.parameters(), lr=config.lr)
        self.eps = config.eps_start
        self.eps_interval = (config.eps_start -
                             config.eps_end) / config.num_epochs
        self.eps_interval *= 2
        if config.render:
            self.env.render()
        if config.monitor:
            self.env = gym.wrappers.Monitor(self.env, config.vid_save_path, \
                                            video_callable = lambda ep: ep % config.vid_interval == 0,force= True)

    def zero_grad(self):
        for param in self.policy_net.parameters():
            param.grad = None

    def train(self):
        config, env, buf = self.config, self.env, self.buf
        lr = config.lr

        def update_target_net():
            self.target_net.load_state_dict(self.policy_net.state_dict())

        def run_epoch():
            loss = None
            curr_state = env.reset()
            done, next_state, reward_list = False, None, []
            while not done:
                action = self.get_eps_act(
                    torch.tensor(curr_state,
                                 device=self.device,
                                 dtype=torch.float32).unsqueeze(0))
                next_state, reward, done, _ = env.step(action)
                reward_list.append(reward)
                self.buf.push(curr_state, action, reward, next_state, done)
                curr_state = next_state
                if len(self.buf) >= self.config.batch_size:
                    loss = self.optimize_model()
            return sum(reward_list), loss

        pbar = tqdm(range(config.num_epochs))
        for eps in pbar:
            rewards, loss = run_epoch()
            if eps % config.target_update == 0:
                update_target_net()
            if loss is not None:
                strprint = f"epoch {eps+1}: loss {loss:.5f}. eps {self.eps} reward {rewards}"
            else:
                strprint = f"epoch {eps+1}: eps {self.eps} reward {rewards}"
            pbar.set_description(strprint)
            if self.eps > self.config.eps_end:
                self.eps = self.eps - self.eps_interval
        self.save_model()

    def optimize_model(self):
        S, A, R, S_, done = self.buf.torch_samples(self.config.batch_size,
                                                   device=self.device)
        target = self.config.gamma * self.target_net(S_).max(
            1)[0].detach().view(-1, 1) * (1 - done)
        target = target + R
        estimate = self.policy_net(S).gather(1, A)
        loss = self.lossfn(estimate, target)
        self.zero_grad()
        loss.backward()
        nn.utils.clip_grad_value_(self.policy_net.parameters(),
                                  self.config.grad_norm_clip)
        self.optimizer.step()
        return loss.item()

    def get_eps_act(self, state):
        """
            accepts a tensor that is loaded onto the device already
        """
        if random.random() > self.eps:
            action = self.policy_net(state).max(1)[1].item()
        else:
            action = random.randrange(self.a_space)
        return action

    def save_model(self):
        logger.info("Saving Model to {self.config.save_path}")
        torch.save(self.policy_net.state_dict(), self.config.save_path)
예제 #3
0
def train(config_filepath, save_dir, device, visualize_interval):
    conf = load_toml_config(config_filepath)
    data_dir, log_dir = create_save_dir(save_dir)
    # Save config file
    shutil.copyfile(config_filepath,
                    os.path.join(save_dir, os.path.basename(config_filepath)))
    device = torch.device(device)

    # Set up log metrics
    metrics = {
        'episode': [],
        'episodic_step': [],
        'collected_total_samples': [],
        'reward': [],
        'q_loss': [],
        'policy_loss': [],
        'alpha_loss': [],
        'alpha': [],
        'policy_switch_epoch': [],
        'policy_switch_sample': [],
        'test_episode': [],
        'test_reward': [],
    }

    policy_switch_samples = conf.policy_switch_samples if hasattr(
        conf, "policy_switch_samples") else None
    total_collected_samples = 0

    # Create environment
    env = make_env(conf.environment, render=False)

    # Instantiate modules
    memory = ReplayBuffer(int(conf.replay_buffer_capacity),
                          env.observation_space.shape, env.action_space.shape)
    agent = getattr(agents, conf.agent_type)(env.observation_space,
                                             env.action_space,
                                             device=device,
                                             **conf.agent)

    # Load checkpoint if specified in config
    if conf.checkpoint != '':
        ckpt = torch.load(conf.checkpoint, map_location=device)
        metrics = ckpt['metrics']
        agent.load_state_dict(ckpt['agent'])
        memory.load_state_dict(ckpt['memory'])
        policy_switch_samples = ckpt['policy_switch_samples']
        total_collected_samples = ckpt['total_collected_samples']

    def save_checkpoint():
        # Save checkpoint
        ckpt = {
            'metrics': metrics,
            'agent': agent.state_dict(),
            'memory': memory.state_dict(),
            'policy_switch_samples': policy_switch_samples,
            'total_collected_samples': total_collected_samples
        }
        path = os.path.join(data_dir, 'checkpoint.pth')
        torch.save(ckpt, path)

        # Save agent model only
        model_ckpt = {'agent': agent.state_dict()}
        model_path = os.path.join(data_dir, 'model.pth')
        torch.save(model_ckpt, model_path)

        # Save metrics only
        metrics_ckpt = {'metrics': metrics}
        metrics_path = os.path.join(data_dir, 'metrics.pth')
        torch.save(metrics_ckpt, metrics_path)

    # Train agent
    init_episode = 0 if len(
        metrics['episode']) == 0 else metrics['episode'][-1] + 1
    pbar = tqdm.tqdm(range(init_episode, conf.episodes))
    reward_moving_avg = None
    agent_update_count = 0
    for episode in pbar:
        episodic_reward = 0
        o = env.reset()
        q1_loss, q2_loss, policy_loss, alpha_loss, alpha = None, None, None, None, None

        for t in range(conf.horizon):
            if total_collected_samples <= conf.random_sample_num:  # Select random actions at the begining of training.
                h = env.action_space.sample()
            elif memory.step <= conf.random_sample_num:  # Select actions from random latent variable soon after inserting a new subpolicy.
                h = agent.select_action(o, random=True)
            else:
                h = agent.select_action(o)

            a = agent.post_process_action(
                o, h)  # Convert abstract action h to actual action a

            o_next, r, done, _ = env.step(a)
            total_collected_samples += 1
            episodic_reward += r
            memory.push(o, h, r, o_next, done)
            o = o_next

            if memory.step > conf.random_sample_num:
                # Update agent
                batch_data = memory.sample(conf.agent_update_batch_size)
                q1_loss, q2_loss, policy_loss, alpha_loss, alpha = agent.update_parameters(
                    batch_data, agent_update_count)
                agent_update_count += 1

            if done:
                break

        # Describe and save episodic metrics
        reward_moving_avg = (
            1. - MOVING_AVG_COEF
        ) * reward_moving_avg + MOVING_AVG_COEF * episodic_reward if reward_moving_avg else episodic_reward
        pbar.set_description(
            "EPISODE {} (total samples {}, subpolicy samples {}) --- Step {}, Reward {:.1f} (avg {:.1f})"
            .format(episode, total_collected_samples, memory.step, t,
                    episodic_reward, reward_moving_avg))
        metrics['episode'].append(episode)
        metrics['reward'].append(episodic_reward)
        metrics['episodic_step'].append(t)
        metrics['collected_total_samples'].append(total_collected_samples)
        if episode % visualize_interval == 0:
            # Visualize metrics
            lineplot(metrics['episode'][-len(metrics['reward']):],
                     metrics['reward'], 'REWARD', log_dir)
            reward_avg = np.array(metrics['reward']) / np.array(
                metrics['episodic_step'])
            lineplot(metrics['episode'][-len(reward_avg):], reward_avg,
                     'AVG_REWARD', log_dir)
            lineplot(
                metrics['collected_total_samples'][-len(metrics['reward']):],
                metrics['reward'],
                'SAMPLE-REWARD',
                log_dir,
                xaxis='sample')

        # Save metrics for agent update
        if q1_loss is not None:
            metrics['q_loss'].append(np.mean([q1_loss, q2_loss]))
            metrics['policy_loss'].append(policy_loss)
            metrics['alpha_loss'].append(alpha_loss)
            metrics['alpha'].append(alpha)
            if episode % visualize_interval == 0:
                lineplot(metrics['episode'][-len(metrics['q_loss']):],
                         metrics['q_loss'], 'Q_LOSS', log_dir)
                lineplot(metrics['episode'][-len(metrics['policy_loss']):],
                         metrics['policy_loss'], 'POLICY_LOSS', log_dir)
                lineplot(metrics['episode'][-len(metrics['alpha_loss']):],
                         metrics['alpha_loss'], 'ALPHA_LOSS', log_dir)
                lineplot(metrics['episode'][-len(metrics['alpha']):],
                         metrics['alpha'], 'ALPHA', log_dir)

        # Insert new subpolicy layer and reset memory if a specific amount of samples is collected
        if policy_switch_samples and len(
                policy_switch_samples
        ) > 0 and total_collected_samples >= policy_switch_samples[0]:
            print(
                "----------------------\nInser new policy\n----------------------"
            )
            agent.insert_subpolicy()
            memory.reset()
            metrics['policy_switch_epoch'].append(episode)
            metrics['policy_switch_sample'].append(total_collected_samples)
            policy_switch_samples = policy_switch_samples[1:]

        # Test a policy
        if episode % conf.test_interval == 0:
            test_rewards = []
            for _ in range(conf.test_times):
                episodic_reward = 0
                obs = env.reset()
                for t in range(conf.horizon):
                    h = agent.select_action(obs, eval=True)
                    a = agent.post_process_action(o, h)
                    obs_next, r, done, _ = env.step(a)
                    episodic_reward += r
                    obs = obs_next

                    if done:
                        break

                test_rewards.append(episodic_reward)

            test_reward_avg, test_reward_std = np.mean(test_rewards), np.std(
                test_rewards)
            print("   TEST --- ({} episodes) Reward {:.1f} (pm {:.1f})".format(
                conf.test_times, test_reward_avg, test_reward_std))
            metrics['test_episode'].append(episode)
            metrics['test_reward'].append(test_rewards)
            lineplot(metrics['test_episode'][-len(metrics['test_reward']):],
                     metrics['test_reward'], "TEST_REWARD", log_dir)

        # Save checkpoint
        if episode % conf.checkpoint_interval:
            save_checkpoint()

    # Save the final model
    torch.save({'agent': agent.state_dict()},
               os.path.join(data_dir, 'final_model.pth'))