def sample_minibatch(self):
     """
         Sample a batch of transitions from memory.
     :return: a batch of the whole memory
     """
     transitions = self.memory.sample(64)
     # transitions = self.memory.sample(len(self.memory))
     return Transition(*zip(*transitions))
Esempio n. 2
0
 def sample_minibatch(self):
     """
         Sample a batch of transitions from memory.
         This only happens
             - when the memory is full
             - at some intermediate memory lengths
         Otherwise, the returned batch is empty
     :return: a batch of the whole memory
     """
     if self.memory.is_full():
         logger.info("Memory is full, switching to evaluation mode.")
         self.eval()
         transitions = self.memory.sample(len(self.memory))
         return Transition(*zip(*transitions))
     elif len(self.memory) % self.config["batch_size"] == 0:
         transitions = self.memory.sample(len(self.memory))
         return Transition(*zip(*transitions))
     else:
         return None
Esempio n. 3
0
    def compute_bellman_residual(self, batch, target_state_action_value=None):
        # Compute concatenate the batch elements
        if not isinstance(batch.state, torch.Tensor):
            # logger.info("Casting the batch to torch.tensor")
            # state = torch.cat(tuple(torch.tensor([batch.state], dtype=torch.float))).to(self.device)
            # action = torch.tensor(batch.action, dtype=torch.long).to(self.device)
            # reward = torch.tensor(batch.reward, dtype=torch.float).to(self.device)
            # next_state = torch.cat(tuple(torch.tensor([batch.next_state], dtype=torch.float))).to(self.device)
            # terminal = torch.tensor(batch.terminal, dtype=torch.bool).to(self.device)
            # batch = Transition(state, action, reward, next_state, terminal, batch.info)
            # logger.info("Casting the batch to torch.tensor")
            # np.array speeds up,
            state = torch.cat(tuple(torch.tensor(np.array([batch.state]), dtype=torch.float))).to(self.device)
            # print ("point1 diff = {:.3f} ms".format((time.time()-start)*1000))
            action = torch.tensor(np.array(batch.action), dtype=torch.long).to(self.device)
            # print("point2 diff = {:.3f} ms".format((time.time() - start) * 1000))
            reward = torch.tensor(np.array(batch.reward), dtype=torch.float).to(self.device)
            # print("point3 diff = {:.3f} ms".format((time.time() - start) * 1000))
            next_state = torch.cat(tuple(torch.tensor(np.array([batch.next_state]), dtype=torch.float))).to(self.device)
            # print("point4 diff = {:.3f} ms".format((time.time() - start) * 1000))
            # #TODO :test wihtout converting to int
            # if self.env.config["observation"]["observation_config"]["type"] == "HeatmapObservation":
            #     scale = 255.0
            #     state = torch.div(state, scale)
            #     # print("point4.1 diff = {:.3f} ms".format((time.time() - start) * 1000))
            #     next_state = torch.div(next_state, scale)

            # print("point5 diff = {:.3f} ms".format((time.time() - start) * 1000))
            terminal = torch.tensor(np.array(batch.terminal), dtype=torch.bool).to(self.device)
            # print("point6 diff = {:.3f} ms".format((time.time() - start) * 1000))
            batch = Transition(state, action, reward, next_state, terminal, batch.info)
            # print("point7 diff = {:.3f} ms".format((time.time() - start) * 1000))
        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken
        state_action_values = self.value_net(batch.state)
        state_action_values = state_action_values.gather(1, batch.action.unsqueeze(1)).squeeze(1)

        if target_state_action_value is None:
            with torch.no_grad():
                # Compute V(s_{t+1}) for all next states.
                next_state_values = torch.zeros(batch.reward.shape).to(self.device)
                if self.config["double"]:
                    # Double Q-learning: pick best actions from policy network
                    _, best_actions = self.value_net(batch.next_state).max(1)
                    # Double Q-learning: estimate action values from target network
                    best_values = self.target_net(batch.next_state).gather(1, best_actions.unsqueeze(1)).squeeze(1)
                else:
                    best_values, _ = self.target_net(batch.next_state).max(1)
                next_state_values[~batch.terminal] = best_values[~batch.terminal]
                # Compute the expected Q values
                target_state_action_value = batch.reward + self.config["gamma"] * next_state_values

        # Compute loss
        loss = self.loss_function(state_action_values, target_state_action_value)
        return loss, target_state_action_value, batch
Esempio n. 4
0
    def compute_bellman_residual(self, batch, target_state_action_value=None):
        # Compute concatenate the batch elements
        if not isinstance(batch.state, torch.Tensor):
            # logger.info("Casting the batch to torch.tensor")
            state = torch.cat(
                tuple(torch.tensor([batch.state],
                                   dtype=torch.float))).to(self.device)
            action = torch.tensor(batch.action,
                                  dtype=torch.long).to(self.device)
            reward = torch.tensor(batch.reward,
                                  dtype=torch.float).to(self.device)
            next_state = torch.cat(
                tuple(torch.tensor([batch.next_state],
                                   dtype=torch.float))).to(self.device)
            terminal = torch.tensor(batch.terminal,
                                    dtype=torch.bool).to(self.device)
            batch = Transition(state, action, reward, next_state, terminal,
                               batch.info)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken
        state_action_values = self.value_net(batch.state)
        state_action_values = state_action_values.gather(
            1, batch.action.unsqueeze(1)).squeeze(1)

        if target_state_action_value is None:
            with torch.no_grad():
                # Compute V(s_{t+1}) for all next states.
                next_state_values = torch.zeros(batch.reward.shape).to(
                    self.device)
                if self.config["double"]:
                    # Double Q-learning: pick best actions from policy network
                    _, best_actions = self.value_net(batch.next_state).max(1)
                    # Double Q-learning: estimate action values from target network
                    best_values = self.target_net(batch.next_state).gather(
                        1, best_actions.unsqueeze(1)).squeeze(1)
                else:
                    best_values, _ = self.target_net(batch.next_state).max(1)
                next_state_values[~batch.terminal] = best_values[~batch.
                                                                 terminal]
                # Compute the expected Q values
                target_state_action_value = batch.reward + self.config[
                    "gamma"] * next_state_values

        # Compute loss
        loss = self.loss_function(state_action_values,
                                  target_state_action_value)
        return loss, target_state_action_value, batch
Esempio n. 5
0
    def collect_samples(environment_config, agent_config, count, start_time,
                        seed, model_path, batch):
        """
            Collect interaction samples of an agent / environment pair.

            Note that the last episode may not terminate, when enough samples have been collected.

        :param dict environment_config: the environment configuration
        :param dict agent_config: the agent configuration
        :param int count: number of samples to collect
        :param start_time: the initial local time of the agent
        :param seed: the env/agent seed
        :param model_path: the path to load the agent model from
        :param batch: index of the current batch
        :return: a list of trajectories, i.e. lists of Transitions
        """
        env = load_environment(environment_config)
        env.seed(seed)

        if batch == 0:  # Force pure exploration during first batch
            agent_config["exploration"]["final_temperature"] = 1
        agent_config["device"] = "cpu"
        agent = load_agent(agent_config, env)
        agent.load(model_path)
        agent.seed(seed)
        agent.set_time(start_time)

        state = env.reset()
        episodes = []
        trajectory = []
        for _ in range(count):
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            trajectory.append(
                Transition(state, action, reward, next_state, done, info))
            if done:
                state = env.reset()
                episodes.append(trajectory)
                trajectory = []
            else:
                state = next_state
        if trajectory:  # Unfinished episode
            episodes.append(trajectory)
        env.close()
        return episodes
Esempio n. 6
0
 def sample_minibatch(self):
     if len(self.memory) < self.config["batch_size"]:
         return None
     transitions = self.memory.sample(self.config["batch_size"])
     return Transition(*zip(*transitions))