예제 #1
0
  def test_no_error_for_existing_table(self):
    connection = sqlite3.connect(':memory:')

    logger_1 = sqlite_logging.Logger(db_path='unused',
                                     experiment_name='test',
                                     setting_index=1,
                                     connection=connection)

    data = dict(
        steps=10,
        episode=1,
        total_return=5.0,
        episode_len=10,
        episode_return=5.0,
    )
    logger_1.write(data)

    logger_2 = sqlite_logging.Logger(db_path='unused',
                                     experiment_name='test',
                                     setting_index=1,
                                     connection=connection)

    data = dict(
        steps=20,
        episode=2,
        total_return=10.0,
        episode_len=10,
        episode_return=5.0,
    )
    logger_2.write(data)
예제 #2
0
  def test_logger(self, custom_data):
    connection = sqlite3.connect(':memory:')
    logger = sqlite_logging.Logger(db_path='unused',
                                   experiment_name='test',
                                   setting_index=1,
                                   connection=connection)

    num_writes = 10
    steps_per_episode = 7

    total_return = 0.0

    for i in range(num_writes):
      episode_return = random.random()
      total_return += episode_return

      data = dict(
          steps=i * steps_per_episode,
          episode=i,
          total_return=total_return,
          episode_len=steps_per_episode,
          episode_return=episode_return,
          extra=42,
      )
      data.update(custom_data)
      logger.write(data)

    count_query = 'select count(*) from test;'
    cursor = connection.cursor()
    results = cursor.execute(count_query).fetchall()
    self.assertLen(results, 1)
    self.assertEqual(results[0][0], num_writes)
예제 #3
0
  def test_logger_raises_malformed_sql_error(self):
    # This experiment name should result in a malformed insert statement.
    experiment_name = 'test--'
    logger = sqlite_logging.Logger(db_path=':memory:',
                                   experiment_name=experiment_name,
                                   setting_index=1,
                                   skip_name_validation=True)

    data = dict(
        steps=10,
        episode=1,
        total_return=5.0,
        episode_len=10,
        episode_return=5.0,
    )
    with self.assertRaises(sqlite3.OperationalError):
      logger.write(data)
예제 #4
0
def generate_results(experiment_name, setting_index, connection):
    logger = sqlite_logging.Logger(db_path='unused',
                                   experiment_name=experiment_name,
                                   setting_index=setting_index,
                                   connection=connection)

    steps_per_episode = 7

    total_return = 0.0

    for i in range(_NUM_WRITES):
        episode_return = random.random()
        total_return += episode_return

        data = dict(
            steps=i * steps_per_episode,
            episode=i,
            total_return=total_return,
            episode_len=steps_per_episode,
            episode_return=episode_return,
            extra=42,
        )
        logger.write(data)
예제 #5
0
 def test_invalid_name(self):
   with self.assertRaises(ValueError):
     sqlite_logging.Logger(db_path=':memory:',
                           experiment_name='test--',
                           setting_index=1)