Ejemplo n.º 1
0
    def init_buffer(self):
        obs_list = [preprocess_frame(self.env.reset())] * 5
        for i in range(self.buffer_start_size):
            obs_list.pop(0)

            action = self.env.action_space.sample()
            obs_p, r, done, _ = self.env.step(action)
            obs_list.append(preprocess_frame(obs_p))

            if done:
                self.env.reset()

            transition = (np.stack(obs_list, -1), action, r, done)
            self.replay_queue.append(transition)
    def get_inference_outputs(self):

        t0 = time.perf_counter()
        t_count = 0

        inputs_model = self.input_blobs()
        prepro_img_face = utils.preprocess_frame(self.frame,
                                                 self.image_input_shape)
        inputs_to_feed = {inputs_model[0]: prepro_img_face}

        t_start = time.perf_counter()

        points = self.inference(inputs_to_feed)

        t_end = time.perf_counter()
        t_count += 1
        log.info(
            "model {} is processed with {:0.2f} requests/sec ({:0.2} sec per request)"
            .format(self.model_name, 1 / (t_end - t_start), t_end - t_start))

        data_l_eye, data_r_eye, data_points_marks = self.get_box_eyes_data(
            points, self.initial_h, self.initial_w)

        left_eye_center_points, right_eye_center_points = self.get_eyes_center(
            points, self.initial_h, self.initial_w)

        return left_eye_center_points, right_eye_center_points, data_l_eye, data_r_eye, data_points_marks
Ejemplo n.º 3
0
def make_env(title=None, frame_skip=0):
    env = RoadFollowingEnv(title=title,
                           encode_state_fn=lambda x: preprocess_frame(x.frame),
                           throttle_scale=0.1,
                           max_speed=30.0,
                           frame_skip=frame_skip)
    return env
Ejemplo n.º 4
0
 def new_game(self):
     screen = self.env.reset()
     if self.do_render:
         self.env.render()
     is_terminal = False
     reward = 0
     self.history = deque([utils.preprocess_frame(screen) for _ in range(self.n_frame_input)])
     return self._dstack(self.history), reward, is_terminal
Ejemplo n.º 5
0
    def optimize(self):

        if len(self.memory) < cons.batch_size:
            return

        state, action, new_state, reward, done = self.memory.sample(
            cons.batch_size)

        state = [preprocess_frame(frame, cons.device) for frame in state]
        state = torch.cat(state)

        new_state = [
            preprocess_frame(frame, cons.device) for frame in new_state
        ]
        new_state = torch.cat(new_state)

        action = cons.LongTensor(action).to(cons.device)
        reward = cons.Tensor(reward).to(cons.device)
        done = cons.Tensor(done).to(cons.device)

        new_state_values = self.atari_target_nn(new_state).detach()
        max_new_state_values = torch.max(new_state_values, 1)[0]
        target_value = reward + (1 - done) * cons.gamma * max_new_state_values

        predicted_value = self.atari_nn(state).gather(
            1, action.unsqueeze(1)).squeeze(1)

        loss = self.loss_func(predicted_value, target_value)

        self.optimizer.zero_grad()

        loss.backward()

        if cons.clip_error:
            for param in self.atari_nn.parameters():
                param.grad.data.clamp_(-1, 1)

        self.optimizer.step()

        if self.number_of_frames % cons.update_target_frequency == 0:
            self.atari_target_nn.load_state_dict(self.atari_nn.state_dict())

        if self.number_of_frames % cons.save_model_frequency == 0:
            save_model(self.atari_nn, cons.model_file)

        self.number_of_frames += 1
Ejemplo n.º 6
0
 def step(self, action, include_noclip=False):
     screen, reward, is_terminal, _ = self.env.step(action)
     if self.do_render:
         self.env.render()
     self.history.append(utils.preprocess_frame(screen))
     self.history.popleft()
     clipped_reward = max(-1, min(1, reward))
     if include_noclip:
         return self._dstack(self.history), clipped_reward, is_terminal, reward
     else:
         return self._dstack(self.history), clipped_reward, is_terminal
Ejemplo n.º 7
0
    def select_action(self, state, epsilon):
        random_for_egreedy = torch.rand(1).item()

        if random_for_egreedy > epsilon:
            with torch.no_grad():
                state = preprocess_frame(state)
                action_from_nn = self.nn(state)
                action = torch.max(action_from_nn, 1)[1].item()
        else:
            action = self.env.action_space.sample()

        return action
Ejemplo n.º 8
0
    def optimize(self):
        if len(self.memory) < self.batch_size:
            return

        state, action, new_state, reward, done = self.memory.sample(
            self.batch_size)
        state = [preprocess_frame(frame) for frame in state]
        state = torch.cat(state)  # stack tensor

        new_state = [preprocess_frame(frame) for frame in new_state]
        new_state = torch.cat(new_state)

        reward = torch.Tensor(reward).to(device)
        action = torch.LongTensor(action).to(device)
        done = torch.Tensor(done).to(device)

        # Double DQN
        max_new_state_indexes = torch.argmax(self.nn(new_state).detach(), 1)

        new_state_values = self.target_nn(new_state).detach()
        max_new_state_values = new_state_values.gather(
            1, max_new_state_indexes.unsqueeze(1)).squeeze(1)

        target_value = reward + (1 - done) * self.gamma * max_new_state_values
        predicted_value = self.nn(state).gather(1,
                                                action.unsqueeze(1)).squeeze(1)

        loss = self.loss_func(predicted_value, target_value)
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.nn.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        if self.number_of_frames % self.save_model_frequency == 0:
            save_model(self.nn)

        if self.number_of_frames % self.update_target_frequency == 0:
            self.target_nn.load_state_dict(self.nn.state_dict())
        self.number_of_frames += 1
Ejemplo n.º 9
0
    def select_action_boltzmann(self, state, temperature, epsilon_boltz):

        random_for_egreedy = torch.rand(1)[0]

        if random_for_egreedy > epsilon_boltz:
            with torch.no_grad():

                state = preprocess_frame(state, cons.device)
                action_from_nn = self.atari_nn(state)
                action = torch.max(action_from_nn, 1)[1]
                action = action.item()
        else:

            state = preprocess_frame(state, cons.device)
            action_from_nn = self.atari_nn(state)
            action_from_nn = action_from_nn.cpu()
            action_from_nn = action_from_nn.detach().numpy()
            action_from_nn = action_from_nn[0]

            expected_reward_array = []

            for x in range(14):
                temp_reward, reward_frequency = get_frequency(conn, x)

                expected_reward = action_from_nn[
                    x] + cons.boltzmann_weight * reward_frequency * temp_reward

                expected_reward_array.append(expected_reward)

            exponent = np.true_divide(
                expected_reward_array - np.max(expected_reward_array),
                temperature)

            action_probs = np.exp(exponent) / np.sum(np.exp(exponent))

            action = np.random.choice(14, p=action_probs)

        return action
Ejemplo n.º 10
0
    def select_action_egreedy(self, state, epsilon_egreedy):

        random_for_egreedy = torch.rand(1)[0]

        if random_for_egreedy > epsilon_egreedy:
            with torch.no_grad():

                state = preprocess_frame(state, cons.device)
                action_from_nn = self.atari_nn(state)
                action = torch.max(action_from_nn, 1)[1]
                action = action.item()
        else:
            action = random.randrange(0, 13)

        return action
Ejemplo n.º 11
0
 def encode_state(state):
     """
         Function that encodes the current state of
         the environment into some feature vector.
     """
     frame = preprocess_frame(state.frame)
     encoded_state = model.encode([frame])[0]
     if with_measurements:
         encoded_state = np.append(
             encoded_state,
             [state.throttle, state.steering, state.velocity / 30.0])
     if isinstance(stack, int):
         s1 = np.array(encoded_state)
         if not hasattr(state, "stack"):
             state.stack = [
                 np.zeros_like(encoded_state) for _ in range(stack)
             ]
             state.stack_idx = 0
         state.stack[state.stack_idx % stack] = s1
         state.stack_idx += 1
         concat_state = np.concatenate(state.stack)
         return concat_state
     return np.array(encoded_state)
Ejemplo n.º 12
0
def deep_qlearning(env, nframes, discount_factor, N, C, mini_batch_size,
                   replay_start_size, sgd_update_frequency,
                   initial_exploration, final_exploration,
                   final_exploration_frame, lr, alpha, m):
    """
    Input:
    - env: environment
    - nframes: # of frames to train on
    - discount_factor (gamma): how much to discount future rewards
    - N: replay memory size
    - C: number of steps before updating Q target network
    - mini_batch_size: mini batch size
    - replay_start_size: minimum size of replay memory before learning starts
    - sgd_update_frequency: number of action selections in between consecutive
      mini batch SGD updates
    - initial_exploration: initial epsilon value
    - final_exploration: final epsilon value
    - final_exploration_frame: number of frames over which the epsilon is
      annealed to its final value
    - lr: learning rate used by RMSprop
    - alpha: alpha value used by RMSprop
    - m: number of consecutive frames to stack for input to Q network

    Output:
    - Q: trained Q-network
    """
    n_actions = env.action_space.n
    Q = QNetwork(n_actions)
    Q_target = deepcopy(Q)
    Q_target.eval()

    transform = T.Compose([T.ToTensor()])
    optimizer = optim.RMSprop(Q.parameters(), lr=lr, alpha=alpha)
    criterion = nn.MSELoss()

    D = deque(maxlen=N)  # replay memory

    last_Q_target_update = 0
    frames_count = 0
    last_sgd_update = 0
    episodes_count = 0
    episode_rewards = []

    while True:
        frame_sequence = initialize_frame_sequence(env, m)
        state = transform(np.stack(frame_sequence, axis=2))

        episode_reward = 0
        done = False

        while not done:
            epsilon = annealed_epsilon(initial_exploration, final_exploration,
                                       final_exploration_frame, frames_count)

            action = get_epsilon_greedy_action(Q, state.unsqueeze(0), epsilon,
                                               n_actions)

            frame, reward, done, _ = env.step(action.item())
            reward = torch.tensor([reward])

            episode_reward += reward.item()
            if done:
                next_state = None
                episode_rewards.append(episode_reward)
            else:
                frame_sequence.append(preprocess_frame(frame))
                next_state = transform(np.stack(frame_sequence, axis=2))

            D.append((state, action, reward, next_state))

            state = next_state

            if len(D) < replay_start_size:
                continue

            last_sgd_update += 1
            if last_sgd_update < sgd_update_frequency:
                continue
            last_sgd_update = 0

            sgd_update(Q, Q_target, D, mini_batch_size, discount_factor,
                       optimizer, criterion)

            last_Q_target_update += 1
            frames_count += mini_batch_size

            if last_Q_target_update % C == 0:
                Q_target = deepcopy(Q)
                Q_target.eval()

            if frames_count >= nframes:
                return Q, episode_rewards

        episodes_count += 1
        if episodes_count % 100 == 0:
            save_stuff(Q, episode_rewards)
            print(f'episodes completed = {episodes_count},',
                  f'frames processed = {frames_count}')
Ejemplo n.º 13
0
EPSILON_DECAY_STEPS = float(0.9/1e5)
MIN_EPS = 0.1


env = gym.make("BreakoutDeterministic-v4")
total_reward = 0
# eps = 1

agent = DeepQAgent(BUFFER_SIZE, env, BATCH_SIZE, BUFFER_START_SIZE, MIN_EPS,EPSILON_DECAY_STEPS)

total_steps = 0
for ep in range(NUM_EPISODES):

    step = 0
    done = False
    obs_list = [preprocess_frame(env.reset())] * 5

    ep_reward = 0

    # Loop until MAX_STEPS reached or env returns done (checked at the bottom)
    while step < MAX_STEPS:
        step += 1
        env.render()
        obs_list.pop(0)

        action = agent.choose_action(obs_list)


        obs_p, r, done, _ = env.step(action)
        obs_list.append(preprocess_frame(obs_p))
def generate_agent_episodes(args):

    full_path = ROLLOUT_DIR + '/rollout_' + args.env_name

    if not os.path.exists(full_path):
        os.umask(0o000)
        os.makedirs(full_path)

    env_name = args.env_name
    total_episodes = args.total_episodes
    time_steps = args.time_steps

    envs_to_generate = [env_name]

    for current_env_name in envs_to_generate:
        print("Generating data for env {}".format(current_env_name))

        env = gym.make(current_env_name)  # Create the environment
        env.seed(0)

        # First load the DQN agent and the predictive auto encoder with their weights
        agent = Agent(gamma=0.99,
                      epsilon=0.0,
                      alpha=0.0001,
                      input_dims=(104, 80, 4),
                      n_actions=env.action_space.n,
                      mem_size=25000,
                      eps_min=0.0,
                      batch_size=32,
                      replace=1000,
                      eps_dec=1e-5,
                      env_name=current_env_name)
        agent.load_models()

        predictor = load_predictive_model(current_env_name, env.action_space.n)

        s = 0

        while s < total_episodes:

            rollout_file = os.path.join(full_path, 'rollout-%d.npz' % s)

            observation = env.reset()
            frame_queue = deque(maxlen=4)
            dqn_queue = deque(maxlen=4)

            t = 0

            next_state_sequence = []
            correct_state_sequence = []
            total_reward = 0
            while t < time_steps:
                # preprocess frames for predictive model and dqn
                converted_obs = preprocess_frame(observation)
                converted_obs_dqn = preprocess_frame_dqn(observation)

                if t == 0:
                    for i in range(4):
                        frame_queue.append(converted_obs)
                        dqn_queue.append(converted_obs_dqn)
                else:
                    frame_queue.pop()
                    dqn_queue.pop()
                    frame_queue.appendleft(converted_obs)
                    dqn_queue.appendleft(converted_obs_dqn)

                observation_states = np.concatenate(frame_queue, axis=2)
                dqn_states = np.concatenate(dqn_queue, axis=2)
                next_states = predictor.generate_output_states(
                    np.expand_dims(observation_states, axis=0))
                next_state_sequence.append(next_states)
                action = agent.choose_action(dqn_states)
                correct_state_sequence.append(
                    encode_action(env.action_space.n, action))

                observation, reward, done, info = env.step(
                    action)  # Take a random action
                total_reward += reward
                t = t + 1

            print(
                "Episode {} finished after {} timesteps with reward {}".format(
                    s, t, total_reward))

            np.savez_compressed(rollout_file,
                                next=next_state_sequence,
                                correct=correct_state_sequence)

            s = s + 1

        env.close()
Ejemplo n.º 15
0
import cv2
import time

Q = torch.load('DQN/trained_Q.pth')
Q.eval()
env = gym.make('Breakout-v0', frameskip=4)
env.reset()

m = 4
num_episodes = 2

transform = T.Compose([T.ToTensor()])

for _ in range(num_episodes):
    frame_sequence = initialize_frame_sequence(env, m)
    state = transform(np.stack(frame_sequence, axis=2))
    done = False

    while not done:
        action = get_greedy_action(Q, state.unsqueeze(0)).item()
        frame, reward, done, _ = env.step(action)

        frame_sequence.append(preprocess_frame(frame))
        state = transform(np.stack(frame_sequence, axis=2))

        env.render()
        time.sleep(.1)

        # cv2.imshow('', frame)
        # cv2.waitKey(100)
Ejemplo n.º 16
0
def init_queue(queue, observation, dqn=False):
    for i in range(4):
        if dqn:
            queue.append(preprocess_frame_dqn(observation))
        else:
            queue.append(preprocess_frame(observation))
def main(args):

    env_name = args.env_name
    total_episodes = args.total_episodes
    time_steps = args.time_steps
    informed = args.informed
    # action_refresh_rate = args.action_refresh_rate

    if informed:
        full_path = ROLLOUT_DIR + '/informed_rollout_' + args.env_name
    else:
        full_path = ROLLOUT_DIR + '/random_rollout_' + args.env_name

    if not os.path.exists(full_path):
        os.umask(0o000)
        os.makedirs(full_path)

    envs_to_generate = [env_name]

    for current_env_name in envs_to_generate:
        print("Generating data for env {}".format(current_env_name))

        env = gym.make(current_env_name)  # Create the environment
        env.seed(0)

        s = 0

        if informed:
            agent = load_dqn(env)

        while s < total_episodes:

            rollout_file = os.path.join(full_path, 'rollout-%d.npz' % s)

            observation = env.reset()
            frame_queue = deque(maxlen=4)
            dqn_queue = deque(maxlen=4)

            t = 0

            obs_sequence = []
            action_sequence = []
            next_sequence = []

            while t < time_steps:

                # convert image to greyscale, downsize
                converted_obs = preprocess_frame(observation)

                if t == 0:
                    for i in range(4):
                        frame_queue.append(converted_obs)
                else:
                    frame_queue.pop()
                    frame_queue.appendleft(converted_obs)

                stacked_state = np.concatenate(frame_queue, axis=2)
                obs_sequence.append(stacked_state)

                if informed:
                    dqn_obs = preprocess_frame_dqn(observation)
                    if t == 0:
                        for i in range(4):
                            dqn_queue.append(dqn_obs)
                    else:
                        dqn_queue.pop()
                        dqn_queue.appendleft(dqn_obs)
                    stacked = np.concatenate(dqn_queue, axis=2)
                    action = agent.choose_action(stacked)
                else:
                    action = env.action_space.sample()

                action_sequence.append(
                    encode_action(env.action_space.n, action))

                observation, _, _, _ = env.step(action)  # Take a random action
                t = t + 1

                next_sequence.append(preprocess_frame(observation))

            print("Episode {} finished after {} timesteps".format(s, t))

            np.savez_compressed(rollout_file,
                                obs=obs_sequence,
                                actions=action_sequence,
                                next_frame=next_sequence)

            s = s + 1

        env.close()
Ejemplo n.º 18
0
def main():
    encoder = EncoderCNN()
    encoder.eval()
    encoder.cuda()

    # 读取视频列表,让视频按照id升序排列
    videos = sorted(os.listdir(video_root), key=video_sort_lambda)
    nvideos = len(videos)

    # 创建保存视频特征的hdf5文件
    if os.path.exists(video_h5_path):
        # 如果hdf5文件已经存在,说明之前处理过,或许是没有完全处理完
        # 使用r+ (read and write)模式读取,以免覆盖掉之前保存好的数据
        h5 = h5py.File(video_h5_path, 'r+')
        dataset_feats = h5[video_h5_dataset]
    else:
        h5 = h5py.File(video_h5_path, 'w')
        dataset_feats = h5.create_dataset(video_h5_dataset,
                                          (nvideos, num_frames, frame_size),
                                          dtype='float32')
    for i, video in enumerate(videos):
        print(video, end=' ')
        video_path = os.path.join(video_root, video)
        try:
            cap = cv2.VideoCapture(video_path)
        except:
            print('Can not open %s.' % video)
            pass

        frame_count = 0
        frame_list = []

        # 每frame_sample_rate(10)帧采1帧
        count = 0
        while True:
            ret, frame = cap.read()
            if ret is False:
                break
            if count % frame_sample_rate == 0:
                frame_list.append(frame)
                frame_count += 1
            count += 1

        print(frame_count)
        frame_list = np.array(frame_list)
        if frame_count > num_frames:
            # 等间隔地取一些帧
            frame_indices = np.linspace(0,
                                        frame_count,
                                        num=num_frames,
                                        endpoint=False).astype(int)
            frame_list = frame_list[frame_indices]
            # 直接截断
            # frame_list = frame_list[:num_frames]
            frame_count = num_frames

        # 把图像做一下处理,然后转换成(batch, channel, height, width)的格式
        cropped_frame_list = np.array(
            [preprocess_frame(x) for x in frame_list]).transpose((0, 3, 1, 2))
        cropped_frame_list = Variable(torch.from_numpy(cropped_frame_list),
                                      volatile=True).cuda()

        # 视频特征的shape是num_frames x 4096
        # 如果帧的数量小于num_frames,则剩余的部分用0补足
        feats = np.zeros((num_frames, frame_size), dtype='float32')
        feats[:frame_count, :] = encoder(cropped_frame_list).data.cpu().numpy()
        dataset_feats[i] = feats
Ejemplo n.º 19
0
 def preprocess(self, obs):
     obs = preprocess_frame(obs)
     last_obs = obs if self.last_obs is None else self.last_obs
     stacked_obs = np.stack([obs, last_obs], axis=0)
     self.last_obs = obs
     return torch.tensor([stacked_obs])
Ejemplo n.º 20
0
def test_against_environment(env_name, num_runs, agent_name):
    env = gym.make(env_name)
    # env.seed(0)
    try:
        predictor = load_predictive_model(env_name, env.action_space.n)
        if agent_name == 'Next_agent':
            agent = StateAgent(env.action_space.n, env_name)
            agent.set_weights()
        elif agent_name == 'DQN':
            agent = Agent(gamma=0.99,
                          epsilon=0.00,
                          alpha=0.0001,
                          input_dims=(104, 80, 4),
                          n_actions=env.action_space.n,
                          mem_size=25000,
                          eps_min=0.00,
                          batch_size=32,
                          replace=1000,
                          eps_dec=1e-5,
                          env_name=env_name)
            agent.load_models()
    except:
        print(
            "Error loading model, check environment name and action space dimensions"
        )

    rewards = []

    start = time.time()

    total_steps = 0.0
    for i in range(num_runs):
        frame_queue = deque(maxlen=4)

        observation = env.reset()
        done = False

        if agent_name == 'DQN':
            init_queue(frame_queue, observation, True)
        else:
            init_queue(frame_queue, observation)

        total_reward = 0.0
        frame_count = 0
        while not done:
            observation_states = np.concatenate(frame_queue, axis=2)

            # Human start of breakout since the next state agent just keeps moving to the left
            if agent_name == 'Next_agent':
                if env_name == 'BreakoutDeterministic-v4' and not frame_count:
                    agent_action = 1
                else:
                    next_states = predictor.generate_output_states(
                        np.expand_dims(observation_states, axis=0))
                    agent_action = agent.choose_action_from_next_states(
                        np.expand_dims(next_states, axis=0))
            elif agent_name == 'DQN':
                agent_action = agent.choose_action(observation_states)
            else:
                agent_action = env.action_space.sample()

            observation, reward, done, _ = env.step(agent_action)
            total_reward += reward
            frame_count += 1
            total_steps += 1

            frame_queue.pop()
            if agent_name == 'DQN':
                frame_queue.appendleft(preprocess_frame_dqn(observation))
            else:
                frame_queue.appendleft(preprocess_frame(observation))

        print("Completed episode {} with reward {}".format(
            i + 1, total_reward))
        rewards.append(total_reward)
    end = time.time()

    time_taken = (end - start) / total_steps

    print("Test complete - Average score: {}    Max score: {}".format(
        np.average(rewards), np.max(rewards)))
    return (rewards, time_taken)
Ejemplo n.º 21
0
    def collect_data(
            self,
            num_steps: Optional[int] = None,  # TODO: handle episode ends?
            num_episodes: Optional[int] = None,
            deterministic: Optional[Dict[str, bool]] = None,
            disable_tqdm: bool = True,
            max_steps: int = 102,
            reset_memory: bool = True,
            include_last: bool = False,
            finish_episode: bool = True,
            divide_rewards: Optional[int] = None,
            visual: bool = False,
            preserve_channels: bool = False) -> DataBatch:
        """
        Performs a rollout of the agents in the environment, for an indicated number of steps or episodes.

        Args:
            num_steps: number of steps to take; either this or num_episodes has to be passed (not both)
            num_episodes: number of episodes to generate
            deterministic: whether each agent should use the greedy policy; False by default
            disable_tqdm: whether a live progress bar should be (not) displayed
            max_steps: maximum number of steps that can be taken in episodic mode, recommended just above env maximum
            reset_memory: whether to reset the memory before generating data
            include_last: whether to include the last observation in episodic mode - useful for visualizations
            finish_episode: in step mode, whether to finish the last episode (resulting in more steps than num_steps)

        Returns: dictionary with the gathered data in the following format:

        {
            "observations":
                {
                    "Agent0": tensor([obs1, obs2, ...]),

                    "Agent1": tensor([obs1, obs2, ...])
                },
            "actions":
                {
                    "Agent0": tensor([act1, act2, ...]),

                    "Agent1": tensor([act1, act2, ...])
                },
            ...,


        }
        """
        assert not ((num_steps is None) == (num_episodes is None)), ValueError(
            "Exactly one of num_steps, num_episodes "
            "should receive a value")

        if deterministic is None:
            deterministic = {agent_id: False for agent_id in self.agent_ids}

        if reset_memory:
            self.reset()

        # obs: Union[Tuple, Dict]
        obs = self.env.reset()

        if self.tuple_mode:  # Convert obs to dict
            obs = convert_obs_to_dict(obs, self.agent_ids)

        obs = {
            key: preprocess_frame(obs_, preserve_channels)
            for key, obs_ in obs.items()
        }

        episode = 0

        for agent_id, agent in self.agents.items():
            agent.storage["last_obs"] = obs[agent_id]

        end_flag = False
        full_steps = (num_steps + 100 * int(finish_episode)
                      ) if num_steps else max_steps * num_episodes
        for step in trange(full_steps, disable=disable_tqdm):
            # Compute the action for each agent

            stacked_obs = {}
            for agent_id, agent in self.agents.items():
                stacked_obs[agent_id] = np.concatenate(
                    [obs[agent_id],
                     agent.storage.get("last_obs")], axis=0)

            # breakpoint()
            action_info = {  # action, logprob
                agent_id: self.agents[agent_id].compute_single_action(
                    stacked_obs[agent_id], deterministic[agent_id])
                for agent_id in self.agent_ids
            }

            # Unpack the actions
            action = {
                agent_id: action_info[agent_id][0]
                for agent_id in self.agent_ids
            }
            logprob = {
                agent_id: action_info[agent_id][1]
                for agent_id in self.agent_ids
            }

            # Actual step in the environment

            if self.tuple_mode:  # Convert action to env-compatible
                env_action = convert_action_to_env(action, self.agent_ids)
            else:
                env_action = action

            next_obs, reward, done, info = self.env.step(env_action)

            if self.tuple_mode:  # Convert outputs to dicts
                next_obs = convert_obs_to_dict(next_obs, self.agent_ids)
                reward = convert_obs_to_dict(reward, self.agent_ids)
                done = {agent_id: done for agent_id in self.agent_ids}

            next_obs = {
                key: preprocess_frame(obs_, preserve_channels)
                for key, obs_ in next_obs.items()
            }
            if divide_rewards:
                reward = {
                    key: (rew / divide_rewards)
                    for key, rew in reward.items()
                }

            # Saving to memory
            self.memory.store(stacked_obs, action, reward, logprob, done)

            # Frame stacking
            for agent_id, agent in self.agents.items():
                agent.storage["last_obs"] = obs[agent_id]

            # Handle episode/loop ending
            if finish_episode and step + 1 == num_steps:
                end_flag = True

            # Update the current obs - either reset, or keep going
            if all(done.values()):  # episode is over
                if include_last:  # record the last observation along with placeholder action/reward/logprob
                    self.memory.store(next_obs, action, reward, logprob, done)
                obs = self.env.reset()
                if self.tuple_mode:
                    obs = convert_obs_to_dict(obs, self.agent_ids)

                obs = {
                    key: preprocess_frame(obs_, preserve_channels)
                    for key, obs_ in obs.items()
                }

                # Frame stacking
                for agent_id, agent in self.agents.items():
                    agent.storage["last_obs"] = obs[agent_id]

                # Episode mode handling
                episode += 1
                if episode == num_episodes:
                    break
                # Step mode with episode finish handling
                if end_flag:
                    break
            else:  # keep going
                obs = next_obs

        return self.memory.get_torch_data()