def test_get_qval_sym(self, filter_dims, num_filters, strides): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.CNNModel'), new=SimpleCNNModel): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.MLPModel'), new=SimpleMLPModel): qf = DiscreteCNNQFunction(env_spec=self.env.spec, filter_dims=filter_dims, num_filters=num_filters, strides=strides, dueling=False) output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]}) obs_dim = self.env.observation_space.shape action_dim = self.env.action_space.n input_var = tf.compat.v1.placeholder(tf.float32, shape=(None, ) + obs_dim) q_vals = qf.get_qval_sym(input_var, 'another') output2 = self.sess.run(q_vals, feed_dict={input_var: [self.obs]}) expected_output = np.full(action_dim, 0.5) assert np.array_equal(output1, output2) assert np.array_equal(output2[0], expected_output)
def test_obs_not_image(self): env = self.env with mock.patch(('metarl.tf.models.' 'categorical_cnn_model.CNNModel._build'), autospec=True, side_effect=CNNModel._build) as build: qf = DiscreteCNNQFunction(env_spec=env.spec, filters=((5, (3, 3)), ), strides=(2, ), dueling=False) normalized_obs = build.call_args_list[0][0][1] input_ph = qf.input assert input_ph == normalized_obs fake_obs = [np.full(env.spec.observation_space.shape, 255)] assert (self.sess.run(normalized_obs, feed_dict={input_ph: fake_obs}) == 255.).all() obs_dim = env.spec.observation_space.shape state_input = tf.compat.v1.placeholder(tf.float32, shape=(None, ) + obs_dim) qf.get_qval_sym(state_input, name='another') normalized_obs = build.call_args_list[1][0][1] fake_obs = [np.full(env.spec.observation_space.shape, 255)] assert (self.sess.run(normalized_obs, feed_dict={state_input: fake_obs}) == 255).all()
def test_clone(self, filters, strides): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.CNNModel'), new=SimpleCNNModel): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.MLPModel'), new=SimpleMLPModel): qf = DiscreteCNNQFunction(env_spec=self.env.spec, filters=filters, strides=strides, dueling=False) qf_clone = qf.clone('another_qf') assert qf_clone._filters == qf._filters assert qf_clone._strides == qf._strides
def test_get_action_max_pooling(self, filter_dims, num_filters, strides, pool_strides, pool_shapes): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.CNNModelWithMaxPooling'), new=SimpleCNNModelWithMaxPooling): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.MLPModel'), new=SimpleMLPModel): qf = DiscreteCNNQFunction(env_spec=self.env.spec, filter_dims=filter_dims, num_filters=num_filters, strides=strides, max_pooling=True, pool_strides=pool_strides, pool_shapes=pool_shapes, dueling=False) action_dim = self.env.action_space.n expected_output = np.full(action_dim, 0.5) outputs = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]}) assert np.array_equal(outputs[0], expected_output) outputs = self.sess.run( qf.q_vals, feed_dict={qf.input: [self.obs, self.obs, self.obs]}) for output in outputs: assert np.array_equal(output, expected_output)
def test_is_pickleable(self, filter_dims, num_filters, strides): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.CNNModel'), new=SimpleCNNModel): with mock.patch(('metarl.tf.q_functions.' 'discrete_cnn_q_function.MLPModel'), new=SimpleMLPModel): qf = DiscreteCNNQFunction(env_spec=self.env.spec, filter_dims=filter_dims, num_filters=num_filters, strides=strides, dueling=False) with tf.compat.v1.variable_scope( 'DiscreteCNNQFunction/Sequential/SimpleMLPModel', reuse=True): return_var = tf.compat.v1.get_variable('return_var') # assign it to all one return_var.load(tf.ones_like(return_var).eval()) output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]}) h_data = pickle.dumps(qf) with tf.compat.v1.Session(graph=tf.Graph()) as sess: qf_pickled = pickle.loads(h_data) output2 = sess.run(qf_pickled.q_vals, feed_dict={qf_pickled.input: [self.obs]}) assert np.array_equal(output1, output2)
def test_invalid_obs_shape(self, obs_dim): boxEnv = MetaRLEnv(DummyDiscreteEnv(obs_dim=obs_dim)) with pytest.raises(ValueError): DiscreteCNNQFunction(env_spec=boxEnv.spec, filters=((5, (3, 3)), ), strides=(2, ), dueling=False)
def run_task(snapshot_config, variant_data, *_): """Run task. Args: snapshot_config (metarl.experiment.SnapshotConfig): The snapshot configuration used by LocalRunner to create the snapshotter. variant_data (dict): Custom arguments for the task. *_ (object): Ignored by this function. """ with LocalTFRunner(snapshot_config=snapshot_config) as runner: n_epochs = 100 steps_per_epoch = 20 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = gym.make('PongNoFrameskip-v4') env = Noop(env, noop_max=30) env = MaxAndSkip(env, skip=4) env = EpisodicLife(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireReset(env) env = Grayscale(env) env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4) env = TfEnv(env) replay_buffer = SimpleReplayBuffer( env_spec=env.spec, size_in_transitions=variant_data['buffer_size'], time_horizon=1) qf = DiscreteCNNQFunction(env_spec=env.spec, filter_dims=(8, 4, 3), num_filters=(32, 64, 64), strides=(4, 2, 1), dueling=False) policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) epilson_greedy_strategy = EpsilonGreedyStrategy( env_spec=env.spec, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_strategy=epilson_greedy_strategy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=0.99, min_buffer_size=int(1e4), double_q=False, n_train_steps=500, steps_per_epoch=steps_per_epoch, target_network_update_freq=2, buffer_batch_size=32) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_path_length=None): """Train DQN on PongNoFrameskip-v4 environment. Args: ctxt (metarl.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. buffer_size (int): Number of timesteps to store in replay buffer. max_path_length (int): Maximum length of a path after which a path is considered complete. This is used during testing to minimize the memory required to store a single path. """ set_seed(seed) with LocalTFRunner(ctxt) as runner: n_epochs = 100 steps_per_epoch = 20 sampler_batch_size = 500 num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size env = gym.make('PongNoFrameskip-v4') env = Noop(env, noop_max=30) env = MaxAndSkip(env, skip=4) env = EpisodicLife(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireReset(env) env = Grayscale(env) env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4) env = MetaRLEnv(env, is_image=True) replay_buffer = PathBuffer(capacity_in_transitions=buffer_size) qf = DiscreteCNNQFunction(env_spec=env.spec, filters=( (32, (8, 8)), (64, (4, 4)), (64, (3, 3)), ), strides=(4, 2, 1), dueling=False) # yapf: disable policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec, policy=policy, total_timesteps=num_timesteps, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) algo = DQN(env_spec=env.spec, policy=policy, qf=qf, exploration_policy=exploration_policy, replay_buffer=replay_buffer, qf_lr=1e-4, discount=0.99, min_buffer_size=int(1e4), max_path_length=max_path_length, double_q=False, n_train_steps=500, steps_per_epoch=steps_per_epoch, target_network_update_freq=2, buffer_batch_size=32) runner.setup(algo, env) runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)