Example #1
0
def test_env_semantics(spec):
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)

  if spec.id not in rollout_dict or should_skip_env_spec_for_tests(spec):
    if not spec.nondeterministic or should_skip_env_spec_for_tests(spec):
      logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
    return

  logger.info("Testing rollout for {} environment...".format(spec.id))

  observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)

  errors = []
  if rollout_dict[spec.id]['observations'] != observations_now:
    errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
  if rollout_dict[spec.id]['actions'] != actions_now:
    errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
  if rollout_dict[spec.id]['rewards'] != rewards_now:
    errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
  if rollout_dict[spec.id]['dones'] != dones_now:
    errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
  if len(errors):
    for error in errors:
      logger.warn(error)
    raise ValueError(errors)
Example #2
0
def test_env_semantics(spec):
    with open(ROLLOUT_FILE) as data_file:
        rollout_dict = json.load(data_file)

    if spec.id not in rollout_dict:
        if not spec.nondeterministic or should_skip_env_spec_for_tests(spec):
            logger.warn(
                "Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs"
                .format(spec.id))
        return

    logger.info("Testing rollout for {} environment...".format(spec.id))

    observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(
        spec)

    assert rollout_dict[spec.id][
        'observations'] == observations_now, 'Observations not equal for {}'.format(
            spec.id)
    assert rollout_dict[
        spec.id]['actions'] == actions_now, 'Actions not equal for {}'.format(
            spec.id)
    assert rollout_dict[
        spec.id]['rewards'] == rewards_now, 'Rewards not equal for {}'.format(
            spec.id)
    assert rollout_dict[
        spec.id]['dones'] == dones_now, 'Dones not equal for {}'.format(
            spec.id)
Example #3
0
def test_env(spec):
    if should_skip_env_spec_for_tests(spec):
        return

    # Note that this precludes running this test in multiple
    # threads. However, we probably already can't do multithreading
    # due to some environments.
    spaces.seed(0)

    env1 = spec.make()
    env1.seed(0)
    action_samples1 = [env1.action_space.sample() for i in range(4)]
    observation_samples1 = [env1.observation_space.sample() for i in range(4)]
    initial_observation1 = env1.reset()
    step_responses1 = [env1.step(action) for action in action_samples1]
    env1.close()

    spaces.seed(0)

    env2 = spec.make()
    env2.seed(0)
    action_samples2 = [env2.action_space.sample() for i in range(4)]
    observation_samples2 = [env2.observation_space.sample() for i in range(4)]
    initial_observation2 = env2.reset()
    step_responses2 = [env2.step(action) for action in action_samples2]
    env2.close()

    for i, (action_sample1,
            action_sample2) in enumerate(zip(action_samples1,
                                             action_samples2)):
        assert_equals(action_sample1, action_sample2
                      ), '[{}] action_sample1: {}, action_sample2: {}'.format(
                          i, action_sample1, action_sample2)

    for (observation_sample1,
         observation_sample2) in zip(observation_samples1,
                                     observation_samples2):
        assert_equals(observation_sample1, observation_sample2)

    # Don't check rollout equality if it's a a nondeterministic
    # environment.
    if spec.nondeterministic:
        return

    assert_equals(initial_observation1, initial_observation2)

    for i, ((o1, r1, d1, i1),
            (o2, r2, d2,
             i2)) in enumerate(zip(step_responses1, step_responses2)):
        assert_equals(o1, o2, '[{}] '.format(i))
        assert r1 == r2, '[{}] r1: {}, r2: {}'.format(i, r1, r2)
        assert d1 == d2, '[{}] d1: {}, d2: {}'.format(i, d1, d2)

        # Go returns a Pachi game board in info, which doesn't
        # properly check equality. For now, we hack around this by
        # just skipping Go.
        if spec.id not in ['Go9x9-v0', 'Go19x19-v0']:
            assert_equals(i1, i2, '[{}] '.format(i))
Example #4
0
def create_rollout(spec):
    """
  Takes as input the environment spec for which the rollout is to be generated.
  Returns a bool which indicates whether the new rollout was added to the json file.  

  """
    # Skip platform-dependent
    if should_skip_env_spec_for_tests(spec):
        logger.warn("Skipping tests for {}".format(spec.id))
        return False

    # Skip environments that are nondeterministic
    if spec.nondeterministic:
        logger.warn("Skipping tests for nondeterministic env {}".format(
            spec.id))
        return False

    # Skip broken environments
    # TODO: look into these environments
    if spec.id in [
            'PredictObsCartpole-v0', 'InterpretabilityCartpoleObservations-v0'
    ]:
        logger.warn("Skipping tests for {}".format(spec.id))
        return False

    with open(ROLLOUT_FILE) as data_file:
        rollout_dict = json.load(data_file)

    # Skip generating rollouts that already exist
    if spec.id in rollout_dict:
        logger.warn("Rollout already exists for {}".format(spec.id))
        return False

    logger.info("Generating rollout for {}".format(spec.id))

    try:
        observations_hash, actions_hash, rewards_hash, dones_hash = generate_rollout_hash(
            spec)
    except:
        # If running the env generates an exception, don't write to the rollout file
        logger.warn(
            "Exception {} thrown while generating rollout for {}. Rollout not added."
            .format(sys.exc_info()[0], spec.id))
        return False

    rollout = {}
    rollout['observations'] = observations_hash
    rollout['actions'] = actions_hash
    rollout['rewards'] = rewards_hash
    rollout['dones'] = dones_hash

    rollout_dict[spec.id] = rollout

    with open(ROLLOUT_FILE, "w") as outfile:
        json.dump(rollout_dict, outfile, indent=2)

    return True
Example #5
0
def update_rollout_dict(spec, rollout_dict):
    """
  Takes as input the environment spec for which the rollout is to be generated,
  and the existing dictionary of rollouts. Returns True iff the dictionary was
  modified.
  """
    # Skip platform-dependent
    if should_skip_env_spec_for_tests(spec):
        logger.info("Skipping tests for {}".format(spec.id))
        return False

    # Skip environments that are nondeterministic
    if spec.nondeterministic:
        logger.info("Skipping tests for nondeterministic env {}".format(
            spec.id))
        return False

    # Skip broken environments
    # TODO: look into these environments
    if spec.id in [
            'PredictObsCartpole-v0', 'InterpretabilityCartpoleObservations-v0'
    ]:
        logger.info("Skipping tests for {}".format(spec.id))
        return False

    logger.info("Generating rollout for {}".format(spec.id))

    try:
        observations_hash, actions_hash, rewards_hash, dones_hash = generate_rollout_hash(
            spec)
    except:
        # If running the env generates an exception, don't write to the rollout file
        logger.warn(
            "Exception {} thrown while generating rollout for {}. Rollout not added."
            .format(sys.exc_info()[0], spec.id))
        return False

    rollout = {}
    rollout['observations'] = observations_hash
    rollout['actions'] = actions_hash
    rollout['rewards'] = rewards_hash
    rollout['dones'] = dones_hash

    existing = rollout_dict.get(spec.id)
    if existing:
        differs = False
        for key, new_hash in rollout.items():
            differs = differs or existing[key] != new_hash
        if not differs:
            logger.debug("Hashes match with existing for {}".format(spec.id))
            return False
        else:
            logger.warn("Got new hash for {}. Overwriting.".format(spec.id))

    rollout_dict[spec.id] = rollout
    return True
def test_env(spec):
    if should_skip_env_spec_for_tests(spec):
        return

    # Note that this precludes running this test in multiple
    # threads. However, we probably already can't do multithreading
    # due to some environments.
    spaces.seed(0)

    env1 = spec.make()
    env1.seed(0)
    action_samples1 = [env1.action_space.sample() for i in range(4)]
    observation_samples1 = [env1.observation_space.sample() for i in range(4)]
    initial_observation1 = env1.reset()
    step_responses1 = [env1.step(action) for action in action_samples1]
    env1.close()

    spaces.seed(0)

    env2 = spec.make()
    env2.seed(0)
    action_samples2 = [env2.action_space.sample() for i in range(4)]
    observation_samples2 = [env2.observation_space.sample() for i in range(4)]
    initial_observation2 = env2.reset()
    step_responses2 = [env2.step(action) for action in action_samples2]
    env2.close()

    for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
        assert_equals(action_sample1, action_sample2), "[{}] action_sample1: {}, action_sample2: {}".format(
            i, action_sample1, action_sample2
        )

    for (observation_sample1, observation_sample2) in zip(observation_samples1, observation_samples2):
        assert_equals(observation_sample1, observation_sample2)

    # Don't check rollout equality if it's a a nondeterministic
    # environment.
    if spec.nondeterministic:
        return

    assert_equals(initial_observation1, initial_observation2)

    for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)):
        assert_equals(o1, o2, "[{}] ".format(i))
        assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
        assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)

        # Go returns a Pachi game board in info, which doesn't
        # properly check equality. For now, we hack around this by
        # just skipping Go.
        if spec.id not in ["Go9x9-v0", "Go19x19-v0"]:
            assert_equals(i1, i2, "[{}] ".format(i))
Example #7
0
def create_rollout(spec):
  """
  Takes as input the environment spec for which the rollout is to be generated.
  Returns a bool which indicates whether the new rollout was added to the json file.  

  """
  # Skip platform-dependent
  if should_skip_env_spec_for_tests(spec):
    logger.warn("Skipping tests for {}".format(spec.id))
    return False

  # Skip environments that are nondeterministic
  if spec.nondeterministic:
    logger.warn("Skipping tests for nondeterministic env {}".format(spec.id))
    return False

  # Skip broken environments
  # TODO: look into these environments
  if spec.id in ['PredictObsCartpole-v0', 'InterpretabilityCartpoleObservations-v0']:
    logger.warn("Skipping tests for {}".format(spec.id))
    return False

  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)

  # Skip generating rollouts that already exist
  if spec.id in rollout_dict:
    logger.warn("Rollout already exists for {}".format(spec.id))
    return False

  logger.info("Generating rollout for {}".format(spec.id))

  try:
    observations_hash, actions_hash, rewards_hash, dones_hash = generate_rollout_hash(spec)
  except:
    # If running the env generates an exception, don't write to the rollout file
    logger.warn("Exception {} thrown while generating rollout for {}. Rollout not added.".format(sys.exc_info()[0], spec.id))
    return False

  rollout = {}
  rollout['observations'] = observations_hash
  rollout['actions'] = actions_hash
  rollout['rewards'] = rewards_hash
  rollout['dones'] = dones_hash

  rollout_dict[spec.id] = rollout

  with open(ROLLOUT_FILE, "w") as outfile:
    json.dump(rollout_dict, outfile, indent=2)

  return True
Example #8
0
def test_env_semantics(spec):
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)

  if spec.id not in rollout_dict:
    if not spec.nondeterministic or should_skip_env_spec_for_tests(spec):
      logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
    return

  logger.info("Testing rollout for {} environment...".format(spec.id))

  observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)

  assert rollout_dict[spec.id]['observations'] == observations_now, 'Observations not equal for {}'.format(spec.id)
  assert rollout_dict[spec.id]['actions'] == actions_now, 'Actions not equal for {}'.format(spec.id)
  assert rollout_dict[spec.id]['rewards'] == rewards_now, 'Rewards not equal for {}'.format(spec.id)
  assert rollout_dict[spec.id]['dones'] == dones_now, 'Dones not equal for {}'.format(spec.id)
Example #9
0
def update_rollout_dict(spec, rollout_dict):
  """
  Takes as input the environment spec for which the rollout is to be generated,
  and the existing dictionary of rollouts. Returns True iff the dictionary was
  modified.
  """
  # Skip platform-dependent
  if should_skip_env_spec_for_tests(spec):
    logger.info("Skipping tests for {}".format(spec.id))
    return False

  # Skip environments that are nondeterministic
  if spec.nondeterministic:
    logger.info("Skipping tests for nondeterministic env {}".format(spec.id))
    return False

  logger.info("Generating rollout for {}".format(spec.id))

  try:
    observations_hash, actions_hash, rewards_hash, dones_hash = generate_rollout_hash(spec)
  except:
    # If running the env generates an exception, don't write to the rollout file
    logger.warn("Exception {} thrown while generating rollout for {}. Rollout not added.".format(sys.exc_info()[0], spec.id))
    return False

  rollout = {}
  rollout['observations'] = observations_hash
  rollout['actions'] = actions_hash
  rollout['rewards'] = rewards_hash
  rollout['dones'] = dones_hash

  existing = rollout_dict.get(spec.id)
  if existing:
    differs = False
    for key, new_hash in rollout.items():
      differs = differs or existing[key] != new_hash
    if not differs:
      logger.debug("Hashes match with existing for {}".format(spec.id))
      return False
    else:
      logger.warn("Got new hash for {}. Overwriting.".format(spec.id))

  rollout_dict[spec.id] = rollout
  return True