def init(sequence_length, eval_init_points_file_name, worlds):
  """Initializes the common operations between train and test."""
  modality_types = create_modality_types()
  logging.info('modality types: %r', modality_types)
  # negative reward_goal_range prevents the env from terminating early when the
  # agent is close to the goal. The policy should keep the agent until the end
  # of the 100 steps either through chosing stop action or oscilating around
  # the target.

  env = active_vision_dataset_env.ActiveVisionDatasetEnv(
      modality_types=modality_types +
      [task_env.ModalityTypes.GOAL, task_env.ModalityTypes.PREV_ACTION],
      reward_goal_range=-1,
      eval_init_points_file_name=eval_init_points_file_name,
      worlds=worlds,
      output_size=FLAGS.obs_resolution,
  )

  config = create_task_io_config(
      modality_types=modality_types,
      goal_category_count=FLAGS.goal_category_count,
      action_size=FLAGS.action_size,
      sequence_length=sequence_length,
  )
  task = tasks.GotoStaticXNoExplorationTask(env=env, config=config)
  embedders_dict = {mtype: map_to_embedder(mtype) for mtype in config.inputs}
  policy_params = tf.contrib.training.HParams(
      lstm_state_size=FLAGS.lstm_cell_size,
      fc_channels=FLAGS.policy_fc_size,
      weight_decay=FLAGS.weight_decay,
      target_embedding_size=FLAGS.embedding_fc_size,
  )
  policy = policies.LSTMPolicy(
      modality_names=config.inputs.keys(),
      embedders_dict=embedders_dict,
      action_size=FLAGS.action_size,
      params=policy_params,
      max_episode_length=sequence_length)
  return env, config, task, policy
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
    print('********')
    print(FLAGS.mode)
    print(FLAGS.gin_config)
    print(FLAGS.gin_params)

    env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[
        task_env.ModalityTypes.IMAGE, task_env.ModalityTypes.
        SEMANTIC_SEGMENTATION, task_env.ModalityTypes.OBJECT_DETECTION,
        task_env.ModalityTypes.DEPTH, task_env.ModalityTypes.DISTANCE
    ])

    if FLAGS.mode == BENCHMARK_MODE:
        benchmark(env, env.possible_targets)
    elif FLAGS.mode == GRAPH_MODE:
        for loc in env.worlds:
            env.check_scene_graph(loc, 'fridge')
    elif FLAGS.mode == HUMAN_MODE:
        human(env, env.possible_targets)
    elif FLAGS.mode == VIS_MODE:
        visualize_random_step_sequence(env)
    elif FLAGS.mode == EVAL_MODE:
        evaluate_folder(env, FLAGS.eval_folder)