예제 #1
0
    def __init__(self,
                 config_env,
                 config_baseline,
                 dataset,
                 target_dim=7,
                 map_kwargs={},
                 reward_kwargs={},
                 loop_episodes=True,
                 scenario_kwargs={}):
        if scenario_kwargs['use_depth']:
            config_env.SIMULATOR.AGENT_0.SENSORS.append("DEPTH_SENSOR")
        super().__init__(config_env, config_baseline, dataset)
        self.target_dim = target_dim
        self.image_dim = 256

        self.use_map = map_kwargs['map_building_size'] > 0
        self.map_dim = 84 if self.use_map else None
        self.map_kwargs = map_kwargs
        self.reward_kwargs = reward_kwargs
        self.scenario_kwargs = scenario_kwargs
        self.last_map = None  # TODO unused

        self.observation_space = get_obs_space(self.image_dim, self.target_dim,
                                               self.map_dim,
                                               scenario_kwargs['use_depth'])

        self.omap = None
        if self.use_map:
            self.omap = OccupancyMap(
                map_kwargs=map_kwargs)  # this one is not used

        self.loop_episodes = loop_episodes
        self.n_episodes_completed = 0
예제 #2
0
 def _reset(self):
     self.info = None
     self.obs = super().reset()
     if self.use_map:
         self.omap = OccupancyMap(initial_pg=self.obs['pointgoal'], map_kwargs=self.map_kwargs)
     self.obs = transform_observations(self.obs, target_dim=self.target_dim, omap=self.omap)
     if 'map' in self.obs:
         self.last_map = self.obs['map']
     return self.obs
    def act(self, observations):
        # tick
        self.t += 1

        # collect raw observations
        self.episode_rgbs.append(copy.deepcopy(observations['rgb']))
        self.episode_pgs.append(copy.deepcopy(observations['pointgoal']))

        # initialize or step occupancy map
        if self.map_kwargs['map_building_size'] > 0:
            if self.t == 1:
                self.omap = OccupancyMap(initial_pg=observations['pointgoal'], map_kwargs=self.map_kwargs)
            else:
                assert self.last_action is not None, 'This is not the first timestep, there must have been at least one action'
                self.omap.add_pointgoal(observations['pointgoal'])
                self.omap.step(self.last_action)

        # hard-coded STOP
        dist = observations['pointgoal'][0]
        if dist <= 0.2:
            return STOP_VALUE

        # preprocess and get observation
        observations = transform_observations(observations, target_dim=self.target_dim, omap=self.omap)
        observations = self.transform_pre_agg(observations)
        for k, v in observations.items():
            observations[k] = np.expand_dims(v, axis=0)
        observations = self.transform_post_agg(observations)
        self.current_obs.insert(observations)
        self.obs_stacked = {k: v.peek().cuda() for k, v in self.current_obs.peek().items()}

        # log first couple agent observation
        if self.t % 4 == 0 and 50 < self.t < 60:
            map_output = split_and_cat(self.obs_stacked['map']) * 0.5 + 0.5
            self.mlog.add_meter(f'diagnostics/map_{self.t}', tnt.meter.SingletonMeter(), ptype='image')
            self.mlog.update_meter(map_output, meters={f'diagnostics/map_{self.t}'}, phase='val')

        # act
        with torch.no_grad():
            value, action, act_log_prob, self.test_recurrent_hidden_states = self.actor_critic.act(
                self.obs_stacked,
                self.test_recurrent_hidden_states,
                self.not_done_masks,
            )
            action = action.item()
            self.not_done_masks = torch.ones(1, 1).cuda()  # mask says not done

        # log agent outputs
        assert self.action_space.contains(action), 'action from model does not fit our action space'
        self.last_action = action
        return action
예제 #4
0
def transform_observations(observations,
                           target_dim=16,
                           omap: OccupancyMap = None):
    new_obs = observations
    new_obs["rgb_filled"] = observations["rgb"]
    new_obs["taskonomy"] = observations["rgb"]
    new_obs["target"] = np.moveaxis(
        np.tile(transform_target(observations["pointgoal"]),
                (target_dim, target_dim, 1)), -1, 0)
    if omap is not None:
        new_obs['map'] = omap.construct_occupancy_map()
        new_obs['global_pos'] = omap.get_current_global_pos()
    del new_obs['rgb']
    return new_obs
예제 #5
0
    def __init__(self,
                 config_env,
                 config_baseline,
                 dataset,
                 target_dim=7,
                 map_kwargs={},
                 reward_kwargs={}):
        super().__init__(config_env, config_baseline, dataset)
        self.target_dim = target_dim
        self.image_dim = 256

        self.use_map = map_kwargs['map_building_size'] > 0
        self.map_dim = 84 if self.use_map else None
        self.map_kwargs = map_kwargs
        self.reward_kwargs = reward_kwargs
        self.last_map = None  # TODO unused

        self.observation_space = get_obs_space(self.image_dim, self.target_dim,
                                               self.map_dim)

        self.omap = None
        if self.use_map:
            self.omap = OccupancyMap(
                map_kwargs=map_kwargs)  # this one is not used
예제 #6
0
class MidlevelNavRLEnv(NavRLEnv):
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self,
                 config_env,
                 config_baseline,
                 dataset,
                 target_dim=7,
                 map_kwargs={},
                 reward_kwargs={}):
        super().__init__(config_env, config_baseline, dataset)
        self.target_dim = target_dim
        self.image_dim = 256

        self.use_map = map_kwargs['map_building_size'] > 0
        self.map_dim = 84 if self.use_map else None
        self.map_kwargs = map_kwargs
        self.reward_kwargs = reward_kwargs
        self.last_map = None  # TODO unused

        self.observation_space = get_obs_space(self.image_dim, self.target_dim,
                                               self.map_dim)

        self.omap = None
        if self.use_map:
            self.omap = OccupancyMap(
                map_kwargs=map_kwargs)  # this one is not used

    def get_reward(self, observations):
        reward = self.reward_kwargs['slack_reward']

        current_target_distance = self._distance_target()
        reward += self._previous_target_distance - current_target_distance
        self._previous_target_distance = current_target_distance

        if self.reward_kwargs['use_visit_penalty'] and len(
                self.omap.history) > 5:
            reward += self.reward_kwargs[
                'visit_penalty_coef'] * self.omap.compute_eps_ball_ratio(
                    self.reward_kwargs['penalty_eps'])

        if self._episode_success():
            reward += self.reward_kwargs['success_reward']

        return reward

    def reset(self):
        self.info = None
        self.obs = super().reset()
        if self.use_map:
            self.omap = OccupancyMap(initial_pg=self.obs['pointgoal'],
                                     map_kwargs=self.map_kwargs)
        self.obs = transform_observations(self.obs,
                                          target_dim=self.target_dim,
                                          omap=self.omap)
        if 'map' in self.obs:
            self.last_map = self.obs['map']
        return self.obs

    def step(self, action):
        self.obs, reward, done, self.info = super().step(action)
        if self.use_map:
            self.omap.add_pointgoal(self.obs['pointgoal'])
            self.omap.step(
                action
            )  # our forward model needs to see how the env changed due to the action (via the pg)
        self.obs = transform_observations(self.obs,
                                          target_dim=self.target_dim,
                                          omap=self.omap)
        if 'map' in self.obs:
            self.last_map = self.obs['map']
        return self.obs, reward, done, self.info

    def render(self, mode='human'):
        if mode == 'rgb_array':
            im = self.obs["rgb_filled"]

            # Get the birds eye view of the agent
            if self.info is None:
                top_down_map = np.zeros((256, 256, 3), dtype=np.uint8)
            else:
                top_down_map = draw_top_down_map(self.info,
                                                 self.obs["heading"],
                                                 im.shape[0])
                top_down_map = np.array(
                    Image.fromarray(top_down_map).resize((256, 256)))

            if 'map' in self.obs:
                occupancy_map = self.obs['map']
                h, w, _ = occupancy_map.shape
                occupancy_map[int(h // 2), int(w // 2),
                              2] = 255  # for debugging
                occupancy_map = np.array(
                    Image.fromarray(occupancy_map).resize((256, 256)))
                output_im = np.concatenate((im, top_down_map, occupancy_map),
                                           axis=1)
            else:
                output_im = np.concatenate((im, top_down_map), axis=1)

            # Pad to make dimensions even ( will always be even )
            # npad = ((output_im.shape[0] % 2, 0), (output_im.shape[1] %2, 0), (0, 0))
            # output_im = np.pad(output_im, pad_width=npad, mode='constant', constant_values=0)
            return output_im
        else:
            super().render(mode=mode)
예제 #7
0
class MidlevelNavRLEnv(NavRLEnv):
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self,
                 config_env,
                 config_baseline,
                 dataset,
                 target_dim=7,
                 map_kwargs={},
                 reward_kwargs={},
                 loop_episodes=True,
                 scenario_kwargs={}):
        if scenario_kwargs['use_depth']:
            config_env.SIMULATOR.AGENT_0.SENSORS.append("DEPTH_SENSOR")
        super().__init__(config_env, config_baseline, dataset)
        self.target_dim = target_dim
        self.image_dim = 256

        self.use_map = map_kwargs['map_building_size'] > 0
        self.map_dim = 84 if self.use_map else None
        self.map_kwargs = map_kwargs
        self.reward_kwargs = reward_kwargs
        self.scenario_kwargs = scenario_kwargs
        self.last_map = None  # TODO unused

        self.observation_space = get_obs_space(self.image_dim, self.target_dim,
                                               self.map_dim,
                                               scenario_kwargs['use_depth'])

        self.omap = None
        if self.use_map:
            self.omap = OccupancyMap(
                map_kwargs=map_kwargs)  # this one is not used

        self.loop_episodes = loop_episodes
        self.n_episodes_completed = 0

    def get_reward(self, observations):
        reward = self.reward_kwargs['slack_reward']

        if not self.reward_kwargs['sparse']:
            current_target_distance = self._distance_target()
            reward += (self._previous_target_distance - current_target_distance
                       ) * self.reward_kwargs['dist_coef']
            self._previous_target_distance = current_target_distance

            if self.reward_kwargs['use_visit_penalty'] and len(
                    self.omap.history) > 5:
                reward += self.reward_kwargs[
                    'visit_penalty_coef'] * self.omap.compute_eps_ball_ratio(
                        self.reward_kwargs['penalty_eps'])

        if self._episode_success():
            reward += self.reward_kwargs['success_reward']

        return reward

    def reset(self):
        self.obs = self._reset()
        return self.obs

    def _reset(self):
        self.info = None
        self.obs = super().reset()
        if self.use_map:
            self.omap = OccupancyMap(initial_pg=self.obs['pointgoal'],
                                     map_kwargs=self.map_kwargs)
        self.obs = transform_observations(self.obs,
                                          target_dim=self.target_dim,
                                          omap=self.omap)
        if 'map' in self.obs:
            self.last_map = self.obs['map']
        return self.obs

    def step(self, action):

        if self.n_episodes_completed >= len(
                self.episodes) and not self.loop_episodes:
            return self.obs, 0.0, False, self.info  # noop forever

        self.obs, reward, done, self.info = super().step(action)
        if self.use_map:
            self.omap.add_pointgoal(self.obs['pointgoal'])  # s_{t+1}
            self.omap.step(
                action
            )  # a_t our forward model needs to see how the env changed due to the action (via the pg)
        self.obs = transform_observations(self.obs,
                                          target_dim=self.target_dim,
                                          omap=self.omap)
        if 'map' in self.obs:
            self.last_map = self.obs['map']

        if done:
            self.n_episodes_completed += 1

        return self.obs, reward, done, self.info

    def render(self, mode='human'):
        if mode == 'rgb_array':
            im = self.obs["rgb_filled"]
            to_concat = [im]

            if 'depth' in self.obs:
                depth_im = gray_to_rgb(self.obs['depth'] * 255).astype(
                    np.uint8)
                to_concat.append(depth_im)

            # Get the birds eye view of the agent
            if self.info is not None:
                top_down_map = draw_top_down_map(self.info,
                                                 self.obs["heading"],
                                                 im.shape[0])
                top_down_map = np.array(
                    Image.fromarray(top_down_map).resize((256, 256)))
            else:
                top_down_map = np.zeros((256, 256, 3), dtype=np.uint8)
            to_concat.append(top_down_map)

            if 'map' in self.obs:
                occupancy_map = np.copy(
                    self.obs['map']
                )  # NEED TO COPY OR THIS IS PASS BY REFERENCE
                h, w, _ = occupancy_map.shape
                occupancy_map[int(h // 2), int(w // 2),
                              2] = 255  # for debugging
                occupancy_map = np.array(
                    Image.fromarray(occupancy_map).resize((256, 256)))
                to_concat.append(occupancy_map)

            output_im = np.concatenate(to_concat, axis=1)
            return output_im
        else:
            super().render(mode=mode)
예제 #8
0
IMAGE_DIR = os.path.join(EXPERT_DIR, 'videos')
TRAJ_DIR = os.path.join(EXPERT_DIR, DATA_SPLIT)

if not os.path.exists(IMAGE_DIR) and SAVE_VIDEO:
    os.makedirs(IMAGE_DIR)

if not os.path.exists(TRAJ_DIR):
    os.makedirs(TRAJ_DIR)

# main loop: collect data
for episode in tqdm(range(len(env.habitat_env.episodes))):
    observations = env.reset()

    images = []
    traj = []
    omap = OccupancyMap(initial_pg=observations['pointgoal'], map_kwargs=DEFAULT_MAP_KWARGS)
    while not env.habitat_env.episode_over:
        # postprocess and log (state, action) pairs
        observations['rgb_filled'] = observations['rgb']
        observations['target'] = np.moveaxis(np.tile(transform_target(observations['pointgoal']), (target_dim, target_dim, 1)), -1, 0)
        observations['map'] = omap.construct_occupancy_map()
        del observations['rgb']

        # agent step
        best_action = follower.get_next_action(env.habitat_env.current_episode.goals[0].position).value
        traj.append([observations, best_action])

        # env step
        observations, reward, done, info = env.step(best_action)
        omap.add_pointgoal(observations['pointgoal'])  # s_{t+1}
        omap.step(best_action)  # a_t
class HabitatAgent(Agent):
    def __init__(self, ckpt_path, config_data):
        # Load agent
        self.action_space = spaces.Discrete(3)
        if ckpt_path is not None:
            checkpoint_obj = torch.load(ckpt_path)
            start_epoch = checkpoint_obj["epoch"]
            print("Loaded learner (epoch {}) from {}".format(start_epoch, ckpt_path), flush=True)
            agent = checkpoint_obj["agent"]
        else:
            cfg = config_data['cfg']
            perception_model = eval(cfg['learner']['perception_network'])(
                cfg['learner']['num_stack'],
                **cfg['learner']['perception_network_kwargs'])
            base = NaivelyRecurrentACModule(
                perception_unit=perception_model,
                use_gru=cfg['learner']['recurrent_policy'],
                internal_state_size=cfg['learner']['internal_state_size'])
            actor_critic = PolicyWithBase(
                base, self.action_space,
                num_stack=cfg['learner']['num_stack'],
                takeover=None)
            if cfg['learner']['use_replay']:
                agent = PPOReplay(actor_critic,
                                                cfg['learner']['clip_param'],
                                                cfg['learner']['ppo_epoch'],
                                                cfg['learner']['num_mini_batch'],
                                                cfg['learner']['value_loss_coef'],
                                                cfg['learner']['entropy_coef'],
                                                cfg['learner']['on_policy_epoch'],
                                                cfg['learner']['off_policy_epoch'],
                                                lr=cfg['learner']['lr'],
                                                eps=cfg['learner']['eps'],
                                                max_grad_norm=cfg['learner']['max_grad_norm'])
            else:
                agent = PPO(actor_critic,
                                          cfg['learner']['clip_param'],
                                          cfg['learner']['ppo_epoch'],
                                          cfg['learner']['num_mini_batch'],
                                          cfg['learner']['value_loss_coef'],
                                          cfg['learner']['entropy_coef'],
                                          lr=cfg['learner']['lr'],
                                          eps=cfg['learner']['eps'],
                                          max_grad_norm=cfg['learner']['max_grad_norm'])
            weights_path = cfg['eval_kwargs']['weights_only_path']
            ckpt = torch.load(weights_path)
            agent.actor_critic.load_state_dict(ckpt['state_dict'])
            agent.optimizer = ckpt['optimizer']
        self.actor_critic = agent.actor_critic

        self.takeover_policy = None
        if config_data['cfg']['learner']['backout']['use_backout']:
            backout_type = config_data['cfg']['learner']['backout']['backout_type']
            if backout_type == 'hardcoded':
                self.takeover_policy = BackoutPolicy(
                    patience=config_data['cfg']['learner']['backout']['patience'],
                    num_processes=1,
                    unstuck_dist=config_data['cfg']['learner']['backout']['unstuck_dist'],
                    randomize_actions=config_data['cfg']['learner']['backout']['randomize_actions'],
                )
            elif backout_type == 'trained':
                backout_ckpt =config_data['cfg']['learner']['backout']['backout_ckpt_path']
                assert backout_ckpt is not None, 'need a checkpoint to use a trained backout'
                backout_checkpoint_obj = torch.load(backout_ckpt)
                backout_start_epoch = backout_checkpoint_obj["epoch"]
                print("Loaded takeover policy at (epoch {}) from {}".format(backout_start_epoch, backout_ckpt), flush=True)
                backout_policy = checkpoint_obj["agent"].actor_critic

                self.takeover_policy = TrainedBackoutPolicy(
                    patience=config_data['cfg']['learner']['backout']['patience'],
                    num_processes=1,
                    policy=backout_policy,
                    unstuck_dist=config_data['cfg']['learner']['backout']['unstuck_dist'],
                    num_takeover_steps=config_data['cfg']['learner']['backout']['num_takeover_steps'],
                )
            else:
                assert False, f'do not recognize backout type {backout_type}'
        self.actor_critic.takeover = self.takeover_policy

        self.validator = None
        if config_data['cfg']['learner']['validator']['use_validator']:
            validator_type = config_data['cfg']['learner']['validator']['validator_type']
            if validator_type == 'jerk':
                self.validator = JerkAvoidanceValidator()
            else:
                assert False, f'do not recognize validator {validator_type}'
        self.actor_critic.action_validator = self.validator

        # Set up spaces
        self.target_dim = config_data['cfg']['env']['env_specific_kwargs']['target_dim']

        map_dim = None
        self.omap = None
        if config_data['cfg']['env']['use_map']:
            self.map_kwargs = config_data['cfg']['env']['habitat_map_kwargs']
            map_dim = 84
            assert self.map_kwargs['map_building_size'] > 0, 'If we are using map in habitat, please set building size to be positive!'

        obs_space = get_obs_space(image_dim=256, target_dim=self.target_dim, map_dim=map_dim)

        preprocessing_fn_pre_agg = eval(config_data['cfg']['env']['transform_fn_pre_aggregation'])
        self.transform_pre_agg, obs_space = preprocessing_fn_pre_agg(obs_space)

        preprocessing_fn_post_agg = eval(config_data['cfg']['env']['transform_fn_post_aggregation'])
        self.transform_post_agg, obs_space = preprocessing_fn_post_agg(obs_space)

        self.current_obs = StackedSensorDictStorage(1,
                                               config_data['cfg']['learner']['num_stack'],
                                               {k: v.shape for k, v in obs_space.spaces.items()
                                                if k in config_data['cfg']['env']['sensors']})
        print(f'Stacked obs shape {self.current_obs.obs_shape}')

        self.current_obs = self.current_obs.cuda()
        self.actor_critic.cuda()

        self.hidden_size = config_data['cfg']['learner']['internal_state_size']
        self.test_recurrent_hidden_states = None
        self.not_done_masks = None

        self.episode_rgbs = []
        self.episode_pgs = []
        self.episode_entropy = []
        self.episode_num = 0
        self.t = 0
        self.episode_lengths = []
        self.episode_values = []
        self.last_action = None

        # Set up logging
        if config_data['cfg']['saving']['logging_type'] == 'visdom':
            self.mlog = tnt.logger.VisdomMeterLogger(
                title=config_data['uuid'], env=config_data['uuid'], server=config_data['cfg']['saving']['visdom_server'],
                port=config_data['cfg']['saving']['visdom_port'],
                log_to_filename=config_data['cfg']['saving']['visdom_log_file']
            )
            self.use_visdom = True
        elif config_data['cfg']['saving']['logging_type'] == 'tensorboard':
            self.mlog = tnt.logger.TensorboardMeterLogger(
                env=config_data['uuid'],
                log_dir=config_data['cfg']['saving']['log_dir'],
                plotstylecombined=True
            )
            self.use_visdom = False
        else:
            assert False, 'no proper logger!'

        self.log_dir = config_data['cfg']['saving']['log_dir']
        self.save_eval_videos = config_data['cfg']['saving']['save_eval_videos']
        self.mlog.add_meter('config', tnt.meter.SingletonMeter(), ptype='text')
        self.mlog.update_meter(cfg_to_md(config_data['cfg'], config_data['uuid']), meters={'config'}, phase='val')

    def reset(self):
        # reset hidden state and set done
        self.test_recurrent_hidden_states = torch.zeros(
            1, self.hidden_size
        ).cuda()
        self.not_done_masks = torch.zeros(1, 1).cuda()

        # reset observation storage (and verify)
        z = torch.zeros(1, 2).cuda()
        mask_out_done = { name: z for name in self.current_obs.sensor_names }
        if 'global_pos' in self.current_obs.sensor_names:
            mask_out_done['global_pos'] = torch.zeros(1,1).cuda()
        self.current_obs.clear_done(mask_out_done)
        for value in self.current_obs.peek().values():
            assert torch.sum(value.peek()).item() < 1e-6, 'did not clear the curent_obs properly'

        # log everything
        if len(self.episode_pgs) != 0:
            # log video (and save to log_dir)
            if self.save_eval_videos:
                images_to_video(images=self.episode_rgbs, output_dir=self.log_dir, video_name=f'test_{self.episode_num}')
                self.mlog.add_meter(f'diagnostics/rollout_{self.episode_num}', tnt.meter.SingletonMeter(), ptype='video')
                if self.use_visdom:
                    vid_path = os.path.join(self.log_dir, f'test_{self.episode_num}.mp4')
                    self.mlog.update_meter(vid_path, meters={f'diagnostics/rollout_{self.episode_num}'}, phase='val')
                else:
                    print('video support for TB is weak not recommended')
                    rgb_tensor = torch.Tensor(self.episode_rgbs).unsqueeze(dim=0)
                    self.mlog.update_meter(rgb_tensor, meters={f'diagnostics/rollout_{self.episode_num}'}, phase='val')

            # reset log
            self.mlog.reset_meter(self.episode_num, mode='val')

            # reset episode logs
            self.episode_rgbs = []
            self.episode_pgs = []
            self.episode_values = []
            self.episode_entropy = []
            self.episode_lengths.append(self.t)
            self.episode_num += 1
            self.t = 0
            self.last_action = None

    def act(self, observations):
        # tick
        self.t += 1

        # collect raw observations
        self.episode_rgbs.append(copy.deepcopy(observations['rgb']))
        self.episode_pgs.append(copy.deepcopy(observations['pointgoal']))

        # initialize or step occupancy map
        if self.map_kwargs['map_building_size'] > 0:
            if self.t == 1:
                self.omap = OccupancyMap(initial_pg=observations['pointgoal'], map_kwargs=self.map_kwargs)
            else:
                assert self.last_action is not None, 'This is not the first timestep, there must have been at least one action'
                self.omap.add_pointgoal(observations['pointgoal'])
                self.omap.step(self.last_action)

        # hard-coded STOP
        dist = observations['pointgoal'][0]
        if dist <= 0.2:
            return STOP_VALUE

        # preprocess and get observation
        observations = transform_observations(observations, target_dim=self.target_dim, omap=self.omap)
        observations = self.transform_pre_agg(observations)
        for k, v in observations.items():
            observations[k] = np.expand_dims(v, axis=0)
        observations = self.transform_post_agg(observations)
        self.current_obs.insert(observations)
        self.obs_stacked = {k: v.peek().cuda() for k, v in self.current_obs.peek().items()}

        # log first couple agent observation
        if self.t % 4 == 0 and 50 < self.t < 60:
            map_output = split_and_cat(self.obs_stacked['map']) * 0.5 + 0.5
            self.mlog.add_meter(f'diagnostics/map_{self.t}', tnt.meter.SingletonMeter(), ptype='image')
            self.mlog.update_meter(map_output, meters={f'diagnostics/map_{self.t}'}, phase='val')

        # act
        with torch.no_grad():
            value, action, act_log_prob, self.test_recurrent_hidden_states = self.actor_critic.act(
                self.obs_stacked,
                self.test_recurrent_hidden_states,
                self.not_done_masks,
            )
            action = action.item()
            self.not_done_masks = torch.ones(1, 1).cuda()  # mask says not done

        # log agent outputs
        assert self.action_space.contains(action), 'action from model does not fit our action space'
        self.last_action = action
        return action

    def finish_benchmark(self, metrics):
        self.mlog.add_meter('diagnostics/length_hist', tnt.meter.ValueSummaryMeter(), ptype='histogram')
        self.mlog.update_meter(self.episode_lengths, meters={'diagnostics/length_hist'}, phase='val')

        for k, v in metrics.items():
            print(k, v)
            self.mlog.add_meter(f'metrics/{k}',  tnt.meter.ValueSummaryMeter())
            self.mlog.update_meter(v, meters={f'metrics/{k}'}, phase='val')

        self.mlog.reset_meter(self.episode_num + 1, mode='val')