Exemple #1
0
    def train(self):
        try:
            self.explorer["train"].reset()

            while (self.state["i_epoch"] < self.params["runner"]["n_epochs"]) and not self.termination_check():
                self.state["i_cycle"] = 0
                while self.state["i_cycle"] < self.params["runner"]["n_cycles"]:
                    with KeepTime("/"):
                        with KeepTime("train"):
                            chunk = self.explorer["train"].update()
                            with KeepTime("store"):
                                self.memory["train"].store(chunk)

                    self.state["i_cycle"] += 1
                # End of Cycle
                self.state["i_epoch"] += 1
                self.monitor_epoch()
                
                # 3. Log
                self.log()
                gc.collect() # Garbage Collection

        except (KeyboardInterrupt, SystemExit):
            logger.fatal('Operation stopped by the user ...')
        finally:
            logger.fatal('End of operation ...')
            self.finalize()
Exemple #2
0
    def enjoy(self):  #i.e. eval
        """This function evaluates the current policy in the environment. It only runs the explorer in a loop.

        .. code-block:: python

            # Do a cycle
            while not done:
                # Explore
                explorer["eval"].update()

            log()
        """
        # TODO: We need more elegant mechanisms to handle this import.
        import glfw
        glfw.init()

        try:
            self._sync_normalizations(source_explorer="train",
                                      target_explorer="eval")
            self.explorer["eval"].reset()
            while True:
                # Cycles
                self.state["i_cycle"] = 0
                while self.state["i_cycle"] < self.params["runner"]["n_cycles"]:
                    with KeepTime("/"):
                        # 1. Do Experiment
                        with KeepTime("eval"):
                            self.explorer["eval"].update()
                    self.log()
                    self.state["i_cycle"] += 1
                # Log
        except (KeyboardInterrupt, SystemExit):
            logger.fatal('Operation stopped by the user ...')
        finally:
            self.finalize(save=False)
Exemple #3
0
    def train(self):
        try:
            while (self.state["i_epoch"] < self.params["runner"]["n_epochs"]) and not self.termination_check():
                self.state["i_cycle"] = 0
                while self.state["i_cycle"] < self.params["runner"]["n_cycles"]:
                    with KeepTime("/"):
                        
                        with KeepTime("demo"):
                            chunk = self.explorer["demo"].update()
                            with KeepTime("store"):
                                self.memory["demo"].store(chunk)
                        
                        # if self.memory["demo"].full:
                        #     # Memory full. Time to leave
                        #     self.ready_for_termination
                        #     # Make sure major checkpoints work.

                    self.state["i_cycle"] += 1
                # End of Cycle
                self.state["i_epoch"] += 1
                self.monitor_epoch()
                
                # 3. Log
                self.log()
                gc.collect() # Garbage Collection

        except (KeyboardInterrupt, SystemExit):
            logger.fatal('Operation stopped by the user ...')
        finally:
            logger.fatal('End of operation ...')
            self.finalize()
Exemple #4
0
    def prestep(self, final_step=False):
        """
        Function to produce actions for all of the agents. This function does not execute the actions in the environment.
        
        Args:
            final_step (bool): A flag indicating whether this is the last call of this function.
        
        Returns:
            dict: The pre-transition dictionary containing observations, masks, and agents informations. The format is like:
            ``{"observations":..., "masks":..., "agents":...}``
        """

        with KeepTime("to_numpy"):
            # TODO: Is it necessary for conversion of obs?
            # NOTE: The np conversion will not work if observation is a dictionary.
            # observations = np.array(self.state["observations"], dtype=np.float32)
            observations = self.state["observations"]
            masks = self.state["masks"]
            hidden_state = self.state["hidden_state"]

        with KeepTime("gen_action"):
            publish_agents = True
            agents = {}
            # TODO: We are assuming a one-level action space.
            if (not final_step) or (self.params["final_action"]):
                if self.state["steps"] < self.params["warm_start"]:
                    # Take RANDOM actions if warm-starting
                    for agent_name in self.agents:
                        agents[agent_name] = self.agents[
                            agent_name].random_action_generator(
                                self.envs, self.params["num_workers"])
                else:
                    # Take REAL actions if not warm-starting
                    for agent_name in self.agents:
                        action_generator = self.agents[
                            agent_name].action_generator
                        agents[agent_name] = action_generator(
                            observations,
                            hidden_state[agent_name],
                            masks,
                            deterministic=self.params["deterministic"])
            else:
                publish_agents = False
            # We are saving the "new" hidden_state now.

            # for agent_name in self.agents:
            #     if (not final_step) or (self.params["final_action"]):
            #         action_generator = self.agents[agent_name].action_generator
            #         agents[agent_name] = action_generator(observations, hidden_state[agent_name], masks, deterministic=self.params["deterministic"])
            #     else:
            #         publish_agents = False

        with KeepTime("form_dictionary"):
            if publish_agents:
                pre_transition = dict(observations=observations,
                                      masks=masks,
                                      agents=agents)
            else:
                pre_transition = dict(observations=observations, masks=masks)
        return pre_transition
Exemple #5
0
    def calculate_loss(self, state, action, reward, next_state, masks):
        with KeepTime("loss"):
            state_repr_softq = self.policy.model["image"](state)
            expected_q_value = self.policy.model["softq"](state_repr_softq,
                                                          action)

            state_repr_value = self.policy.model["image"](state)
            expected_value = self.policy.model["value"](state_repr_value)
            # new_action, log_prob, z, mean, log_std = self.policy.evaluate_actions(state_repr.detach())
            state_repr_actor = self.policy.model["image"](state)
            new_action, log_prob, z, mean, log_std = self.policy.evaluate_actions(
                state_repr_actor)

            next_state_repr = self.policy.model["image"](next_state)
            target_value = self.policy.model["value_target"](next_state_repr)
            next_q_value = reward + masks * float(
                self.params["methodargs"]["gamma"]) * target_value
            softq_loss = self.criterion["softq"](expected_q_value,
                                                 next_q_value.detach())

            expected_new_q_value = self.policy.model["softq"](
                state_repr_softq.detach(), new_action)
            next_value = expected_new_q_value - log_prob
            value_loss = self.criterion["value"](expected_value,
                                                 next_value.detach())

            log_prob_target = expected_new_q_value - expected_value
            # TODO: Apparently the calculation of actor_loss is problematic: none of its ingredients have gradients! So backprop does nothing.
            actor_loss = (log_prob *
                          (log_prob - log_prob_target).detach()).mean()

            mean_loss = float(
                self.params["methodargs"]["mean_lambda"]) * mean.pow(2).mean()
            std_loss = float(self.params["methodargs"]
                             ["std_lambda"]) * log_std.pow(2).mean()
            z_loss = float(self.params["methodargs"]["z_lambda"]) * z.pow(
                2).sum(1).mean()

            actor_loss += mean_loss + std_loss + z_loss

        with KeepTime("optimization"):
            self.optimizer["softq"].zero_grad()
            softq_loss.backward()
            self.optimizer["softq"].step()

            self.optimizer["value"].zero_grad()
            value_loss.backward()
            self.optimizer["value"].step()

            # self.optimizer["image"].zero_grad()
            # self.optimizer["image"].step()

            self.optimizer["actor"].zero_grad()
            actor_loss.backward()
            self.optimizer["actor"].step()

        monitor("/update/loss/actor", actor_loss.item())
        monitor("/update/loss/softq", softq_loss.item())
        monitor("/update/loss/value", value_loss.item())
Exemple #6
0
    def update(self):
        """Runs :func:`step` for ``n_steps`` times.

        Returns:
            dict: A dictionary of unix-stype file system keys including all information generated by the simulation.
        
        See Also:
            :ref:`ref-data-structure`
        """

        # trajectory is a dictionary of lists
        trajectory = {}

        if not self.state["was_reset"] and self.params["do_reset"]:
            self.reset()

        self.state["was_reset"] = False

        # Run T (n-step) steps.
        self.local["steps"] = 0
        self.local["n_episode"] = 0

        while (self.params["n_steps"]    and self.local["steps"]     < self.params["n_steps"]) or \
              (self.params["n_episodes"] and self.local["n_episode"] < self.params["n_episodes"]):
            with KeepTime("step"):
                # print("one exploration step ...")
                transition = self.step()

            with KeepTime("append"):
                # Data is flattened in the explorer per se.
                transition = flatten_dict(transition)
                # Update the trajectory with the current list of data.
                # Put nones if the key is absent.
                update_dict_of_lists(trajectory,
                                     transition,
                                     index=self.local["steps"])

            self.local["steps"] += 1

        with KeepTime("poststep"):
            # Take one prestep so we have the next observation/hidden_state/masks/action/value/ ...
            transition = self.prestep(final_step=True)
            transition = flatten_dict(transition)
            update_dict_of_lists(trajectory,
                                 transition,
                                 index=self.local["steps"])

            # Complete the trajectory if one key was in a transition, but did not occur in later
            # transitions. "length=n_steps+1" is because of counting final out-of-loop prestep.

            # complete_dict_of_list(trajectory, length=self.params["n_steps"]+1)
            complete_dict_of_list(trajectory, length=self.local["steps"] + 1)
            result = convert_time_to_batch_major(trajectory)

        # We discard the rest of monitored episodes for the test mode to prevent them from affecting next test.
        monitor.discard_key("/reward/test/episodic")
        return result
Exemple #7
0
    def update(self):
        # Update the networks for n times
        for i in range(self.params["methodargs"]["n_update"]):
            with KeepTime("step"):
                self.step()

        with KeepTime("targets"):
            # Update value target
            self.policy.averager["value"].update_target()
Exemple #8
0
    def update(self):
        # Update the networks for n times
        for i in range(self.params["methodargs"]["n_update"]):
            # Step
            with KeepTime("step"):
                self.step()

            with KeepTime("targets"):
                # Update actor/critic targets
                self.policy.averager["actor"].update_target()
                self.policy.averager["critic"].update_target()
Exemple #9
0
 def train_cycle(self):
     # 1. Do Experiment
     with KeepTime("train"):
         chunk = self.explorer["train"].update()
     # 2. Store Result
     with KeepTime("store"):
         self.memory["train"].store(chunk)
     # 3. Update Agent
     with KeepTime("update"):
         for agent_name in self.agents:
             with KeepTime(agent_name):
                 self.agents[agent_name].update()
Exemple #10
0
 def test(self):
     # Make the states of the two explorers train/test exactly the same, for the states of the environments.
     if self.params["runner"]["test_act"]:
         if self.state["i_epoch"] % self.params["runner"]["test_int"] == 0:
             with KeepTime("/"):
                 with KeepTime("test"):
                     self._sync_normalizations(source_explorer="train", target_explorer="test")
                     # self.explorer["test"].load_state_dict(self.explorer["train"].state_dict())
                     self.explorer["test"].reset()
                     # TODO: Do update until "win_size" episodes get executed.
                     # That is in: self.explorer["test"].state["n_episode"]
                     # Make sure that n_steps is 1.
                     # If num_worker>1 it is possible that we get more than required test episodes.
                     # The rest will be reported with the next test run.
                     self.explorer["test"].update()
Exemple #11
0
    def train(self):
        try:
            while self.state["i_epoch"] < self.params["runner"][
                    "n_epochs"] and not self.termination_check():
                self.state["i_cycle"] = 0
                while self.state["i_cycle"] < self.params["runner"]["n_cycles"]:
                    with KeepTime("/"):
                        self.explorer["demo"].update()

                        # # Update Agent
                        # for agent_name in self.agents:
                        #     with KeepTime(agent_name):
                        #         self.agents[agent_name].update()

                    self.state["i_cycle"] += 1
                # End of Cycle
                self.state["i_epoch"] += 1
                self.monitor_epoch()

                # 3. Log
                self.log()
                gc.collect()  # Garbage Collection

        except (KeyboardInterrupt, SystemExit):
            logger.fatal('Operation stopped by the user ...')
        finally:
            logger.fatal('End of operation ...')
            self.finalize()
Exemple #12
0
    def train(self):
        """
        The function that runs the training loop.

        See Also:
            :ref:`ref-how-runner-works`
        """
        try:
            # while self.state["i_epoch"] < self.state["n_epochs"]:
            while (self.state["i_epoch"] < self.params["runner"]["n_epochs"]
                   ) and not self.termination_check():
                self.state["i_cycle"] = 0
                while self.state["i_cycle"] < self.params["runner"]["n_cycles"]:
                    with KeepTime("/"):
                        self.train_cycle()
                    self.state["i_cycle"] += 1
                    # End of Cycle
                self.state["i_epoch"] += 1
                self.monitor_epoch()
                self.iterations += 1

                # NOTE: We may save/test after each cycle or at intervals.
                # 1. Perform the test
                self.test()
                # 2. Log
                self.log()
                # 3. Save
                self.save()
                # Free up memory from garbage.
                gc.collect()  # Garbage Collection

        except (KeyboardInterrupt, SystemExit):
            logger.fatal('Operation stopped by the user ...')
        finally:
            self.finalize()
Exemple #13
0
 def fetch(self, batch):
     with KeepTime("fetch/to_torch"):
         state = torch.from_numpy(
             batch["/observations" + self.params["observation_path"]]).to(
                 self.device)
         action = torch.from_numpy(batch["/agents/" + self.params["name"] +
                                         "/actions"]).to(self.device)
         reward = torch.from_numpy(batch["/rewards"]).to(self.device)
         next_state = torch.from_numpy(
             batch["/observations" + self.params["observation_path"] +
                   "_2"]).to(self.device)
         masks = torch.from_numpy(batch["/masks"]).to(self.device)
     return state, action, reward, next_state, masks
Exemple #14
0
    def action_generator(self,
                         observations,
                         hidden_state,
                         masks,
                         deterministic=False):
        """The function that is called by :class:`~digideep.environment.explorer.Explorer`.

        Args:
            deterministic (bool): If ``True``, the best action from the optimal action will be computed. If ``False``,
                the action will be sampled from the action probability distribution.

        Returns:
            dict: ``{"actions":...,"hidden_state":...,"artifacts":{"values":...,"action_log_p":...}}``
        
        """

        observation_path = self.params.get("observation_path", "/agent")
        observations_ = observations[observation_path].astype(np.float32)

        with KeepTime("to_torch"):
            observations_ = torch.from_numpy(observations_).to(self.device)
            hidden_state_ = torch.from_numpy(hidden_state).to(self.device)
            masks_ = torch.from_numpy(masks).to(self.device)

        with KeepTime("compute_func"):
            values, action, action_log_p, hidden_state_ = \
                self.policy.generate_actions(observations_, hidden_state_, masks_, deterministic=deterministic)

        with KeepTime("to_numpy"):
            artifacts = dict(values=values.cpu().data.numpy(),
                             action_log_p=action_log_p.cpu().data.numpy())

            # actions and hidden_state is something every agent should produce.
            results = dict(actions=action.cpu().data.numpy(),
                           hidden_state=hidden_state_.cpu().data.numpy(),
                           artifacts=artifacts)
        return results
Exemple #15
0
    def sample(self):
        with KeepTime("sampler"):
            info = deepcopy(self.params["sampler_args"])

            batch_size = info["batch_size"]
            b = self.scheduler.value

            demo_batch_size = int(b * batch_size)
            train_batch_size = batch_size - demo_batch_size

            info["batch_size_dict"] = {
                "train": train_batch_size,
                "demo": demo_batch_size
            }

            batch = self.sampler(data=self.memory, info=info)
            return batch
Exemple #16
0
    def step(self):
        """This function is inspired by `pytorch-a2c-ppo-acktr <https://github.com/ikostrikov/pytorch-a2c-ppo-acktr>`_.

        This function needs the following keys to be in the input batch:

        * ``/observations``
        * ``/masks``
        * ``/agents/agent_name/hidden_state``
        * ``/agents/<agent_name>/actions``
        * ``/agents/<agent_name>/artifacts/action_log_p``
        * ``/agents/<agent_name>/artifacts/values``
        * ``/agents/<agent_name>/artifacts/advantages``
        * ``/agents/<agent_name>/artifacts/returns``

        The last two keys are added by the :mod:`digideep.agent.samplers`, while the rest are added at
        :class:`~digideep.environment.explorer.Explorer`.

        """
        with KeepTime("samples"):
            info = deepcopy(self.params["sampler"])
            if self.policy.is_recurrent:
                data_sampler = sampler_rn(data=self.memory, info=info)
            else:
                data_sampler = sampler_ff(data=self.memory, info=info)

        with KeepTime("batches"):
            for batch in data_sampler:
                with KeepTime("to_torch"):
                    # Environment
                    observations = torch.from_numpy(
                        batch["/observations" +
                              self.params["observation_path"]]).to(self.device)
                    masks = torch.from_numpy(batch["/masks"]).to(self.device)
                    # Agent
                    hidden_state = torch.from_numpy(
                        batch["/agents/" + self.params["name"] +
                              "/hidden_state"]).to(self.device)
                    actions = torch.from_numpy(
                        batch["/agents/" + self.params["name"] +
                              "/actions"]).to(self.device)
                    # Agent Artifacts
                    old_action_log_p = torch.from_numpy(
                        batch["/agents/" + self.params["name"] +
                              "/artifacts/action_log_p"]).to(
                                  device=self.device)
                    value_preds = torch.from_numpy(
                        batch["/agents/" + self.params["name"] +
                              "/artifacts/values"]).to(self.device)
                    advantages = torch.from_numpy(
                        batch["/agents/" + self.params["name"] +
                              "/artifacts/advantages"]).to(self.device)
                    returns = torch.from_numpy(
                        batch["/agents/" + self.params["name"] +
                              "/artifacts/returns"]).to(self.device)

                with KeepTime("eval_action"):
                    values, action_log_p, dist_entropy, _ = \
                        self.policy.evaluate_actions(observations,
                                                        hidden_state,
                                                        masks,
                                                        actions)

                with KeepTime("loss_function"):
                    # This ratio is the quotient of old/new policy "density" at the state s.
                    ratio = torch.exp(action_log_p - old_action_log_p)
                    surr1 = ratio * advantages
                    surr2 = torch.clamp(
                        ratio, 1.0 - self.params["methodargs"]["clip_param"],
                        1.0 +
                        self.params["methodargs"]["clip_param"]) * advantages
                    action_loss = -torch.min(surr1, surr2).mean()

                    if self.params["methodargs"]["use_clipped_value_loss"]:
                        value_pred_clipped = value_preds + \
                            (values - value_preds).clamp(-self.params["methodargs"]["clip_param"], self.params["methodargs"]["clip_param"])
                        value_losses = (values - returns).pow(2)
                        value_losses_clipped = (value_pred_clipped -
                                                returns).pow(2)
                        value_loss = 0.5 * torch.max(
                            value_losses, value_losses_clipped).mean()
                    else:
                        # value_loss = 0.5 * F.mse_loss(returns, values)
                        value_loss = 0.5 * (returns - values).pow(2).mean()

                    Loss = value_loss * self.params["methodargs"]["value_loss_coef"] \
                        + action_loss \
                        - dist_entropy * self.params["methodargs"]["entropy_coef"]

                with KeepTime("backprop"):
                    self.optimizer.zero_grad()
                    Loss.backward()
                    nn.utils.clip_grad_norm_(
                        self.policy.model.parameters(),
                        self.params["methodargs"]["max_grad_norm"])

                with KeepTime("optimstep"):
                    self.optimizer.step()

                # Monitoring values
                monitor("/update/loss", Loss.item())
                monitor("/update/value_loss", value_loss.item())
                monitor("/update/action_loss", action_loss.item())
                monitor("/update/dist_entropy", dist_entropy.item())

                self.session.writer.add_scalar('loss/overall', Loss.item(),
                                               self.state["i_step"])
                self.session.writer.add_scalar('loss/value', value_loss.item(),
                                               self.state["i_step"])
                self.session.writer.add_scalar('loss/action',
                                               action_loss.item(),
                                               self.state["i_step"])
                self.session.writer.add_scalar('loss/dist_entropy',
                                               dist_entropy.item(),
                                               self.state["i_step"])

                ## Candidates for monitoring
                # ratio.item()
        self.state["i_step"] += 1
Exemple #17
0
    def step(self):
        """This function needs the following key values in the batch of memory:

        * ``/observations``
        * ``/rewards``
        * ``/agents/<agent_name>/actions``
        * ``/observations_2``

        The first three keys are generated by the :class:`~digideep.environment.explorer.Explorer`
        and the last key is added by the sampler.
        """
        alpha = self.params["methodargs"]["alpha"]
        gamma = self.params["methodargs"]["gamma"]


        with KeepTime("sampler"):
            info = deepcopy(self.params["sampler_args"])
            batch = self.sampler(data=self.memory, info=info)
            if batch is None:
                return

        with KeepTime("loss"):
            with KeepTime("to_torch"):
                # ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2', ...]
                state      = torch.from_numpy(batch["/observations"+ self.params["observation_path"]]).to(self.device).float()
                action     = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device).float()
                reward     = torch.from_numpy(batch["/rewards"]).to(self.device).float()
                next_state = torch.from_numpy(batch["/observations"+self.params["observation_path"]+"_2"]).to(self.device).float()
                masks      = torch.from_numpy(batch["/masks"]).to(self.device)
                # masks      = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)

            #     reward = torch.FloatTensor(reward).to(self.device).unsqueeze(1)
            #     masks = torch.FloatTensor(masks).to(self.device).unsqueeze(1)

            ## Critic loss
            with torch.no_grad():
                next_state_action, next_state_log_prob = self.policy.evaluate_actions(next_state)
                qf1_next_target = self.policy.model["critic1_target"](next_state, next_state_action)
                qf2_next_target = self.policy.model["critic2_target"](next_state, next_state_action)
            
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_prob
                next_q_value = reward + masks * gamma * (min_qf_next_target)


            # Two Q-functions to mitigate positive bias in the policy improvement step
            qf1 = self.policy.model["critic1"](state, action)
            qf2 = self.policy.model["critic2"](state, action)

            # # JQ = E(st,at)~D[0.5(Q(st,at) - r(st,at) - γ(Est+1~p[V(st+1)]))^2]
            qf1_loss = F.mse_loss(qf1, next_q_value)
            qf2_loss = F.mse_loss(qf2, next_q_value)

            # NOTE: Since we are using self.policy.model["critic1"] & self.policy.model["critic2"] for calculating
            #       actor loss in the continuation, it is super critical to calculate and step the qf1_loss 
            #       and qf2_loss here. Otherwise PyTorch won't know how to separate these different gradients which
            #       are otherwise mixed. It will complain about in-place operations that mix the gradients.
            self.optimizer["critic1"].zero_grad()
            qf1_loss.backward()
            self.optimizer["critic1"].step()

            self.optimizer["critic2"].zero_grad()
            qf2_loss.backward()
            self.optimizer["critic2"].step()


            ## Policy loss            
            pi, log_pi = self.policy.evaluate_actions(state)

            qf1_pi = self.policy.model["critic1"](state, pi)
            qf2_pi = self.policy.model["critic2"](state, pi)

            min_qf_pi = torch.min(qf1_pi, qf2_pi)

            # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
            actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

            self.optimizer["actor"].zero_grad()
            actor_loss.backward()
            self.optimizer["actor"].step()

        #     if self.automatic_entropy_tuning:
        #         alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

        #         self.alpha_optim.zero_grad()
        #         alpha_loss.backward()
        #         self.alpha_optim.step()

        #         self.alpha = self.log_alpha.exp()
        #         alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        #     else:
        #         alpha_loss = torch.tensor(0.).to(self.device)
        #         alpha_tlogs = torch.tensor(alpha) # For TensorboardX logs


        monitor("/update/loss/actor", actor_loss.item())
        monitor("/update/loss/critic1", qf1_loss.item())
        monitor("/update/loss/critic2", qf2_loss.item())

        self.session.writer.add_scalar('loss/actor', actor_loss.item())
        self.session.writer.add_scalar('loss/critic1', qf1_loss.item())
        self.session.writer.add_scalar('loss/critic2', qf2_loss.item())

        # 'loss/entropy_loss', ent_loss:        alpha_loss.item()
        # 'entropy_temprature/alpha', alpha:    alpha_tlogs.item()
        self.state["i_step"] += 1
Exemple #18
0
def multi_memory_sample(memory, infos):
    batch_size = 0
    list_of_buffers = []
    for m in memory:
        with KeepTime("info_deepcopy"):
            # 1. Make a copy of infos
            info = deepcopy(infos)
        with KeepTime("info_update"):
            # 2. Set the necessary changes in the corresponding info: Set the buffer_size for sampling.
            info["batch_size"] = infos["batch_size_dict"][m]
            batch_size += info["batch_size"]
            #### print("For {} we have bs = {}".format(m, info["batch_size"]))
        # with KeepTime("get_memory_params"):
        #     # 3. Get the memory parameters of the specified key
        #     mem = get_memory_params(memory[m], info)
        with KeepTime("get_sample_memory"):
            # 4. Do the actual sampling from the memory
            buf = get_sample_memory(memory[m], info)
        if buf is None:
            # print("Not enough data in [{}].".format(m))
            return None
        
        # 5. Use demonstrator's actions not the agent's.
        # TODO: The correct thing is to use demonstrators actions. However, we want to test
        #       if it works with agents actions (this will definitely cause erroneous state
        #       trajectories which can effect learning system dynamics; because s' is not a
        #       subsequence of s by action a anymore.)
        if m == "replay":
            # If we are in replaying mode, we want the demo's actions to be learnt.
            buf["/agents/agent/actions"] = buf["/agents/demonstrator/actions"]

        # 6. Append the sampled buffers to a single buffer
        with KeepTime("append_buffer"):
            list_of_buffers += [buf]
    
    # print("{} vs. {}".format(np.mean(list_of_buffers[0]["/obs_with_key"]), np.mean(list_of_buffers[1]["/obs_with_key"])))

    with KeepTime("append_common_keys"):
        buffer = append_common_keys(list_of_buffers)
    
    # with KeepTime("shuffle_inside"):
    #     # Shuffle inside the buffer:
    #     # NOTE: It does not matter most probably. Just in case ...
    #     p = np.random.permutation(batch_size)
    #     for key in buffer:
    #         buffer[key] = buffer[key][p, ...]

    # print(buffer.keys())
    # exit()

    # indices = (buffer["/observations/status/is_training"] == 0)
    # print(indices)
    # print(indices.shape)
    # exit()
    
    #### for k in buffer:
    ####     print("{:50s}: {} = {} + {}".format(k, buffer[k].shape, list_of_buffers[0][k].shape, list_of_buffers[1][k].shape))
    #### exit()
    return buffer



# def post_sampler(chunk, info):
#     # Use the demonstrator action when we are teaching!
#     if chunk:
#         assert chunk["/agents/agent/actions"].shape == chunk["/agents/demonstrator/actions"].shape, \
#             "The actions of interchangeable agents should have equal shape."
        
#         indices = (chunk["/observations/status/is_training"] == 0)
        
#         # There shouldn't be any indcies with is_training==0 because we are not saving the results of demo
#         # print(indices)
#         chunk["/agents/agent/actions"][indices] = chunk["/agents/demonstrator/actions"][indices]
#     return chunk

    
Exemple #19
0
    def step(self):
        """This function needs the following key values in the batch of memory:

        * ``/observations``
        * ``/rewards``
        * ``/agents/<agent_name>/actions``
        * ``/observations_2``

        The first three keys are generated by the :class:`~digideep.environment.explorer.Explorer`
        and the last key is added by the sampler.
        """

        with KeepTime("sample"):
            batch = self.sample()
        if batch is None:
            return

        ## Sequence of images as stacking
        #
        # shape = batch["/observations/camera"][0].shape
        # self.session.writer.add_images(tag=self.params["name"]+"_images",
        #                                img_tensor=batch["/observations/camera"][0].reshape(shape[0],1,shape[1],shape[2]),
        #                                global_step=self.state['i_step'],
        #                                dataformats='NCHW')
        #

        # self.session.writer.add_images(tag=self.params["name"]+"_images",
        #                                img_tensor=batch["/observations/camera"][:,1:,:,:],
        #                                global_step=self.state['i_step'],
        #                                dataformats='NCHW')

        # print("1. RAW AVERAGE OF OUR FIRST INSTANCE:", np.mean(batch["/observations/camera"][0]), "|     STD:", np.std(batch["/observations/camera"][0]))

        # print("2. RAW AVERAGE OF OUR FIRST INSTANCE:", np.mean(batch["/observations/camera"]/255.), "|     STD:", np.std(batch["/observations/camera"]/255.))

        # print(batch["/observations/camera"][0])
        # print("\n\n\n")

        # batch["/observations/camera"][:,1:,:,:] vs batch["/observations/camera"][:,:3,:,:]
        # Thesecond shows complete black at the very first frame since frame-stacking stackes with zero frames.
        # The first one should always show something.
        #
        ## Sequence of images as channels
        # a1 = batch["/observations/camera"][0].reshape(1, *shape)
        # a2 = batch["/observations/camera_2"][0].reshape(1, *shape)
        # c = np.concatenate([a1,a2])
        # self.session.writer.add_images(tag=self.params["name"]+"_images_0",
        #                               img_tensor=c[:,:3,:,:],
        #                               global_step=self.state['i_step'],
        #                               dataformats='NCHW')
        #
        # self.session.writer.add_image(tag=self.params["name"]+"_images_1",
        #                               img_tensor=batch["/observations/camera_2"][1].reshape(1, *shape),
        #                               global_step=self.state['i_step'],
        #                               dataformats='CHW')

        # batch["/observations/camera"]   = (batch["/observations/camera"]   - 16.4) / 17.0
        # batch["/observations/camera_2"] = (batch["/observations/camera_2"] - 16.4) / 17.0

        with KeepTime("fetch"):
            state, action, reward, next_state, masks = self.fetch(batch)

        self.calculate_loss(state, action, reward, next_state, masks)
        self.state["i_step"] += 1
Exemple #20
0
 def initProlog(self):
     if not self.dry_run:
         profiler.set_output_file(self.state['file_prolog'])
     KeepTime.set_level(self.args["profiler_level"])
Exemple #21
0
    def step(self):
        """This function needs the following key values in the batch of memory:

        * ``/observations``
        * ``/rewards``
        * ``/agents/<agent_name>/actions``
        * ``/observations_2``

        The first three keys are generated by the :class:`~digideep.environment.explorer.Explorer`
        and the last key is added by the sampler.
        """
        with KeepTime("sampler"):
            info = deepcopy(self.params["sampler_args"])
            batch = self.sampler(data=self.memory, info=info)
            if batch is None:
                return

        with KeepTime("to_torch"):
            # ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2', ...]
            state = torch.from_numpy(
                batch["/observations" + self.params["observation_path"]]).to(
                    self.device).float()
            action = torch.from_numpy(batch["/agents/" + self.params["name"] +
                                            "/actions"]).to(
                                                self.device).float()
            reward = torch.from_numpy(batch["/rewards"]).to(
                self.device).float()
            next_state = torch.from_numpy(
                batch["/observations" + self.params["observation_path"] +
                      "_2"]).to(self.device).float()
            masks = torch.from_numpy(batch["/masks"]).to(self.device)

            # state      = torch.from_numpy(batch["/obs_with_key"]).to(self.device)
            # action     = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device)
            # reward     = torch.from_numpy(batch["/rewards"]).to(self.device)
            # next_state = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device)
            # # masks      = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)
            # masks      = torch.from_numpy(batch["/masks"]).to(self.device)

        with KeepTime("loss"):
            expected_q_value = self.policy.model["softq"](state, action)
            expected_value = self.policy.model["value"](state)
            new_action, log_prob, z, mean, log_std = self.policy.evaluate_actions(
                state)

            target_value = self.policy.model["value_target"](next_state)
            next_q_value = reward + masks * float(
                self.params["methodargs"]["gamma"]) * target_value
            softq_loss = self.criterion["softq"](expected_q_value,
                                                 next_q_value.detach())

            expected_new_q_value = self.policy.model["softq"](state,
                                                              new_action)
            next_value = expected_new_q_value - log_prob
            value_loss = self.criterion["value"](expected_value,
                                                 next_value.detach())

            log_prob_target = expected_new_q_value - expected_value
            # TODO: Apparently the calculation of actor_loss is problematic: none of its ingredients have gradients! So backprop does nothing.
            actor_loss_base = (log_prob *
                               (log_prob - log_prob_target).detach()).mean()

            mean_loss = float(
                self.params["methodargs"]["mean_lambda"]) * mean.pow(2).mean()
            std_loss = float(self.params["methodargs"]
                             ["std_lambda"]) * log_std.pow(2).mean()
            z_loss = float(self.params["methodargs"]["z_lambda"]) * z.pow(
                2).sum(1).mean()

            actor_loss = actor_loss_base + mean_loss + std_loss + z_loss

            self.optimizer["softq"].zero_grad()
            softq_loss.backward()
            self.optimizer["softq"].step()

            self.optimizer["value"].zero_grad()
            value_loss.backward()
            self.optimizer["value"].step()

            self.optimizer["actor"].zero_grad()
            actor_loss.backward()
            self.optimizer["actor"].step()

        monitor("/update/loss/actor", actor_loss.item())
        monitor("/update/loss/softq", softq_loss.item())
        monitor("/update/loss/value", value_loss.item())

        self.session.writer.add_scalar('loss/actor', actor_loss.item(),
                                       self.state["i_step"])
        self.session.writer.add_scalar('loss/softq', softq_loss.item(),
                                       self.state["i_step"])
        self.session.writer.add_scalar('loss/value', value_loss.item(),
                                       self.state["i_step"])

        # for key,item in locals().items():
        #     if isinstance(item, torch.Tensor):
        #         # print("item =", type(item))
        #         print(key, ":", item.shape)
        # print("-----------------------------")

        self.state["i_step"] += 1
Exemple #22
0
    def step(self):
        """This function needs the following key values in the batch of memory:

        * ``/obs_with_key``
        * ``/rewards``
        * ``/agents/<agent_name>/actions``
        * ``/obs_with_key_2``

        The first three keys are generated by the :class:`~digideep.environment.explorer.Explorer`
        and the last key is added by the sampler.
        """

        with KeepTime("sampler"):
            info = deepcopy(self.params["sampler_args"])
            batch = self.sampler(data=self.memory, info=info)
            if batch is None:
                return

        with KeepTime("to_torch"):
            # ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2', ...]
            o1 = torch.from_numpy(batch["/observations" +
                                        self.params["observation_path"]]).to(
                                            self.device).float()
            a1 = torch.from_numpy(batch["/agents/" + self.params["name"] +
                                        "/actions"]).to(self.device).float()
            r1 = torch.from_numpy(batch["/rewards"]).to(self.device).float()
            o2 = torch.from_numpy(
                batch["/observations" + self.params["observation_path"] +
                      "_2"]).to(self.device).float()
            masks = torch.from_numpy(batch["/masks"]).to(self.device)
            # .view(-1).float()

        # with KeepTime("to_torch"):
        #     # ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2']
        #     o1 = torch.from_numpy(batch["/obs_with_key"]).to(self.device).float()
        #     r1 = torch.from_numpy(batch["/rewards"]).to(self.device).float()
        #     a1 = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device).float()
        #     o2 = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device).float()
        #     masks = torch.from_numpy(batch["/masks"]).to(self.device).view(-1).float()
        #     # o1.clamp_(min=-self.params["trainer"]["clamp_obs"], max= self.params["trainer"]["clamp_obs"])
        #     # o2.clamp_(min=-self.params["trainer"]["clamp_obs"], max= self.params["trainer"]["clamp_obs"])

        with KeepTime("loss/critic"):
            # ---------------------- optimize critic ----------------------
            # Use target actor exploitation policy here for loss evaluation
            a2 = self.policy.model["actor_target"](o2).detach()
            next_val = self.policy.model["critic_target"](o2, a2).detach()
            # next_val = torch.squeeze(self.policy.model["critic_target"](o2, a2).detach())

            # y_target = r + gamma * Q'( s2, pi'(s2))
            # NOTE: THIS SENTENCE IS VERY IMPORTANT!
            # r1 = torch.squeeze(r1)
            r1 = r1
            y_target = r1 + masks * next_val * float(
                self.params["methodargs"]["gamma"])

            # TODO: IT WASN'T IN THE ORIGINAL IMPLEMENTATION BUT IN HER's.
            # y_target.clamp_(min=-self.params["methodargs"]["clamp_return"], max=0)

            # y_pred = Q( s1, a1)
            y_predicted = self.policy.model["critic"](o1, a1)
            # y_predicted = torch.squeeze(self.policy.model["critic"](o1, a1))
            # compute critic loss, and update the critic
            # smooth_l1_loss: Calculates l2 norm near zero and l1 elsewhere

            # NOTE: The following is in DDPG+HER implementation.
            # loss_critic = F.mse_loss(y_predicted, y_target, reduction='sum')
            # NOTE: The following was used in the original!
            loss_critic = F.smooth_l1_loss(y_predicted, y_target)

            self.optimizer["critic"].zero_grad()
            loss_critic.backward()
            self.optimizer["critic"].step()

        with KeepTime("loss/actor"):
            # ---------------------- optimize actor ----------------------
            pred_a1 = self.policy.model["actor"](o1)

            loss_actor = -1 * torch.sum(self.policy.model["critic"](o1,
                                                                    pred_a1))
            self.optimizer["actor"].zero_grad()
            loss_actor.backward()
            self.optimizer["actor"].step()

        monitor("/update/loss_actor", loss_actor.item())
        monitor("/update/loss_critic", loss_critic.item())

        self.session.writer.add_scalar('loss/actor', loss_actor.item(),
                                       self.state["i_step"])
        self.session.writer.add_scalar('loss/critic', loss_critic.item(),
                                       self.state["i_step"])

        self.state["i_step"] += 1
Exemple #23
0
    def step(self):
        """Function that runs the ``prestep`` and the actual ``env.step`` functions.
        It will also manipulate the transition data to be in appropriate format.

        Returns:
            dict: The full transition information, including the pre-transition (actions, last observations, etc) and the
            results of executing actions on the environments, i.e. rewards and infos. The format is like:
            ``{"observations":..., "masks":..., "rewards":..., "infos":..., "agents":...}``
        
        See Also:
            :ref:`ref-data-structure`
        """

        # We are saving old versions of observations, hidden_state, and masks.
        with KeepTime("prestep"):
            pre_transition = self.prestep()

        # TODO: For true multi-agent systems, rewards must be a dictionary as well,
        #       i.e. one reward for each agent. However, if the agents are pursuing
        #       a single goal, the reward can still be a single scalar!

        # Updating observations and masks: These two are one step old in the trajectory.
        # hidden_state is the newest.

        with KeepTime("envstep"):
            # Prepare actions
            actions = extract_keywise(pre_transition["agents"], "actions")

            # Step
            self.state["observations"], rewards, dones, infos = self.envs.step(
                actions)
            # Post-step
            self.state["hidden_state"] = extract_keywise(
                pre_transition["agents"], "hidden_state")
            self.state["masks"] = np.array(
                [0.0 if done_ else 1.0 for done_ in dones],
                dtype=np.float32).reshape((-1, 1))

            # NOTE: Uncomment if you find useful information in the continuous rewards ...
            # monitor("/reward/"+self.params["mode"]+"/continuous", np.mean(rewards))

        with KeepTime("render"):
            if self.params["render"]:
                self.envs.render()
                if self.params["render_delay"] > 0:
                    time.sleep(self.params["render_delay"])
        # except MujocoException as e:
        #     logger.error("We got a MuJoCo exception!")
        #     raise
        #     ## Retry??
        #     # return self.run()

        with KeepTime("poststep"):
            # TODO: Sometimes the type of observations is "dict" which shouldn't be. Investigate the reason.
            if isinstance(self.state["observations"],
                          OrderedDict) or isinstance(
                              self.state["observations"], dict):
                for key in self.state["observations"]:
                    if np.isnan(self.state["observations"][key]).any():
                        logger.warn(
                            'NaN caught in observations during rollout generation.',
                            'step =', self.state["steps"])
                        raise ValueError
            else:
                if np.isnan(self.state["observations"]).any():
                    logger.warn(
                        'NaN caught in observations during rollout generation.',
                        'step =', self.state["steps"])
                    raise ValueError
                ## Retry??
                # return self.run()

            self.state["steps"] += 1
            self.state["timesteps"] += self.params["num_workers"]
            self.monitor_timesteps()
            # TODO: Adapt with the new dict_of_lists data structure.
            with KeepTime("report_reward"):
                self.report_rewards(infos)

            transition = dict(**pre_transition, rewards=rewards, infos=infos)
        return transition
Exemple #24
0
    def step(self):
        """This function needs the following key values in the batch of memory:

        * ``/observations``
        * ``/rewards``
        * ``/agents/<agent_name>/actions``
        * ``/observations_2``

        The first three keys are generated by the :class:`~digideep.environment.explorer.Explorer`
        and the last key is added by the sampler.
        """
        with KeepTime("sampler"):
            info = deepcopy(self.params["sampler_args"])

            batch_size = info["batch_size"]
            b = self.scheduler.value

            demo_batch_size = int(b * batch_size)
            train_batch_size = batch_size - demo_batch_size

            info["batch_size_dict"] = {
                "train": train_batch_size,
                "demo": demo_batch_size
            }

            batch = self.sampler(data=self.memory, info=info)
            if batch is None:
                return

        with KeepTime("to_torch"):
            # ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2', ...]

            # for k in batch:
            #     print(k, "dtype:", batch[k].dtype)
            # exit()

            # Keys:
            #   /observations/agent dtype: float32
            #   /observations/demonstrator/distance dtype: float32
            #   /observations/demonstrator/hand_closure dtype: float32
            #   /observations/status/is_training dtype: uint8
            #   /masks dtype: float32
            #   /agents/agent/actions dtype: float32
            #   /agents/agent/hidden_state dtype: float32
            #   /agents/demonstrator/actions dtype: float32
            #   /agents/demonstrator/hidden_state/time_step dtype: float32
            #   /agents/demonstrator/hidden_state/initial_distance dtype: float32
            #   /rewards dtype: float32
            #   /obs_with_key dtype: float32
            #   /obs_with_key_2 dtype: float32

            #   /observations/camera_2 dtype: float32
            #   /observations/camera dtype: uint8

            # state      = torch.from_numpy(batch["/obs_with_key"]).to(self.device)
            # next_state = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device)

            state = torch.from_numpy(batch["/observations/camera"]).to(
                self.device)
            action = torch.from_numpy(batch["/agents/" + self.params["name"] +
                                            "/actions"]).to(self.device)
            reward = torch.from_numpy(batch["/rewards"]).to(self.device)
            next_state = torch.from_numpy(batch["/observations/camera_2"]).to(
                self.device)
            # masks      = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)
            masks = torch.from_numpy(batch["/masks"]).to(self.device)

            # print("state =", state.shape)
            # print("action =", action.shape)
            # print("reward =", reward.shape)
            # print("next_state =", next_state.shape)
            # print("masks =", masks.shape)
            # exit()

            ## print(">>>>>> state shape =", state.shape)
            ## print(">>>>>> next_state shape =", next_state.shape)

        with KeepTime("loss"):
            expected_q_value = self.policy.model["softq"](state, action)
            expected_value = self.policy.model["value"](state)
            new_action, log_prob, z, mean, log_std = self.policy.evaluate_actions(
                state)

            target_value = self.policy.model["value_target"](next_state)

            next_q_value = reward + masks * float(
                self.params["methodargs"]["gamma"]) * target_value
            softq_loss = self.criterion["softq"](expected_q_value,
                                                 next_q_value.detach())

            expected_new_q_value = self.policy.model["softq"](state,
                                                              new_action)
            next_value = expected_new_q_value - log_prob
            value_loss = self.criterion["value"](expected_value,
                                                 next_value.detach())

            log_prob_target = expected_new_q_value - expected_value
            # TODO: Apperantly the calculation of actor_loss is problematic: none of its ingredients have gradients! So backprop does nothing.
            actor_loss = (log_prob *
                          (log_prob - log_prob_target).detach()).mean()

            mean_loss = float(
                self.params["methodargs"]["mean_lambda"]) * mean.pow(2).mean()
            std_loss = float(self.params["methodargs"]
                             ["std_lambda"]) * log_std.pow(2).mean()
            z_loss = float(self.params["methodargs"]["z_lambda"]) * z.pow(
                2).sum(1).mean()

            actor_loss += mean_loss + std_loss + z_loss

        with KeepTime("optimization"):
            self.optimizer["softq"].zero_grad()
            softq_loss.backward()
            self.optimizer["softq"].step()

            self.optimizer["value"].zero_grad()
            value_loss.backward()
            self.optimizer["value"].step()

            self.optimizer["actor"].zero_grad()
            actor_loss.backward()
            self.optimizer["actor"].step()

        monitor("/update/loss/actor", actor_loss.item())
        monitor("/update/loss/softq", softq_loss.item())
        monitor("/update/loss/value", value_loss.item())

        # for key,item in locals().items():
        #     if isinstance(item, torch.Tensor):
        #         # print("item =", type(item))
        #         print(key, ":", item.shape)
        # print("-----------------------------")

        self.state["i_step"] += 1
Exemple #25
0
def get_sample_memory(memory, info):
    """Sampler function for DDPG-like algorithms where we want to sample data from an experience replay buffer.

    This function adds the following key to the buffer:
    
    * ``/observations_2``

    Returns:
        dict: One sampled batch to be used in the DDPG algorithm for one step of training. The shape of each
        key in the output batch will be: ``(batch_size, *key_shape[2:])``

    """
    # Get information from info
    batch_size = info["batch_size"]
    observation_path = info["observation_path"]
    # Whether to use CER or not:
    # use_cer = info.get("use_cer", False)

    # Get the main data from the memory
    buffer = memory.get_buffer()

    # Get some constants from the memory
    num_workers = memory.get_num_batches()
    # num_records = memory.length * num_workers
    N = memory.get_last_trans_index(
    ) - 1  # We don't want to consider the last "incomplete" record, hence "-1"

    record_arr = memory.get_index_valid_elements()
    worker_arr = np.arange(0, num_workers)

    num_records = len(record_arr) * num_workers

    # with KeepTime("mask_array"):
    #     masks_arr = buffer["/masks"][:,record_arr]
    #     masks_arr = masks_arr.reshape(-1)

    if batch_size >= num_records:
        # We don't have enough data in the memory yet.
        logger.debug(
            "batch_size ({}) should be smaller than total number of records (~ {}={}x{})."
            .format(batch_size, num_records, num_workers, len(record_arr)))
        return None

    with KeepTime("sampling_by_choice"):
        # if use_cer:
        #     last_chunk_indices = memory.get_index_valid_last_chunk()
        #     available_batch_size = len(last_chunk_indices) * num_workers
        #     if available_batch_size <= batch_size:
        #         # We have selected a few transitions from previous step.
        #         # Now, we should sample the rest from the replay buffer.
        #         sample_record_recent = np.repeat(last_chunk_indices, num_workers)   # 10 10 10 10 11 11 11 11 ...
        #         sample_worker_recent = np.tile(worker_arr, len(last_chunk_indices)) #  0  1  2  3  0  1  2  3 ...
        #
        #         batch_size_prime = batch_size - available_batch_size
        #
        #         # Select the rest ...
        #         sample_record_prime = np.random.choice(record_arr, batch_size_prime, replace=True)
        #         sample_worker_prime = np.random.choice(worker_arr, batch_size_prime, replace=True)
        #
        #         # Combine
        #         sample_record = np.concatenate([sample_record_recent, sample_record_prime])
        #         sample_worker = np.concatenate([sample_worker_recent, sample_worker_prime])
        #     else:
        #
        #         # OK, we have enough data, so no sampling!
        #         logger.warn("CER: Latest transitions greater than batch size. Sample from last transitions.")
        #
        #         sample_record = np.random.choice(last_chunk_indices, batch_size, replace=True)
        #         sample_worker = np.random.choice(worker_arr,         batch_size, replace=True)
        #
        # else:

        # NOTE: NEVER ever use sampling WITHOUT replacement: Its time scales up with the array size.
        # Sampling with replacement:
        sample_record = np.random.choice(record_arr, batch_size, replace=True)
        sample_worker = np.random.choice(worker_arr, batch_size, replace=True)

        # Move the next step samples
        sample_record_2 = memory.get_index_move_n_steps(sample_record, 1)
        # Make a table of indices to extract transitions
        sample_tabular = [[sample_worker], [sample_record]]
        sample_tabular_2 = [[sample_worker], [sample_record_2]]

    with KeepTime("tabular_index_extraction"):
        # Extracting the indices
        batch = {}
        for key in buffer:
            batch[key] = buffer[key][sample_tabular[0], sample_tabular[1]]
    with KeepTime("post_key_generation"):
        observation_path = "/observations" + observation_path
        # Adding predictive keys
        batch[observation_path +
              "_2"] = buffer[observation_path][sample_tabular_2[0],
                                               sample_tabular_2[1]]

    with KeepTime("flatten_first_two"):
        batch = flatten_first_two(batch)
    return batch