Exemplo n.º 1
0
class Learner:
    def __init__(self, learner_id, config, dev, shared_state, shared_queue):

        self.action_size = config['action_space']
        self.obs_size = config['obs_space']

        self.shared_queue = shared_queue
        self.shared_state = shared_state

        self.dev = dev
        self.id = learner_id
        self.burn_in_length = config['burn_in_length']  # 40-80
        self.learning_length = config['learning_length']
        self.sequence_length = self.burn_in_length + self.learning_length
        self.n_step = config['n_step']
        self.sequence = []
        self.recurrent_state = []
        self.priority = []
        self.td_loss = deque(maxlen=self.learning_length)

        self.gamma = config['gamma']
        #        self.actor_parameter_update_interval = config['actor_parameter_update_interval']

        self.actor = ActorNet(dev, config).to(self.dev)
        self.target_actor = ActorNet(dev, config).to(self.dev)
        self.critic = CriticNet(dev, config).to(self.dev)
        self.target_critic = CriticNet(dev, config).to(self.dev)

        self.actor.load_state_dict(self.shared_state["actor"].state_dict())
        self.target_actor.load_state_dict(
            self.shared_state["target_actor"].state_dict())
        self.critic.load_state_dict(self.shared_state["critic"].state_dict())
        self.target_critic.load_state_dict(
            self.shared_state["target_critic"].state_dict())

        #        self.actor.load_state_dict(self.shared_state["actor"])
        #        self.target_actor.load_state_dict(self.shared_state["target_actor"])
        #        self.critic.load_state_dict(self.shared_state["critic"])
        #        self.target_critic.load_state_dict(self.shared_state["target_critic"])

        self.learner_actor_rate = config['learner_actor_rate']

        self.num_actors = learner_id
        self.n_actions = 1
        self.max_frame = config['learner_max_frame']

        self.memory_sequence_size = config['memory_sequence_size']
        self.batch_size = config['batch_size']
        self.memory = LearnerReplayMemory(self.memory_sequence_size, config,
                                          dev)

        self.model_path = './'
        #        self.memory_path = './memory_data/'
        #        self.model_save_interval = 10 # 50
        self.learner_parameter_update_interval = config[
            'learner_parameter_update_interval']  # 50
        self.target_update_inverval = config['target_update_interval']  # 100

        self.gamma = config['gamma']
        self.actor_lr = config['actor_lr']
        self.critic_lr = config['critic_lr']
        self.actor_optimizer = optim.Adam(self.actor.parameters(),
                                          lr=self.actor_lr)
        self.actor_criterion = nn.MSELoss()
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=self.critic_lr)
        self.critic_criterion = nn.MSELoss()

    def __del__(self):
        self.shared_queue.close()
        self.shared_state.close()

#        self.save_model()

    def save_model(self):
        model_dict = {
            'actor': self.actor.state_dict(),
            'target_actor': self.target_actor.state_dict(),
            'critic': self.critic.state_dict(),
            'target_critic': self.target_critic.state_dict()
        }
        torch.save(model_dict, self.model_path + 'model.pt')

    def update_target_model(self):
        self.target_actor.load_state_dict(self.actor.state_dict())
        self.target_critic.load_state_dict(self.critic.state_dict())

    def run(self):
        time_check(-1)
        while self.memory.size() < self.batch_size:
            self.memory.append(self.shared_queue.get(block=True))
            #            self.memory.append(self.shared_queue.get())
            print('\rmem size: ', self.memory.size(), end='\r')
        time_check(1)
        count_mem = 0
        frame = 0
        win_v = vis.line(Y=torch.Tensor([0]), opts=dict(title='V_loss'))
        win_p = vis.line(Y=torch.Tensor([0]), opts=dict(title='P_loss'))

        while frame < self.max_frame:
            #            sleep(0.0001)
            #            if self.shared_queue.qsize()==0 and count_mem <0:
            #                self.memory.append(self.shared_queue.get(block=True))
            #
            #            for i in range(self.shared_queue.qsize()):
            #                self.memory.append(self.shared_queue.get(block=False))
            #                count_mem += self.learner_actor_rate

            #            print('waiting  shared q {}/{}'.format(self.memory.size(),self.batch_size))

            #            self.shared_state['frame'][self.id]=frame
            #            while self.shared_state['sleep'][self.id] :
            #                sleep(0.5)
            #            if self.shared_queue.qsize()==0 and count_mem <0:
            #                self.memory.append(self.shared_queue.get(block=True))
            #                self.memory.append(self.shared_queue.get())

            #            for i in range(self.shared_queue.qsize()):
            ##                global_buf.append(self.shared_queue.get())
            #                self.memory.append(self.shared_queue.get())
            #                count_mem += self.learner_actor_rate

            if self.shared_queue.qsize() != 0:
                self.memory.append(self.shared_queue.get(block=True))

            frame += 1

            count_mem -= 1

            episode_index, sequence_index, obs_seq, action_seq, reward_seq, gamma_seq, a_state, ta_state, c_state, tc_state = self.memory.sample(
            )

            self.actor.set_state(a_state[0], a_state[1])
            self.target_actor.set_state(ta_state[0], ta_state[1])
            self.critic.set_state(c_state[0], c_state[1])
            self.target_critic.set_state(tc_state[0], tc_state[1])

            ### burn-in step ###
            _ = [self.actor(obs_seq[i]) for i in range(self.burn_in_length)]
            _ = [
                self.critic(obs_seq[i], action_seq[i])
                for i in range(self.burn_in_length)
            ]
            _ = [
                self.target_actor(obs_seq[i])
                for i in range(self.burn_in_length + self.n_step)
            ]
            _ = [
                self.target_critic(obs_seq[i], action_seq[i])
                for i in range(self.burn_in_length + self.n_step)
            ]
            ### learning steps ###

            # update ciritic
            q_value = torch.zeros(self.learning_length * self.batch_size,
                                  self.n_actions)

            target_q_value = torch.zeros(
                self.learning_length * self.batch_size, self.n_actions)
            for i in range(self.learning_length):
                obs_i = self.burn_in_length + i
                next_obs_i = self.burn_in_length + i + self.n_step
                q_value[i * self.batch_size:(i + 1) *
                        self.batch_size] = self.critic(obs_seq[obs_i],
                                                       action_seq[obs_i])
                with torch.no_grad():
                    next_q_value = self.target_critic(
                        obs_seq[next_obs_i],
                        self.target_actor(obs_seq[next_obs_i]))
                    target_q_val = reward_seq[obs_i] + (
                        gamma_seq[next_obs_i]**self.n_step) * next_q_value
                    #                target_q_val = invertical_vf(target_q_val)
                    target_q_value[i * self.batch_size:(i + 1) *
                                   self.batch_size] = target_q_val

            critic_loss = self.actor_criterion(q_value,
                                               target_q_value.detach())
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # update actor
            self.actor.reset_state()
            self.critic.reset_state()
            actor_loss = torch.zeros(self.learning_length * self.batch_size,
                                     self.n_actions).to(self.dev)
            for i in range(self.learning_length):
                obs_i = i + self.burn_in_length
                action = self.actor(obs_seq[obs_i])
                actor_loss[i * self.batch_size:(i + 1) *
                           self.batch_size] = -self.critic(
                               obs_seq[obs_i], self.actor(obs_seq[obs_i]))
            actor_loss = actor_loss.mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # update target networks
            if frame % self.target_update_inverval == 0:
                self.update_target_model()

            print('#', frame, 'critic_loss:', critic_loss.item(),
                  '  actor_loss:', actor_loss.item(), '  count:', count_mem)
            win_p = vis.line(X=torch.Tensor([frame]),
                             Y=torch.Tensor([actor_loss.item()]),
                             win=win_p,
                             update='append')
            win_v = vis.line(X=torch.Tensor([frame]),
                             Y=torch.Tensor([critic_loss.item()]),
                             win=win_v,
                             update='append')

            # calc priority
            average_td_loss = ((q_value - target_q_value)**2).detach().to(
                self.dev)

            #            average_td_loss = np.mean(((q_value - target_q_value)**2).detach().cpu().numpy() , axis = 1)
            for i in range(len(episode_index)):
                td = average_td_loss[i:-1:self.batch_size]
                self.memory.priority[episode_index[i]][
                    sequence_index[i]] = calc_priority(td).cpu().view(1, -1)
                self.memory.total_priority[episode_index[i]] = torch.cat(
                    self.memory.priority[episode_index[i]]).sum(0).view(1, -1)


#                self.memory.priority[episode_index[i]][sequence_index[i]] = calc_priority(td)
#                self.memory.total_priority[episode_index[i]] = sum(self.memory.priority[episode_index[i]])

#            if frame % self.model_save_interval == 0:
#                self.save_model()

            if frame % self.learner_parameter_update_interval == 0:
                #                print('learner update ')

                #                [self.shared_state["actor"][k] = v.cpu() for k,v in self.actor.state_dict().item() ]
                #                [self.shared_state["target_actor"][k] = v.cpu() for k,v in self.target_actor.state_dict().item() ]
                #                [self.shared_state["critic"][k] = v.cpu() for k,v in self.critic.state_dict().item() ]
                #                [self.shared_state["target_critic"][k] = v.cpu() for k,v in self.target_critic.state_dict().item() ]

                #
                #                for k,v in self.actor.state_dict().items():
                #                    self.shared_state["actor"][k] = v.cpu()
                #                for k,v in self.target_actor.state_dict().items():
                #                    self.shared_state["target_actor"][k] = v.cpu()
                #                for k,v in self.critic.state_dict().items():
                #                    self.shared_state["critic"][k] = v.cpu()
                #                for k,v in self.target_critic.state_dict().items():
                #                    self.shared_state["target_critic"][k] = v.cpu()

                #                self.shared_state["actor"] = self.actor.state_dict()
                #                self.shared_state["target_actor"] = self.target_actor.state_dict()
                #                self.shared_state["critic"] = self.critic.state_dict()
                #                self.shared_state["target_critic"] = self.target_critic.state_dict()

                self.shared_state["actor"].load_state_dict(
                    self.actor.state_dict())
                self.shared_state["critic"].load_state_dict(
                    self.critic.state_dict())
                self.shared_state["target_actor"].load_state_dict(
                    self.target_actor.state_dict())
                self.shared_state["target_critic"].load_state_dict(
                    self.target_critic.state_dict())
                for i in range(self.num_actors):
                    self.shared_state["update"][i] = True

                print('learner_update', self.actor.policy_l0.weight.data[0][0])

            self.actor.reset_state()
            self.target_actor.reset_state()
            self.critic.reset_state()
            self.target_critic.reset_state()
Exemplo n.º 2
0
class Learner:
    def __init__(self, n_actors):
        self.env = suite.load(domain_name="walker", task_name="run")
        self.n_actions = self.env.action_spec().shape[0]
        self.obs_size = get_obs(self.env.reset().observation).shape[1]

        self.n_actors = n_actors
        self.burn_in_length = 20  # 40-80
        self.learning_length = 40
        self.sequence_length = self.burn_in_length + self.learning_length
        self.n_step = 5
        self.memory_sequence_size = 5000000
        self.batch_size = 32
        self.memory = LearnerReplayMemory(
            memory_sequence_size=self.memory_sequence_size,
            batch_size=self.batch_size)

        self.model_path = './model_data/'
        self.memory_path = './memory_data/'
        self.actor = ActorNet(self.obs_size, self.n_actions, 0).cuda()
        self.target_actor = deepcopy(self.actor).eval()
        self.critic = CriticNet(self.obs_size, self.n_actions, 0).cuda()
        self.target_critic = deepcopy(self.critic).eval()
        self.model_save_interval = 50  # 50
        self.memory_update_interval = 50  # 50
        self.target_update_inverval = 500  # 100

        self.gamma = 0.997
        self.actor_lr = 1e-4
        self.critic_lr = 1e-3
        self.actor_optimizer = optim.Adam(self.actor.parameters(),
                                          lr=self.actor_lr)
        self.actor_criterion = nn.MSELoss()
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=self.critic_lr)
        self.critic_criterion = nn.MSELoss()
        self.save_model()

    def save_model(self):
        model_dict = {
            'actor': self.actor.state_dict(),
            'target_actor': self.target_actor.state_dict(),
            'critic': self.critic.state_dict(),
            'target_critic': self.target_critic.state_dict()
        }
        torch.save(model_dict, self.model_path + 'model.pt')

    def update_target_model(self):
        self.target_actor.load_state_dict(self.actor.state_dict())
        self.target_critic.load_state_dict(self.critic.state_dict())

    def run(self):
        # memory not enough
        while self.memory.sequence_counter < self.batch_size * 100:
            for i in range(self.n_actors):
                is_memory = os.path.isfile(self.memory_path +
                                           '/memory{}.pt'.format(i))
                if is_memory:
                    self.memory.load(i)
                sleep(0.1)
            print('learner memory sequence size:',
                  self.memory.sequence_counter)

        step = 0
        while True:
            if step % 100 == 0:
                print('learning step:', step)
            start = time()
            step += 1

            episode_index, sequence_index, obs_seq, action_seq, reward_seq, terminal_seq, a_state, ta_state, c_state, tc_state = self.memory.sample(
            )

            self.actor.set_state(a_state[0], a_state[1])
            self.target_actor.set_state(ta_state[0], ta_state[1])
            self.critic.set_state(c_state[0], c_state[1])
            self.target_critic.set_state(tc_state[0], tc_state[1])

            ### burn-in step ###
            _ = [self.actor(obs) for obs in obs_seq[0:self.burn_in_length]]
            _ = [
                self.critic(obs, action)
                for obs, action in zip(obs_seq[0:self.burn_in_length],
                                       action_seq[0:self.burn_in_length])
            ]
            _ = [
                self.target_actor(obs)
                for obs in obs_seq[0:self.burn_in_length + self.n_step]
            ]
            _ = [
                self.target_critic(obs, action) for obs, action in zip(
                    obs_seq[0:self.burn_in_length +
                            self.n_step], action_seq[0:self.burn_in_length +
                                                     self.n_step])
            ]

            ### learning steps ###

            # update ciritic
            q_value = torch.zeros(self.learning_length * self.batch_size,
                                  self.n_actions).cuda()
            target_q_value = torch.zeros(
                self.learning_length * self.batch_size, self.n_actions).cuda()
            for i in range(self.learning_length):
                obs_i = self.burn_in_length + i
                next_obs_i = self.burn_in_length + i + self.n_step
                q_value[i * self.batch_size:(i + 1) *
                        self.batch_size] = self.critic(obs_seq[obs_i],
                                                       action_seq[obs_i])
                next_q_value = self.target_critic(
                    obs_seq[next_obs_i],
                    self.target_actor(obs_seq[next_obs_i]))
                target_q_val = reward_seq[obs_i] + (
                    self.gamma**self.n_step) * (
                        1. - terminal_seq[next_obs_i - 1]) * next_q_value
                target_q_val = invertical_vf(target_q_val)
                target_q_value[i * self.batch_size:(i + 1) *
                               self.batch_size] = target_q_val

            critic_loss = self.actor_criterion(q_value,
                                               target_q_value.detach())
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # update actor
            self.actor.reset_state()
            self.critic.reset_state()
            actor_loss = torch.zeros(self.learning_length * self.batch_size,
                                     self.n_actions).cuda()
            for i in range(self.learning_length):
                obs_i = i + self.burn_in_length
                action = self.actor(obs_seq[obs_i])
                actor_loss[i * self.batch_size:(i + 1) *
                           self.batch_size] = -self.critic(
                               obs_seq[obs_i], self.actor(obs_seq[obs_i]))
            actor_loss = actor_loss.mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # update target networks
            if step % self.target_update_inverval == 0:
                self.update_target_model()

            # calc priority
            average_td_loss = np.mean(
                (q_value - target_q_value).detach().cpu().numpy()**2., axis=1)
            for i in range(len(episode_index)):
                td = average_td_loss[i:-1:self.batch_size]
                self.memory.priority[episode_index[i]][
                    sequence_index[i]] = calc_priority(td)
                self.memory.total_priority[episode_index[i]] = sum(
                    self.memory.priority[episode_index[i]])

            if step % self.model_save_interval == 0:
                self.save_model()

            if step % self.memory_update_interval == 0:
                for i in range(self.n_actors):
                    is_memory = os.path.isfile(self.memory_path +
                                               '/memory{}.pt'.format(i))
                    if is_memory:
                        self.memory.load(i)
                    sleep(0.1)

            self.actor.reset_state()
            self.target_actor.reset_state()
            self.critic.reset_state()
            self.target_critic.reset_state()