Esempio n. 1
0
    def sample(self, policy, env_name, params=None, gamma=0.95, batch_size=None):
        if batch_size is None:
            batch_size = self.args.fast_batch_size

        episodes = BatchEpisodes(batch_size=batch_size, gamma=gamma, device=self.args.device)
        for i in range(batch_size):
            self.queue.put(i)
        for _ in range(self.args.num_workers):
            self.queue.put(None)
        observations, batch_ids = self.envs[env_name].reset()
        dones = [False]
        while (not all(dones)) or (not self.queue.empty()):
            with torch.no_grad():
                # Process observation and pad zeros if needed
                if observations.shape[-1] < policy.input_size:
                    target = np.zeros((observations.shape[0], policy.input_size), dtype=observations.dtype)
                    target[:, :int(observations.shape[-1])] = observations
                    observations = target
                observations_tensor = torch.from_numpy(observations).to(device=self.args.device)

                actions_tensor = policy(observations_tensor, params=params).sample()
                actions = actions_tensor.cpu().numpy()
                # Process actions to fit into action_space
                # TODO May need to apply masking laster
                actions_ = actions[:, :int(np.prod(self.envs[env_name].action_space.shape))]
            new_observations, rewards, dones, new_batch_ids, _ = self.envs[env_name].step(actions_)
            episodes.append(observations, actions, rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

        return episodes
Esempio n. 2
0
    def sample(self, policy, params=None, gamma=0.95):
        episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        while (not all(dones)) or (not self.queue.empty()):
            observations_tensor = observations
            actions_tensor = policy(observations_tensor, params=params).sample()
            with tf.device('/CPU:0'):
                actions = actions_tensor.numpy()
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions)
            episodes.append(observations, actions, rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

        return episodes
Esempio n. 3
0
 def sample_maml(self, policy, task=None, batch_id=None, params=None):
     for i in range(self.batch_size):
         self.queue.put(i)
     for _ in range(self.num_workers):
         self.queue.put(None)
     episodes = BatchEpisodes(dic_agent_conf=self.dic_agent_conf)
     observations, batch_ids = self.envs.reset()
     dones = [False]
     if params:  # todo precise load parameter logic
         policy.load_params(params)
     while (not all(dones)) or (not self.queue.empty()):
         actions = policy.choose_action(observations)
         ## for multi_intersection
         actions = np.reshape(actions, (-1, 1))
         new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
             actions)
         episodes.append(observations, actions, new_observations, rewards,
                         batch_ids)
         observations, batch_ids = new_observations, new_batch_ids
     #self.envs.bulk_log()
     return episodes
Esempio n. 4
0
 def sample(self, policy, params=None, gamma=0.95, batch_size=None):
     if batch_size is None:
         batch_size = self.batch_size
     episodes = BatchEpisodes(batch_size=batch_size,
                              gamma=gamma,
                              device=self.device)
     for i in range(batch_size):
         self.queue.put(i)
     for _ in range(self.num_workers):
         self.queue.put(None)
     observations, batch_ids = self.envs.reset()
     dones = [False]
     while (not all(dones)) or (not self.queue.empty()):
         with torch.no_grad():
             observations_tensor = torch.from_numpy(observations).to(
                 device=self.device)
             actions_tensor = policy(observations_tensor,
                                     params=params).sample()
             actions = actions_tensor.cpu().numpy()
         new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
             actions)
         episodes.append(observations, actions, rewards, batch_ids)
         observations, batch_ids = new_observations, new_batch_ids
     return episodes
Esempio n. 5
0
    def sample_meta_test(self,
                         policy,
                         task,
                         batch_id,
                         params=None,
                         target_params=None,
                         old_episodes=None):
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        episodes = BatchEpisodes(dic_agent_conf=self.dic_agent_conf,
                                 old_episodes=old_episodes)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        if params:  # todo precise load parameter logic
            policy.load_params(params)

        while (not all(dones)) or (not self.queue.empty()):
            actions = policy.choose_action(observations)
            ## for multi_intersection
            actions = np.reshape(actions, (-1, 1))
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            episodes.append(observations, actions, new_observations, rewards,
                            batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'UPDATE_PERIOD'] == 0:
                if len(episodes) > self.dic_agent_conf['MAX_MEMORY_LEN']:
                    episodes.forget()

                policy.fit(episodes,
                           params=params,
                           target_params=target_params)
                sample_size = min(self.dic_agent_conf['SAMPLE_SIZE'],
                                  len(episodes))
                slice_index = random.sample(range(len(episodes)), sample_size)
                params = policy.update_params(episodes,
                                              params=copy.deepcopy(params),
                                              lr_step=self.lr_step,
                                              slice_index=slice_index)

                policy.load_params(params)

                self.lr_step += 1
                self.target_step += 1
                if self.target_step == self.dic_agent_conf[
                        'UPDATE_Q_BAR_FREQ']:
                    target_params = params
                    self.target_step = 0

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'TEST_PERIOD'] == 0:
                self.single_test_sample(policy,
                                        task,
                                        self.test_step,
                                        params=params)
                pickle.dump(
                    params,
                    open(
                        os.path.join(
                            self.dic_path['PATH_TO_MODEL'],
                            'params' + "_" + str(self.test_step) + ".pkl"),
                        'wb'))
                write_summary(self.dic_path, task,
                              self.dic_traffic_env_conf["EPISODE_LEN"],
                              batch_id)

                self.test_step += 1
            self.step += 1

        policy.decay_epsilon(batch_id)
        self.envs.bulk_log()
        return params, target_params, episodes
Esempio n. 6
0
    def sample_period(self,
                      policy,
                      task,
                      batch_id,
                      params=None,
                      target_params=None,
                      old_episodes=None):
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        episodes = BatchEpisodes(
            batch_size=self.batch_size,
            dic_traffic_env_conf=self.dic_traffic_env_conf,
            dic_agent_conf=self.dic_agent_conf,
            old_episodes=old_episodes)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        if params:  # todo precise load parameter logic
            policy.load_params(params)

        while (not all(dones)) or (not self.queue.empty()):

            if self.dic_traffic_env_conf['MODEL_NAME'] == 'MetaDQN':
                actions = policy.choose_action(observations,
                                               task_type=task_type)
            else:
                actions = policy.choose_action(observations)
            ## for multi_intersection
            actions = np.reshape(actions, (-1, 1))
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            episodes.append(observations, actions, new_observations, rewards,
                            batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

            # if update
            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'UPDATE_PERIOD'] == 0:
                if len(episodes) > self.dic_agent_conf['MAX_MEMORY_LEN']:
                    #TODO
                    episodes.forget()

                policy.fit(episodes,
                           params=params,
                           target_params=target_params)
                params = policy.update_params(episodes,
                                              params=copy.deepcopy(params),
                                              lr_step=self.lr_step)
                policy.load_params(params)

                self.lr_step += 1
                self.target_step += 1
                if self.target_step == self.dic_agent_conf[
                        'UPDATE_Q_BAR_FREQ']:
                    target_params = params
                    self.target_step = 0

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'TEST_PERIOD'] == 0:
                self.test(policy, task, self.test_step, params=params)
                pickle.dump(
                    params,
                    open(
                        os.path.join(
                            self.dic_path['PATH_TO_MODEL'],
                            'params' + "_" + str(self.test_step) + ".pkl"),
                        'wb'))

                self.test_step += 1
            self.step += 1

        return params, target_params, episodes
Esempio n. 7
0
def train(args, scorer, summary_writer=None):
    device = args.device
    env = create_crop_env(args, scorer)

    model = ActorCritic(args).to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # import pdb; pdb.set_trace();
    training_log_file = open(os.path.join(
        args.model_save_path, 'training.log'), 'w')
    validation_log_file = open(os.path.join(
        args.model_save_path, 'validation.log'), 'w')

    training_log_file.write('Epoch,Cost\n')
    validation_log_file.write('Epoch,Cost\n')

    for train_iter in range(args.n_epochs):
        episode = BatchEpisodes(batch_size=args.batch_size, gamma=args.gamma, device=device)

        for _ in range(args.batch_size):
            done = True
            observation_np = env.reset()

            observations_np, rewards_np, actions_np, hs_ts, cs_ts = [], [], [], [], []
            cx = torch.zeros(1, args.hidden_dim).to(device)
            hx = torch.zeros(1, args.hidden_dim).to(device)
            
            for step in range(args.num_steps):
                observations_np.append(observation_np[0])
                hs_ts.append(hx)
                cs_ts.append(cx)

                with torch.no_grad():
                    observation_ts = torch.from_numpy(observation_np).to(device)
                    value_ts, logit_ts, (hx, cx) = model((observation_ts,
                                                (hx, cx)))       
                    prob = F.softmax(logit_ts, dim=-1)         
                    action_ts = prob.multinomial(num_samples=1).detach()
                
                action_np = action_ts.cpu().numpy()
                actions_np.append(action_np[0][0])
                observation_np, reward_num, done, _ = env.step(action_np)
                if step == args.num_steps - 1:
                    reward_num = 0 if done else value_ts.item()
                rewards_np.append(reward_num)

                if done:
                    break

            observations_np, actions_np, rewards_np = \
                map(lambda x: np.array(x).astype(np.float32), [observations_np, actions_np, rewards_np])
            episode.append(observations_np, actions_np, rewards_np, hs_ts, cs_ts)

        log_probs = []
        values = []
        entropys = []
        for i in range(len(episode)):
            (hs_ts, cs_ts) = episode.hiddens[0][i], episode.hiddens[1][i]
            value_ts, logit_ts, (_, _) = model((episode.observations[i], (hs_ts, cs_ts)))
            prob = F.softmax(logit_ts, dim=-1)
            log_prob = F.log_softmax(logit_ts, dim=-1)
            entropy = -(log_prob * prob).sum(1)
            log_prob = log_prob.gather(1, episode.actions[i].unsqueeze(1).long())
            log_probs.append(log_prob)
            entropys.append(entropy)
            values.append(value_ts)

        log_probs_ts = torch.stack(log_probs).squeeze(2)
        values_ts = torch.stack(values).squeeze(2)
        entropys_ts = torch.stack(entropys)

        advantages_ts = episode.gae(values_ts)
        advantages_ts = weighted_normalize(advantages_ts, weights=episode.mask)
        policy_loss = - weighted_mean(log_probs_ts * advantages_ts, dim=0,
                weights=episode.mask)
        # import pdb; pdb.set_trace();
        value_loss = weighted_mean((values_ts - episode.returns).pow(2), dim=0,
                weights = episode.mask)
        entropy_loss = - weighted_mean(entropys_ts, dim=0,
                weights = episode.mask)
        
        optimizer.zero_grad()
        tot_loss = policy_loss + entropy_loss + args.value_loss_coef * value_loss
        tot_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()

        print("Epoch [%2d/%2d] : Tot Loss: %5.5f, Policy Loss: %5.5f, Value Loss: %5.5f, Entropy Loss: %5.5f" %
              (train_iter, args.n_epochs, tot_loss.item(), policy_loss.item(), value_loss.item(), entropy_loss.item()))
        # print("Train_iter: ", train_iter, " Total Loss: ", tot_loss.item(), " Value Loss: ", value_loss.item(), " Policy Loss: ", policy_loss.item(), "Entropy Loss: ", entropy_loss.item())
        if summary_writer:
            summary_writer.add_scalar('loss_policy', policy_loss.item(), train_iter)
            summary_writer.add_scalar('loss_value', value_loss.item(), train_iter)
            summary_writer.add_scalar('loss_entropy', entropy_loss.item(), train_iter)
            summary_writer.add_scalar('loss_tot', tot_loss.item(), train_iter)
        train_iter += 1

        if (train_iter + 1) % args.save_per_epoch == 0:
            torch.save(model.state_dict(), os.path.join(args.model_save_path,
                                                        'model_{}_{}.pth').format(train_iter, tot_loss.item()))

        training_log_file.write('{},{}\n'.format(train_iter, tot_loss.item()))
        validation_log_file.write('{},{}\n'.format(train_iter, 0))
        training_log_file.flush()
        validation_log_file.flush()

    training_log_file.close()
    validation_log_file.close()