def test_middle_of_batches_are_all_zero_onehots(self): env = OneCharMemory(n=3, num_steps=4) X, Y = env.get_batch(batch_size=100) self.assertNpArrayConstant(X[:, 1:, 1:], 0) self.assertNpArrayConstant(Y[:, :-1, 1:], 0) self.assertNpArrayConstant(X[:, 1:, 0], 1) self.assertNpArrayConstant(Y[:, :-1, 0], 1)
def test_episode_length_one(self): env = OneCharMemory(n=3, num_steps=1, reward_for_remembering=0.5) action = env.reset() next_ob, reward, terminal, _ = env.step(action) self.assertTrue(terminal) self.assertNpArraysEqual(next_ob, np.array([1, 0, 0, 0])) self.assertAlmostEqual(reward, 0.5)
def test_episode_length_is_right(self): env = OneCharMemory(num_steps=4) action = env.reset() for _ in range(3): _, _, terminal, _ = env.step(action) self.assertFalse(terminal) _, _, terminal, _ = env.step(action) self.assertTrue(terminal)
def test_reward_for_optimal_input_is_correct(self): env = OneCharMemory(n=3, num_steps=4, reward_for_remembering=.25) init_obs = env.reset() action = np.zeros((4, 1)) action[0] = 1 for _ in range(3): next_ob, reward, terminal, _ = env.step(action) self.assertAlmostEqual(reward, 0.) next_ob, reward, terminal, _ = env.step(init_obs) self.assertAlmostEqual(reward, 0.25)
def test_reward_for_wrong_input_is_correct(self): env = OneCharMemory(n=3, num_steps=4) init_obs = env.reset() action = init_obs for _ in range(3): next_ob, reward, terminal, _ = env.step(action) self.assertTrue(reward < 0) based_reward = reward action = np.zeros((4,)) next_ob, reward, terminal, _ = env.step(action) self.assertTrue(reward < 0)
def test_output_target(self): env = OneCharMemory(n=2, num_steps=5, output_target_number=True) self.assertEqual(env.observation_space.flat_dim, 4) obs = env.reset() target_number = np.argmax(obs[:3]) action = np.zeros(3) next_ob, reward, terminal, _ = env.step(action) expected_obs = np.hstack((obs[:3], [target_number])) expected_next_obs = np.array([1, 0, 0, target_number]) self.assertNpArraysEqual(obs, expected_obs) self.assertNpArraysEqual(next_ob, expected_next_obs) next_ob, reward, terminal, _ = env.step(action) self.assertNpArraysEqual(next_ob, expected_next_obs)
def test_reward_for_wrong_input_is_correct(self): env = OneCharMemory(n=3, num_steps=4, reward_for_remembering=10) init_obs = env.reset() action = init_obs for _ in range(3): next_ob, reward, terminal, _ = env.step(action) self.assertEqual(reward, 0) action = np.zeros((4, )) if init_obs[2]: action[3] = 1 else: action[2] = 1 next_ob, reward, terminal, _ = env.step(action) self.assertEqual(reward, 0)
def test_output_time(self): env = OneCharMemory(n=2, num_steps=3, output_time=True) self.assertEqual(env.observation_space.flat_dim, 3 + 4) obs = env.reset() action = np.zeros(3) time = np.zeros(4) time[0] = 1 self.assertNpEqual(obs[-4:], time) next_ob, reward, terminal, _ = env.step(action) time[0] = 0 time[1] = 1 self.assertNpEqual(next_ob[-4:], time) next_ob, reward, terminal, _ = env.step(action) time[1] = 0 time[2] = 1 self.assertNpEqual(next_ob[-4:], time) next_ob, reward, terminal, _ = env.step(action) time[2] = 0 time[3] = 1 self.assertNpEqual(next_ob[-4:], time) self.assertTrue(terminal) obs = env.reset() time[3] = 0 time[0] = 1 self.assertNpEqual(obs[-4:], time)
def test_memory_action_saved(self): ocm = OneCharMemory(n=5, num_steps=100) env = ContinuousMemoryAugmented(ocm, num_memory_states=10) env.reset() env_action = np.zeros(6) env_action[0] = 1 memory_written = np.random.rand(10) action = [env_action, memory_written] _, saved_memory = env.step(action)[0] self.assertNpArraysEqual(memory_written, saved_memory)
def get_env_settings(env_id="", normalize_env=True, gym_name="", env_params=None): if env_params is None: env_params = {} if env_id == 'cart': env = CartpoleEnv() name = "Cartpole" elif env_id == 'cheetah': env = HalfCheetahEnv() name = "HalfCheetah" elif env_id == 'ant': env = AntEnv() name = "Ant" elif env_id == 'point': env = gym_env("OneDPoint-v0") name = "OneDPoint" elif env_id == 'reacher': env = gym_env("Reacher-v1") name = "Reacher" elif env_id == 'idp': env = InvertedDoublePendulumEnv() name = "InvertedDoublePendulum" elif env_id == 'ocm': env = OneCharMemory(**env_params) name = "OneCharMemory" elif env_id == 'gym': if gym_name == "": raise Exception("Must provide a gym name") env = gym_env(gym_name) name = gym_name else: raise Exception("Unknown env: {0}".format(env_id)) if normalize_env: env = normalize(env) name += "-normalized" return dict( env=env, name=name, was_env_normalized=normalize_env, )
def test_target_is_never_zero_one_hot(self): env = OneCharMemory(n=3, num_steps=4) X, Y = env.get_batch(batch_size=100) self.assertNpArrayConstant(np.sum(X[:, 0, 0]), 0)
def test_dim_correct(self): ocm = OneCharMemory(n=5, num_steps=100) env = ContinuousMemoryAugmented(ocm, num_memory_states=10) self.assertEqual(env.action_space.flat_dim, 16)
def test_get_batch_shape(self): env = OneCharMemory(n=5, num_steps=100) X, Y = env.get_batch(batch_size=3) self.assertEqual(X.shape, (3, 100, 6)) self.assertEqual(Y.shape, (3, 100, 6))
def test_dim_correct(self): env = OneCharMemory(n=5, num_steps=100) self.assertEqual(env.feature_dim, 6) self.assertEqual(env.target_dim, 6) self.assertEqual(env.sequence_length, 100)
def test_batch_x_first_and_y_last_are_equal(self): env = OneCharMemory(n=3, num_steps=4) X, Y = env.get_batch(batch_size=100) self.assertNpEqual(X[:, 0, :], Y[:, -1, :])
def test_first_x_is_one_hot(self): env = OneCharMemory(n=3, num_steps=4) X, Y = env.get_batch(batch_size=100) self.assertNpArrayConstant(np.sum(np.array(1 == X[:, 0, :]), axis=1), 1)
def test_init_state_is_one_hot(self): env = OneCharMemory(n=3, num_steps=4) init_state = env.reset() self.assertEqual(init_state.shape, (4,))