def test_no_reset(self): with LocalRunner(sess=self.sess) as runner: # This tests if off-policy sampler respect batch_size # when no_reset is set to True env = TfEnv(normalize(gym.make('InvertedDoublePendulum-v2'))) action_noise = OUStrategy(env.spec, sigma=0.2) policy = ContinuousMLPPolicyWithModel( env_spec=env.spec, hidden_sizes=[64, 64], hidden_nonlinearity=tf.nn.relu, output_nonlinearity=tf.nn.tanh) qf = ContinuousMLPQFunction( env_spec=env.spec, hidden_sizes=[64, 64], hidden_nonlinearity=tf.nn.relu) replay_buffer = SimpleReplayBuffer( env_spec=env.spec, size_in_transitions=int(1e6), time_horizon=100) algo = DDPG( env_spec=env.spec, policy=policy, policy_lr=1e-4, qf_lr=1e-3, qf=qf, replay_buffer=replay_buffer, target_update_tau=1e-2, n_train_steps=50, discount=0.9, min_buffer_size=int(1e4), exploration_strategy=action_noise, ) sampler = OffPolicyVectorizedSampler(algo, env, 1, no_reset=True) sampler.start_worker() runner.initialize_tf_vars() paths1 = sampler.obtain_samples(0, 5) paths2 = sampler.obtain_samples(0, 5) len1 = sum([len(path['rewards']) for path in paths1]) len2 = sum([len(path['rewards']) for path in paths2]) assert len1 == 5 and len2 == 5, 'Sampler should respect batch_size' # yapf: disable assert (len(paths1[0]['rewards']) + len(paths2[0]['rewards']) == paths2[0]['running_length']), ( 'Running length should be the length of full path') # yapf: enable assert np.isclose( paths1[0]['rewards'].sum() + paths2[0]['rewards'].sum(), paths2[0]['undiscounted_return'] ), 'Undiscounted_return should be the sum of rewards of full path'
def test_algo_with_goal_without_es(self): # This tests if sampler works properly when algorithm # includes goal but is without exploration policy env = DummyDictEnv() policy = DummyPolicy(env) replay_buffer = SimpleReplayBuffer( env_spec=env, size_in_transitions=int(1e6), time_horizon=100) algo = DummyOffPolicyAlgo( env_spec=env, qf=None, replay_buffer=replay_buffer, policy=policy, exploration_strategy=None, input_include_goal=True) sampler = OffPolicyVectorizedSampler(algo, env, 1, no_reset=True) sampler.start_worker() sampler.obtain_samples(0, 30)