Пример #1
0
def get_norm_dtw(action_list, env_output_list, environment):
  """Returns normalized DTW.

  "Effective and General Evaluation for Instruction Conditioned Navigation using
  Dynamic Time Warping" 2019 Magalhaes et al. https://arxiv.org/abs/1907.05446

  Args:
    action_list: List of actions.
    env_output_list: List of observations in the trajectories.
    environment: Testing environment.

  Returns:
    Value of normalized DTW.
  """
  invalid_panos = [constants.STOP_NODE_ID, constants.INVALID_NODE_ID]
  golden_panos = env_output_list[0].observation[constants.GOLDEN_PATH]
  golden_panos = [pano for pano in golden_panos if pano not in invalid_panos]

  obs_panos = [e.observation[constants.PANO_ID] for e in env_output_list[:-1]]
  obs_panos.append(action_list[-1])
  obs_panos = [pano for pano in obs_panos if pano not in invalid_panos]
  scan_id = env_output_list[0].observation[constants.SCAN_ID]
  dtw_matrix = base_eval_metric.get_dtw_matrix(
      obs_panos,
      golden_panos,
      lambda pano1, pano2: environment.get_distance(pano1, pano2, scan_id))

  dtw = dtw_matrix[len(obs_panos)][len(golden_panos)]
  return np.exp(-1. * dtw / (SUCCESS_THRESHOLD * len(golden_panos)))
Пример #2
0
def dense_dtw(path_history, next_pano, golden_path, end_of_episode, scan_info):
    """Rewards an agent based on the difference in DTW after going to nex_pano.

  Args:
    path_history: See above.
    next_pano: See above.
    golden_path: See above.
    end_of_episode: See above.
    scan_info: See above.

  Returns:
    A scalar float immediate reward for the transition
    current_pano --> next_pano.
  """
    del end_of_episode
    if next_pano in [constants.STOP_NODE_ID, constants.INVALID_NODE_ID]:
        return 0.0
    observed_pano_ids = path_history + [next_pano]
    observed_pano_names = [
        scan_info.pano_id_to_name[pano] for pano in observed_pano_ids
    ]

    dtw_matrix = eval_metric.get_dtw_matrix(observed_pano_names, golden_path,
                                            scan_info.graph.get_distance)

    num_obs_panos = len(observed_pano_names)
    num_golden_panos = len(golden_path)
    previous_dtw = dtw_matrix[num_obs_panos - 1][num_golden_panos]
    current_dtw = dtw_matrix[num_obs_panos][num_golden_panos]

    return previous_dtw - current_dtw
Пример #3
0
def get_dtw(action_list, env_output_list, environment):
  """Dynamic Time Warping (DTW).

  Muller, Meinard. "Dynamic time warping."
  Information retrieval for music and motion (2007): 69-84.

  Dynamic Programming implementation, O(NM) time and memory complexity.

  Args:
    action_list: List of actions.
    env_output_list: List of observations in the trajectories.
    environment: Testing environment.

  Returns:
    The DTW score.
  """
  invalid_panos = [constants.STOP_NODE_ID, constants.INVALID_NODE_ID]
  obs_panos = _get_predicted_path(action_list, env_output_list)
  obs_panos = [pano for pano in obs_panos if pano not in invalid_panos]

  golden_panos = env_output_list[0].observation[constants.GOLDEN_PATH]
  golden_panos = [pano for pano in golden_panos if pano not in invalid_panos]

  scan_id = env_output_list[0].observation[constants.SCAN_ID]
  dtw_matrix = base_eval_metric.get_dtw_matrix(
      obs_panos,
      golden_panos,
      lambda pano1, pano2: environment.get_distance(pano1, pano2, scan_id))

  golden_path_length = _get_path_length(golden_panos, scan_id, environment)
  # Note: We normalize DTW (which is sum of distances in the graph) by
  # golden_path_length.
  return dtw_matrix[len(obs_panos)][len(golden_panos)] / golden_path_length
Пример #4
0
 def _get_dtw_score(self, success_rate, golden_path, agent_path):
     distance_fn = self._env.shortest_path_length
     dtw_matrix = base_eval_metric.get_dtw_matrix(agent_path, golden_path,
                                                  distance_fn)
     dtw = dtw_matrix[len(agent_path)][len(golden_path)]
     pln_dtw = dtw / len(golden_path)
     ndtw = tf.math.exp(-1. * dtw / (_SUCCESS_THRESHOLD * len(golden_path)))
     sdtw = ndtw if success_rate else 0.
     return pln_dtw, ndtw, sdtw