示例#1
0
    def train(self, env, time_remaining=1e9, noise_type=None, noise_value=None):
        time_start = time.time()

        sd = 1 if env.has_discrete_state_space() else self.state_dim
        ad = 1 if env.has_discrete_action_space() else self.action_dim
        replay_buffer = ReplayBuffer(state_dim=sd, action_dim=ad, device=self.device, max_size=self.rb_size)

        avg_meter_reward = AverageMeter(print_str="Average reward: ")

        # training loop
        for episode in range(self.train_episodes):
            # early out if timeout
            if self.time_is_up(avg_meter_reward=avg_meter_reward,
                               max_episodes=self.train_episodes,
                               time_elapsed=time.time() - time_start,
                               time_remaining=time_remaining):
                break

            self.update_parameters_per_episode(episode=episode)

            state = env.reset()
            episode_reward = 0

            for t in range(0, env.max_episode_steps()):
                action = self.select_train_action(state=state, env=env)

                # live view
                if self.render_env and episode % 10 == 0:
                    env.render()

                # state-action transition
                next_state, reward, done = env.step(action=action)
                next_state, reward = self.apply_noise(next_state=next_state, reward=reward, noise_type=noise_type, noise_value=noise_value)
                replay_buffer.add(state=state, action=action, next_state=next_state, reward=reward, done=done)
                state = next_state
                episode_reward += reward

                # train
                if episode >= self.init_episodes:
                    self.learn(replay_buffer=replay_buffer, env=env)

                if done > 0.5:
                    break

            # logging
            avg_meter_reward.update(episode_reward, print_rate=self.print_rate)

            # quit training if environment is solved
            if self.env_solved(env=env, avg_meter_reward=avg_meter_reward, episode=episode):
                break

        env.close()

        return avg_meter_reward.get_raw_data(), replay_buffer
    def test(self, env, time_remaining=1e9):
        discretize_action = False

        sd = 1 if env.has_discrete_state_space() else self.state_dim

        if env.has_discrete_action_space():
            ad = 1
            # in case of td3_discrete, action_dim=1 does not reflect the required action_dim for the gumbel softmax distribution
            if self.agent_name == "td3_discrete_vary":
                ad = env.get_action_dim()
                discretize_action = True
        else:
            ad = self.action_dim

        replay_buffer = ReplayBuffer(state_dim=sd,
                                     action_dim=ad,
                                     device=self.device,
                                     max_size=int(1e6))

        env.set_agent_params(same_action_num=self.same_action_num,
                             gamma=self.gamma)

        with torch.no_grad():
            time_start = time.time()

            avg_meter_reward = AverageMeter(print_str="Average reward: ")
            avg_meter_episode_length = AverageMeter(
                print_str="Average episode length: ")

            # training loop
            for episode in range(self.test_episodes):
                # early out if timeout
                if self.time_is_up(
                        avg_meter_reward=avg_meter_reward,
                        avg_meter_episode_length=avg_meter_episode_length,
                        max_episodes=self.test_episodes,
                        time_elapsed=time.time() - time_start,
                        time_remaining=time_remaining):
                    break

                state = env.reset()
                episode_reward = 0
                episode_length = 0

                for t in range(0, env.max_episode_steps(),
                               self.same_action_num):
                    action = self.select_test_action(state, env)

                    # live view
                    if self.render_env:
                        env.render()

                    # state-action transition
                    # required due to gumble softmax in td3 discrete
                    if discretize_action:
                        next_state, reward, done = env.step(
                            action=action.argmax().unsqueeze(0))
                    else:
                        next_state, reward, done = env.step(action=action)
                    replay_buffer.add(state=state,
                                      action=action,
                                      next_state=next_state,
                                      reward=reward,
                                      done=done)

                    state = next_state
                    episode_reward += reward
                    episode_length += 1

                    if done > 0.5:
                        break

                # logging
                avg_meter_episode_length.update(episode_length, print_rate=1e9)
                avg_meter_reward.update(episode_reward.item(),
                                        print_rate=self.print_rate)

            env.close()

        # todo: use dict to reduce confusions and bugs
        return avg_meter_reward.get_raw_data(
        ), avg_meter_episode_length.get_raw_data(), replay_buffer
    def train(self, env, test_env=None, time_remaining=1e9):
        time_start = time.time()

        discretize_action = False

        sd = 1 if env.has_discrete_state_space() else self.state_dim

        # todo: @fabio use "hasattr" and custom function in derived class (see below)
        if env.has_discrete_action_space():
            ad = 1
            # in case of td3_discrete, action_dim=1 does not reflect the required action_dim for the gumbel softmax distribution
            if "td3_discrete" in self.agent_name:
                ad = env.get_action_dim()
                discretize_action = True
        else:
            ad = self.action_dim

        replay_buffer = ReplayBuffer(state_dim=sd,
                                     action_dim=ad,
                                     device=self.device,
                                     max_size=self.rb_size)

        avg_meter_reward = AverageMeter(print_str="Average reward: ")
        avg_meter_episode_length = AverageMeter(
            print_str="Average episode length: ")

        env.set_agent_params(same_action_num=self.same_action_num,
                             gamma=self.gamma)

        # training loop
        for episode in range(self.train_episodes):
            # early out if timeout
            if self.time_is_up(
                    avg_meter_reward=avg_meter_reward,
                    avg_meter_episode_length=avg_meter_episode_length,
                    max_episodes=self.train_episodes,
                    time_elapsed=time.time() - time_start,
                    time_remaining=time_remaining):
                break

            if hasattr(self, 'update_parameters_per_episode'):
                self.update_parameters_per_episode(episode=episode)

            state = env.reset()
            episode_reward = 0
            episode_length = 0
            for t in range(0, env.max_episode_steps(), self.same_action_num):
                action = self.select_train_action(state=state,
                                                  env=env,
                                                  episode=episode)

                # live view
                if self.render_env:
                    env.render()

                # state-action transition
                # required due to gumble softmax in td3 discrete
                # todo @fabio: move into agent-specific select_train_action, do the same for test
                if discretize_action:
                    next_state, reward, done = env.step(
                        action=action.argmax().unsqueeze(0))
                else:
                    next_state, reward, done = env.step(action=action)
                replay_buffer.add(state=state,
                                  action=action,
                                  next_state=next_state,
                                  reward=reward,
                                  done=done)

                state = next_state
                episode_reward += reward
                episode_length += self.same_action_num

                # train
                if episode >= self.init_episodes:
                    self.learn(replay_buffer=replay_buffer,
                               env=env,
                               episode=episode)

                if done > 0.5:
                    break

            # logging
            avg_meter_episode_length.update(episode_length, print_rate=1e9)

            if test_env is not None:
                avg_reward_test_raw, _, _ = self.test(test_env)
                avg_meter_reward.update(statistics.mean(avg_reward_test_raw),
                                        print_rate=self.print_rate)
            else:
                avg_meter_reward.update(episode_reward,
                                        print_rate=self.print_rate)

            # quit training if environment is solved
            if episode >= self.init_episodes:
                if test_env is not None:
                    break_env = test_env
                else:
                    break_env = env
                if self.env_solved(env=break_env,
                                   avg_meter_reward=avg_meter_reward,
                                   episode=episode):
                    print('early out after ' + str(episode) + ' episodes')
                    break

        env.close()

        # todo: use dict to reduce confusions and bugs
        return avg_meter_reward.get_raw_data(
        ), avg_meter_episode_length.get_raw_data(), replay_buffer
示例#4
0
    def train(self, env, time_remaining=1e9, test_env=None):

        sd = 1 if env.has_discrete_state_space() else self.state_dim
        ad = 1 if env.has_discrete_action_space() else self.action_dim
        replay_buffer = ReplayBuffer(state_dim=sd, action_dim=ad, device=self.device, max_size=self.rb_size)

        avg_meter_reward = AverageMeter(print_str="Average reward: ")
        avg_meter_episode_length = AverageMeter(print_str="Average episode length: ")

        env.set_agent_params(same_action_num=self.same_action_num, gamma=self.gamma)

        time_step = 0

        # training loop
        for episode in range(self.train_episodes):
            state = env.reset()
            episode_reward = 0
            episode_length = 0

            for t in range(0, env.max_episode_steps(), self.same_action_num):
                time_step += self.same_action_num

                # run old policy
                action = self.actor_old(state.to(self.device)).cpu()
                next_state, reward, done = env.step(action=action)

                # live view
                if self.render_env and episode % 100 == 0:
                    env.render()

                replay_buffer.add(state=state, action=action, next_state=next_state, reward=reward, done=done)
                state = next_state
                episode_reward += reward
                episode_length += self.same_action_num

                # train after certain amount of timesteps
                if time_step / env.max_episode_steps() > self.update_episodes:
                    self.learn(replay_buffer)
                    replay_buffer.clear()
                    time_step = 0
                if done > 0.5:
                    break

            # logging
            avg_meter_episode_length.update(episode_length, print_rate=1e9)

            if test_env is not None:
                avg_reward_test_raw, _, _ = self.test(test_env)
                avg_meter_reward.update(statistics.mean(avg_reward_test_raw), print_rate=self.print_rate)
            else:
                avg_meter_reward.update(episode_reward, print_rate=self.print_rate)

            # quit training if environment is solved
            if episode >= self.init_episodes:
                if test_env is not None:
                    break_env = test_env
                else:
                    break_env = env
                if self.env_solved(env=break_env, avg_meter_reward=avg_meter_reward, episode=episode):
                    print('early out after ' + str(episode) + ' episodes')
                    break

        env.close()

        return avg_meter_reward.get_raw_data(), avg_meter_episode_length.get_raw_data(), {}