Exemple #1
0
 def run(self):
     for i, trajectory in enumerate(self.trajectory_generator):
         self.saved = False
         self.is_counterfactual = False
         self.navigator = TrajectoryNavigator(trajectory)
         self.window = Window(f'Trajectory {i}')
         self.window.reg_key_handler(self.key_handler)
         self.reset()
         self.window.show(block=True)
         if not self.saved:
             raise Exception('Continued without saving the trajectory!')
Exemple #2
0
 def __init__(self, display=False, agent_view=5, map_size=20, roads=1, max_step=100):
     super().__init__()
     self.display = display
     self.map = Simple2Dv2(map_size, map_size, agent_view=agent_view, roads=roads, max_step=max_step)
     self.window = None
     if self.display:
         self.window = Window('GYM_MiniGrid')
         self.window.reg_key_handler(self.key_handler)
         self.window.show(True)
     self.detect_rate = []
     self.rewards = []
     self.step_count = []
     self.old = None
     self.new = None
     self._rewards = []
 def __init__(self,
              env,
              detector=None,
              render_mode=None,
              learners=[],
              **kwargs):
     self.env = env
     self.core_actions = self.env.actions
     self.detector = detector
     self._learners = learners
     self.executors = {
         "gotoobj1": (["agent", "graspable"], self.go_to_obj),
         "gotoobj2": (["agent", "notgraspable"], self.go_to_obj),
         "gotoobj3": (["agent", "graspable", "thing"], self.go_to_obj),
         "gotoobj4": (["agent", "notgraspable", "thing"], self.go_to_obj),
         "gotodoor1": (["agent", "door"], self.go_to_obj),
         "gotodoor2": (["agent", "door", "thing"], self.go_to_obj),
         "stepinto": (["agent", "goal"], self.step_into),
         "enterroomof": (["agent", "door",
                          "physobject"], self.enter_room_of),
         "usekey": (["agent", "key", "door"], self.usekey),
         "opendoor": (["agent", "door"], self.usekey),
         "pickup": (["agent", "graspable"], self.pickup),
         "putdown": (["agent", "graspable"], self.putdown),
         "core_drop": ([], self.core_drop),
         "core_pickup": ([], self.core_pickup),
         "core_forward": ([], self.core_forward),
         "core_right": ([], self.core_right),
         "core_left": ([], self.core_left),
         "core_toggle": ([], self.core_toggle),
         "core_done": ([], self.core_done)
     }
     self._window = None
     self._record_path = None
     self.set_record_path(0)
     if render_mode and render_mode.lower() == "human":
         self._window = Window("minigrid_executor")
         self._window.show(block=False)
     self.accumulator = None
def main():
    """
    Main method for running a game manually

    """

    # Parse arguments
    args = parser.parse_args()

    # Make environments
    env = gym.make(args.env_name)
    args.env = env

    #Build window and register key_handler
    window = Window('Cameleon - ' + args.env_name)
    args.window = window
    window.fig.canvas.mpl_connect("key_press_event",
                                  lambda event: args.key_handler(event, args))

    # Reset the game
    reset(args)

    # Blocking event loop
    window.show(block=True)
Exemple #5
0
                    default=32)
parser.add_argument('--agent_view',
                    default=False,
                    help="draw the agent sees (partially observable view)",
                    action='store_true')

if __name__ == "__main__":
    args = parser.parse_args()

    env = gym.make(args.env)

    if args.agent_view:
        env = RGBImgPartialObsWrapper(env)
        env = ImgObsWrapper(env)

    window = Window('gym_minigrid - ' + args.env)

    agent = DeepSARSAgent()

    global_step = 0
    scores, episodes = [], []

    for e in range(30000):
        done = False
        score = 0
        state = env.reset()
        state = np.reshape(state['image'], [1, 147])
        reset()

        while not done:
            # fresh env
Exemple #6
0
class Interface:
    """
    User interface for generating counterfactual states and actions.
    For each trajectory:
        The user can use the `a` and `d` keys to go backward and forward in time to a specific state they wish to
        generate a counterfactual explanation from. Once a specific state is found, the user presses `w` to launch the
        counterfactual mode.
        In the counterfactual mode, the user can control the agent using the keyboard. Once satisfied with their
        progress, the user can give up control to the bot to finish the episode by pressing `w`.

        The keys for controlling the agent are summarized here:

        escape: quit the program

        view mode:
            d: next time step
            a: previous time step
            w: select time step and go to counterfactual mode

        counterfactual mode:
            left:          turn counter-clockwise
            right:         turn clockwise
            up:            go forward
            space:         toggle
            pageup or x:   pickup
            pagedown or z: drop
            enter or q:    done (should not use)
            a:             undo action
            w:             roll out to the end of the episode and move to next episode
    """
    def __init__(self,
                 original_dataset: TrajectoryDataset,
                 counterfactual_dataset: TrajectoryDataset,
                 policy_factory=None):
        self.policy_factory = policy_factory
        self.dataset = original_dataset
        self.counterfactual_dataset = counterfactual_dataset
        self.trajectory_generator = self.dataset.trajectory_generator()
        self.navigator: TrajectoryNavigator = None
        self.window = None
        self.is_counterfactual = False
        self.run()

    def run(self):
        for i, trajectory in enumerate(self.trajectory_generator):
            self.saved = False
            self.is_counterfactual = False
            self.navigator = TrajectoryNavigator(trajectory)
            self.window = Window(f'Trajectory {i}')
            self.window.reg_key_handler(self.key_handler)
            self.reset()
            self.window.show(block=True)
            if not self.saved:
                raise Exception('Continued without saving the trajectory!')

    def redraw(self):
        step: TrajectoryStep = self.navigator.step()
        # if not self.agent_view:
        env = step.state
        img = env.render('rgb_array', tile_size=32)
        # else:
        # img = step.observation['image']
        # TODO later: figure out when to use the observation instead.

        self.window.show_img(img)

    def step(self, action=None):
        if action is None:
            self.navigator.forward()
        else:
            assert isinstance(self.navigator, CounterfactualNavigator)
            self.navigator.forward(action)
        self.redraw()

    def backward(self):
        self.navigator.backward()
        self.redraw()

    def reset(self):
        env = self.navigator.step().state

        if hasattr(env, 'mission'):
            print('Mission: %s' % env.mission)
            self.window.set_caption(env.mission)

        self.redraw()

    def select(self):
        new_navigator = CounterfactualNavigator(
            self.navigator.episode,
            self.navigator.index,
            self.navigator.step(),
            policy_factory=self.policy_factory)
        self.navigator = new_navigator
        self.is_counterfactual = True
        print(
            f'Starting counterfactual trajectory from {self.navigator.index}')
        self.redraw()

    def save_trajectory(self):
        assert isinstance(self.navigator, CounterfactualNavigator)
        self.navigator.store(self.counterfactual_dataset)
        self.saved = True

    def key_handler(self, event):
        print('pressed', event.key)

        if event.key == 'escape':
            self.window.close()
            exit()
            return

        # if event.key == 'backspace':
        #     self.reset()
        #     return
        if self.is_counterfactual:
            if event.key == 'left':
                self.step('left')
                return
            if event.key == 'right':
                self.step('right')
                return
            if event.key == 'up':
                self.step('forward')
                return

            # Spacebar
            if event.key == ' ':
                self.step('toggle')
                return
            if event.key == 'pageup' or event.key == 'x':
                self.step('pickup')
                return
            if event.key == 'pagedown' or event.key == 'z':
                self.step('drop')
                return

            if event.key == 'enter' or event.key == 'q':
                self.step('done')
                return

            if event.key == 'w':
                if self.policy_factory is not None:
                    self.navigator.rollout()
                self.save_trajectory()
                self.window.close()

            if event.key == 'a':
                self.backward()
                return

        if not self.is_counterfactual:
            if event.key == 'd':
                self.step()
                return

            if event.key == 'a':
                self.backward()
                return

            if event.key == 'w':
                self.select()
                return
class MiniGridExecutor(Executor):
    def _remove_learner(self, learner):
        self._learners.remove(learner)

    def __init__(self,
                 env,
                 detector=None,
                 render_mode=None,
                 learners=[],
                 **kwargs):
        self.env = env
        self.core_actions = self.env.actions
        self.detector = detector
        self._learners = learners
        self.executors = {
            "gotoobj1": (["agent", "graspable"], self.go_to_obj),
            "gotoobj2": (["agent", "notgraspable"], self.go_to_obj),
            "gotoobj3": (["agent", "graspable", "thing"], self.go_to_obj),
            "gotoobj4": (["agent", "notgraspable", "thing"], self.go_to_obj),
            "gotodoor1": (["agent", "door"], self.go_to_obj),
            "gotodoor2": (["agent", "door", "thing"], self.go_to_obj),
            "stepinto": (["agent", "goal"], self.step_into),
            "enterroomof": (["agent", "door",
                             "physobject"], self.enter_room_of),
            "usekey": (["agent", "key", "door"], self.usekey),
            "opendoor": (["agent", "door"], self.usekey),
            "pickup": (["agent", "graspable"], self.pickup),
            "putdown": (["agent", "graspable"], self.putdown),
            "core_drop": ([], self.core_drop),
            "core_pickup": ([], self.core_pickup),
            "core_forward": ([], self.core_forward),
            "core_right": ([], self.core_right),
            "core_left": ([], self.core_left),
            "core_toggle": ([], self.core_toggle),
            "core_done": ([], self.core_done)
        }
        self._window = None
        self._record_path = None
        self.set_record_path(0)
        if render_mode and render_mode.lower() == "human":
            self._window = Window("minigrid_executor")
            self._window.show(block=False)
        self.accumulator = None

    def _attach_learner(self, learner):
        self._learners.append(learner)

    def _clear_learners(self):
        self._learners = []

    def executor_map(self):
        return {
            name: (self, self.executors[name][0], self.executors[name][1])
            for name in self.executors
        }

    def _all_actions(self):
        action_list = []
        type_to_objects = self.detector.get_objects(inherit=True)
        for executor_name in self.executors:
            possible_args = tuple(
                [type_to_objects[t] for t in self.executors[executor_name][0]])
            for items in product(*possible_args):
                if len(items) == 0:
                    action_list.append("(" + executor_name + ")")
                else:
                    action_list.append("(" + executor_name + " " +
                                       " ".join(items) + ")")
        return action_list

    def _act(self, operator_name):
        self.accumulator = None
        executor_name, items = operator_parts(operator_name)
        param_types, executor = self.executors[executor_name]
        if len(items) == len(param_types):
            if items:
                # Doesn't currently ensure actions are typesafe. Could if we wanted it to.
                # param operator call  E.g. "(stack block cup)"
                obs, reward, done, info = executor(items)
            else:
                # primitive action call E.g., "right"
                obs, reward, done, info = executor()
        else:
            raise Exception("Wrong number of arguments for action " +
                            executor_name)
        return obs, reward, done, info

    ####################################
    ## OPERATORS
    ######################

    def _check_agent_nextto(self, item, obs):
        return self.detector.check_nextto("agent", item, obs)

    def _check_agent_facing(self, item, obs):
        return self.detector.check_facing("agent", item, obs)

    def _check_agent_obstructed(self, obs):
        return self.detector.check_formula(["(obstructed agent)"], obs)

    def get_line_between(self, x1, y1, x2, y2):
        # Assumes a straight line, so only one of x2-x1, y2-y1 is nonzero
        if y1 == y2:
            return [(i, y1) for i in range(min(x1, x2) + 1, max(x1, x2))]
        else:
            return [(x1, i) for i in range(min(y1, y2) + 1, max(y1, y2))]

    def enter_room_of(self, items):
        self.accumulator = RewardAccumulator(self.env)
        try:
            if self.detector.check_formula(
                ["(inroom {} {})".format(items[0], items[2])],
                    self.env.last_observation()):
                self.execute_core_action(self.env.actions.left,
                                         self.accumulator)
                self.execute_core_action(self.env.actions.left,
                                         self.accumulator)
            else:
                self.execute_core_action(self.env.actions.forward,
                                         self.accumulator)
                self.execute_core_action(self.env.actions.forward,
                                         self.accumulator)
        finally:
            return self.accumulator.combine_steps()

    def go_to_obj(self, items):
        self.accumulator = RewardAccumulator(self.env, action_timeout=100)
        agent = items[0]
        obj = items[1]
        obs = self.env.last_observation()
        agent_cur_direction = obs['objects']['agent']['encoding'][2]
        agent_x = obs['objects'][agent]['x']
        agent_y = obs['objects'][agent]['y']
        object_x = obs['objects'][obj]['x']
        object_y = obs['objects'][obj]['y']
        goal = (object_x, object_y)
        initial_state = (agent_x, agent_y, agent_cur_direction)
        image = self.env.last_observation()['image']
        path = self.a_star(image, initial_state, goal, self.manhattan_distance)
        try:
            for action in path[:-1]:
                self.execute_core_action(action, self.accumulator)
        finally:
            return self.accumulator.combine_steps()

    def manhattan_distance(self, agent_orientation, goal_position):
        return abs(agent_orientation[0] -
                   goal_position[0]) + abs(agent_orientation[1] -
                                           goal_position[1])

    def get_path(self, parent, current):
        plan = []
        while current in parent.keys():
            current, action = parent[current]
            plan.insert(0, action)
        return plan

    def a_star(self, image, initial_state, goal, heuristic):
        open_set = []
        parent = {}
        g_score = defaultdict(lambda: math.inf)
        g_score[initial_state] = 0.
        f_score = defaultdict(lambda: math.inf)
        f_score[initial_state] = heuristic(initial_state, goal)
        heapq.heappush(open_set, (f_score[initial_state], initial_state))

        while len(open_set) > 0:
            f, current = heapq.heappop(open_set)
            if (current[0], current[1]) == goal:
                return self.get_path(parent, current)
            neighbors = [(current[0], current[1], (current[2] - 1) % 4),
                         (current[0], current[1], (current[2] + 1) % 4)]
            fwd_x = current[0] + DIR_TO_VEC[current[2]][0]
            fwd_y = current[1] + DIR_TO_VEC[current[2]][1]
            if image[fwd_x][fwd_y][0] in [1, 3, 8] or (fwd_x, fwd_y) == goal:
                neighbors.append((fwd_x, fwd_y, current[2]))
            for action, neighbor in enumerate(neighbors):
                tentative_g_score = g_score[current] + 1
                if tentative_g_score < g_score[neighbor]:
                    parent[neighbor] = (current, action)
                    g_score[neighbor] = tentative_g_score
                    f_score[neighbor] = g_score[neighbor] + heuristic(
                        neighbor, goal)
                    if neighbor not in [x[1] for x in open_set]:
                        heapq.heappush(open_set, (f_score[neighbor], neighbor))
        return []

    def set_record_path(self, episode_number):
        self._record_path = "results/videos/{}/episode_{}".format(
            self.env.spec.id, episode_number)
        if SAVE_VIDEO:
            import os
            os.makedirs(self._record_path, exist_ok=True)

    def pickup(self, items):
        return self.core_pickup()

    def putdown(self, items):
        return self.core_drop()

    def usekey(self, items):
        return self.core_toggle()

    def step_into(self, items):
        return self.core_forward()

    def _execute_core_action(self, action, accumulator=None):
        prev_obs = self.env.last_observation()
        obs, reward, done, info = self.env.step(action)
        for learner in self._learners:
            learner.train(prev_obs, action, reward, obs)
        if accumulator:
            accumulator.accumulate(action, obs, reward, done, info)
        if self._window:
            img = self.env.render('rgb_array', tile_size=32, highlight=False)
            self._window.show_img(img)
            if SAVE_VIDEO:
                self._window.fig.savefig(self._record_path +
                                         "/{}.png".format(self.env.step_count))
        return obs, reward, done, action

    def core_drop(self):
        return self._execute_core_action(self.env.actions.drop)

    def core_pickup(self):
        return self._execute_core_action(self.env.actions.pickup)

    def core_forward(self):
        return self._execute_core_action(self.env.actions.forward)

    def core_left(self):
        return self._execute_core_action(self.env.actions.left)

    def core_right(self):
        return self._execute_core_action(self.env.actions.right)

    def core_toggle(self):
        return self._execute_core_action(self.env.actions.toggle)

    def core_done(self):
        return self._execute_core_action(self.env.actions.done)
Exemple #8
0
    help="draw the agent sees (partially observable view)",
    action='store_true'
)

args = parser.parse_args()

env = gym.make(args.env)

from pathlib import Path

RESOURCES_DIR = (Path(__file__).parent).resolve()

# env = gym_minigrid.envs.minimap.MinimapForSparky(
#         raw_map_path=Path(RESOURCES_DIR, 'gym_minigrid/envs/resources/map_set_'+str(args.map_set_number)+'.npy')
#     )
env = HumanFOVWrapper(env, agent_pos=(23, 14))
# env = InvertColorsWrapper(env)

# env = gym.wrappers.Monitor(env, "recording")
# env = gym.wrappers.Monitor(env, "./vid", video_callable=lambda episode_id: True,force=True)
if args.agent_view:
    # env = RGBImgPartialObsWrapper(env)
    env = ImgObsWrapper(env)

window = Window('gym_minigrid - ' + args.env + ' - Map set ' + str(args.map_set_number))
window.reg_key_handler(key_handler)

reset()

# Blocking event loop
window.show(block=True)
                    default=32)
parser.add_argument('--agent_view',
                    default=False,
                    help="draw the agent sees (partially observable view)",
                    action='store_true')
parser.add_argument('--agent_num',
                    type=int,
                    help="number of agents",
                    default=2)

args = parser.parse_args()

env = gym.make(args.env, agent_pos=[(1, 1), (17, 1)])

if args.agent_view:
    env = RGBImgPartialObsWrapper(env)
    env = ImgObsWrapper(env)

window = Window('gym_minigrid - ' + args.env)
window_explored = Window('explored_area')
window_obstacle = Window('obstacle_area')
#window.reg_key_handler(key_handler)

reset()

while (1):
    actions = env.get_short_term_action([[1, 17], [17, 17]])
    step(actions)
# Blocking event loop
# window.show(block=True)
    # Train
    update_start_time = time.time()
    images, label, seq_lens = mgmt.collect_episode()
    loss, correct = mgmt.update_parameters(images, label, seq_lens)

    # Log
    losses.append(loss)
    accuracy = torch.cat((accuracy, correct), 0)
    update += 1

    # Print logs
    if update % args.log_interval == 0:
        if args.visualize:
            # Visualize last frame of last sample
            from gym_minigrid.window import Window
            window = Window('gym_minigrid - ' + args.env)
            images = images.transpose(1,2)
            images = images.transpose(2,3)
            print(images[-1].shape)
            print(label)
            window.show_img(images[-1])
            input()
            window.close()

        duration = int(time.time() - start_time)

        header = ["Update", "Time", "Loss", "Accuracy"]
        acc = torch.mean(accuracy)
        data = [update, duration, sum(losses) / len(losses), acc]
        losses = []
        over = (acc >= 0.9999)
    parser.add_argument("--tile_size",
                        type=int,
                        help="size at which to render tiles",
                        default=32)
    parser.add_argument('--agent_view',
                        default=False,
                        help="draw the agent sees (partially observable view)",
                        action='store_true')

    args = parser.parse_args()
    env = gym.make(args.env)
    if args.agent_view:
        env = RGBImgPartialObsWrapper(env)
        env = ImgObsWrapper(env)

    window = Window('gym_minigrid - ' + args.env)
    reset()

    settings = termios.tcgetattr(sys.stdin)
    pub = rospy.Publisher('initial_pose', PoseStamped, queue_size=1)
    try:
        while (1):
            key = getKey()
            if key == 'w':
                step(env.actions.forward)

            elif key == 'a':
                step(env.actions.left)

            elif key == 'd':
                step(env.actions.right)
    # Spacebar
    if event.key == ' ':
        step(env.actions.toggle)
        return
    if event.key == 'pageup':
        step(env.actions.pickup)
        return
    if event.key == 'pagedown':
        step(env.actions.drop)
        return

    if event.key == 'enter':
        step(env.actions.done)
        return



env = gym.make('MiniGrid-Empty-5x5-v0')

env = RGBImgPartialObsWrapper(env)
env = ImgObsWrapper(env)

window = Window('gym_minigrid')
window.reg_key_handler(random_solve)

reset()

# Blocking event loop
window.show(block=True)
Exemple #13
0
class SimpleEnv(object):
    def __init__(self, display=False, agent_view=5, map_size=20, roads=1, max_step=100):
        super().__init__()
        self.display = display
        self.map = Simple2Dv2(map_size, map_size, agent_view=agent_view, roads=roads, max_step=max_step)
        self.window = None
        if self.display:
            self.window = Window('GYM_MiniGrid')
            self.window.reg_key_handler(self.key_handler)
            self.window.show(True)
        self.detect_rate = []
        self.rewards = []
        self.step_count = []
        self.old = None
        self.new = None
        self._rewards = []

    def short_term_reward(self):
        # (- manhattan distance / 100) + ( - stay time / 100)
        return self.new["reward"] / 100 - self.map.check_history() / 100

    def long_term_reward(self):
        _extrinsic_reward = self.new["l_reward"]
        _extrinsic_reward = sum(_extrinsic_reward) / len(_extrinsic_reward)
        return _extrinsic_reward

    def step(self, action):
        # Turn left, turn right, move forward
        # forward = 0
        # left = 1
        # right = 2
        self.old = self.map.state()
        self.new, done = self.map.step(action)
        reward = self.short_term_reward()
        if self.display is True:
            self.redraw()
        if done != 0:
            self.detect_rate.append(self.new["l_reward"])
            self.step_count.append(self.map.step_count)
            reward += self.long_term_reward()
            self._rewards.append(reward)
            self.rewards.append(np.mean(self._rewards))
        else:
            self._rewards.append(reward)

        return self.old, self.new, reward, done

    def key_handler(self, event):
        print('pressed', event.key)
        if event.key == 'left':
            self.step(0)
            return
        if event.key == 'right':
            self.step(1)
            return
        if event.key == 'up':
            self.step(2)
            return

    def redraw(self):
        if self.window is not None:
            self.map.render('human')

    def reset_env(self):
        """
        reset environment to the start point
        :return:
        """
        self.map.reset()
        self._rewards = []
        if self.display:
            self.redraw()
Exemple #14
0
0 reward if connect with lava, 1 - 0.9 * (self.step_count / self.max_steps) if the end goal is reached.
max_steps = 144
min_steps = 7(if good env with gap on top)/8
--> max_reward = 0.95625 / 0.94375s

MiniGrid-SimpleCrossingS9N1-v0:
1 - 0.9 * (self.step_count / self.max_steps) if the end goal is reached, 0 if not.
max_steps = 324
min_steps = 14
--> max_reward = 0.961
"""
envs = ['CartPole-v1', 'Acrobot-v1', 'MountainCar-v0', 'LunarLander-v2']

#env = gym.make(envs[0])
#env = (gym.make("Breakout-MinAtar-v0"))
#env = MinAtarObsWrapper(gym.make("Space_invaders-MinAtar-v0"))

env = FlatImgObsWrapper(gym.make('MiniGrid-Dynamic-Obstacles-8x8-v0'))
#env = FlatImgObsWrapper(gym.make('MiniGrid-LavaGapS7-v0'))
obs = env.reset()
print(obs.shape)
print(env.action_space.n)
print(obs.shape)
window = Window(title="MiniGrid")
for _ in range(1):
    img = env.render('rgb_array')
    window.show_img(img)
    obs, _, _, _ = env.step(env.action_space.sample()) # take a random action
    time.sleep(1000)
    print(obs)
env.close()
Exemple #15
0
def init_env(env_name):
    env = gym.make(env_name)
    env.max_steps = 256
    window = Window('gym_minigrid - ' + env_name)
    return env, window
Exemple #16
0
                    help="gym environment to load",
                    default='MiniGrid-MultiRoom-N6-v0')
parser.add_argument("--seed",
                    type=int,
                    help="random seed to generate the environment with",
                    default=-1)
parser.add_argument("--tile_size",
                    type=int,
                    help="size at which to render tiles",
                    default=32)
parser.add_argument('--agent_view',
                    default=False,
                    help="draw the agent sees (partially observable view)",
                    action='store_true')

args = parser.parse_args()

env = gym.make(args.env)

if args.agent_view:
    env = RGBImgPartialObsWrapper(env)
    env = ImgObsWrapper(env)

window = Window('gym_minigrid - ' + args.env)
window.reg_key_handler(key_handler)

reset()

# Blocking event loop
window.show(block=True)
if 'state_embedding_model_state_dict' in checkpoint:
    embedder_model.load_state_dict(
        checkpoint['state_embedding_model_state_dict'])

print(env)
print(env.unwrapped.grid)
env = Environment(env, fix_seed=args.fix_seed, env_seed=args.env_seed)
env_output = env.initial()
print(env.gym_env)

agent_state = model.initial_state(batch_size=1)
state_embedding = embedder_model(env_output['frame'])

if not args.stop_visu and is_minigrid:
    from gym_minigrid.window import Window
    w = Window(checkpoint['flags']['model'])
    arr = env.gym_env.render('rgb_array')
    #print("Arr", arr)
    w.show_img(arr)

while True:
    model_output, agent_state = model(env_output, agent_state)

    # action = model_output["action"]
    logits = model_output["policy_logits"]
    #print(logits)
    m = Categorical(logits=logits)
    action = m.sample()

    # action = torch.randint(low=0, high=env.gym_env.action_space.n, size=(1,))
    # action = torch.tensor([0])