def test_get_qval_sym(self, filter_dims, num_filters, strides):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filter_dims=filter_dims,
                                          num_filters=num_filters,
                                          strides=strides,
                                          dueling=False)
        output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})

        obs_dim = self.env.observation_space.shape
        action_dim = self.env.action_space.n

        input_var = tf.compat.v1.placeholder(tf.float32,
                                             shape=(None, ) + obs_dim)
        q_vals = qf.get_qval_sym(input_var, 'another')
        output2 = self.sess.run(q_vals, feed_dict={input_var: [self.obs]})

        expected_output = np.full(action_dim, 0.5)

        assert np.array_equal(output1, output2)
        assert np.array_equal(output2[0], expected_output)
    def test_obs_not_image(self):
        env = self.env
        with mock.patch(('metarl.tf.models.'
                         'categorical_cnn_model.CNNModel._build'),
                        autospec=True,
                        side_effect=CNNModel._build) as build:

            qf = DiscreteCNNQFunction(env_spec=env.spec,
                                      filters=((5, (3, 3)), ),
                                      strides=(2, ),
                                      dueling=False)
            normalized_obs = build.call_args_list[0][0][1]

            input_ph = qf.input
            assert input_ph == normalized_obs

            fake_obs = [np.full(env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={input_ph:
                                             fake_obs}) == 255.).all()

            obs_dim = env.spec.observation_space.shape
            state_input = tf.compat.v1.placeholder(tf.float32,
                                                   shape=(None, ) + obs_dim)

            qf.get_qval_sym(state_input, name='another')
            normalized_obs = build.call_args_list[1][0][1]

            fake_obs = [np.full(env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={state_input:
                                             fake_obs}) == 255).all()
 def test_clone(self, filters, strides):
     with mock.patch(('metarl.tf.q_functions.'
                      'discrete_cnn_q_function.CNNModel'),
                     new=SimpleCNNModel):
         with mock.patch(('metarl.tf.q_functions.'
                          'discrete_cnn_q_function.MLPModel'),
                         new=SimpleMLPModel):
             qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                       filters=filters,
                                       strides=strides,
                                       dueling=False)
     qf_clone = qf.clone('another_qf')
     assert qf_clone._filters == qf._filters
     assert qf_clone._strides == qf._strides
    def test_get_action_max_pooling(self, filter_dims, num_filters, strides,
                                    pool_strides, pool_shapes):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModelWithMaxPooling'),
                        new=SimpleCNNModelWithMaxPooling):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filter_dims=filter_dims,
                                          num_filters=num_filters,
                                          strides=strides,
                                          max_pooling=True,
                                          pool_strides=pool_strides,
                                          pool_shapes=pool_shapes,
                                          dueling=False)

        action_dim = self.env.action_space.n
        expected_output = np.full(action_dim, 0.5)
        outputs = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})
        assert np.array_equal(outputs[0], expected_output)
        outputs = self.sess.run(
            qf.q_vals, feed_dict={qf.input: [self.obs, self.obs, self.obs]})
        for output in outputs:
            assert np.array_equal(output, expected_output)
    def test_is_pickleable(self, filter_dims, num_filters, strides):
        with mock.patch(('metarl.tf.q_functions.'
                         'discrete_cnn_q_function.CNNModel'),
                        new=SimpleCNNModel):
            with mock.patch(('metarl.tf.q_functions.'
                             'discrete_cnn_q_function.MLPModel'),
                            new=SimpleMLPModel):
                qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                          filter_dims=filter_dims,
                                          num_filters=num_filters,
                                          strides=strides,
                                          dueling=False)
        with tf.compat.v1.variable_scope(
                'DiscreteCNNQFunction/Sequential/SimpleMLPModel', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())

        output1 = self.sess.run(qf.q_vals, feed_dict={qf.input: [self.obs]})

        h_data = pickle.dumps(qf)
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            qf_pickled = pickle.loads(h_data)
            output2 = sess.run(qf_pickled.q_vals,
                               feed_dict={qf_pickled.input: [self.obs]})

        assert np.array_equal(output1, output2)
 def test_invalid_obs_shape(self, obs_dim):
     boxEnv = MetaRLEnv(DummyDiscreteEnv(obs_dim=obs_dim))
     with pytest.raises(ValueError):
         DiscreteCNNQFunction(env_spec=boxEnv.spec,
                              filters=((5, (3, 3)), ),
                              strides=(2, ),
                              dueling=False)
예제 #7
0
def run_task(snapshot_config, variant_data, *_):
    """Run task.

    Args:
        snapshot_config (metarl.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
        variant_data (dict): Custom arguments for the task.
        *_ (object): Ignored by this function.

    """
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        n_epochs = 100
        steps_per_epoch = 20
        sampler_batch_size = 500
        num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size

        env = gym.make('PongNoFrameskip-v4')
        env = Noop(env, noop_max=30)
        env = MaxAndSkip(env, skip=4)
        env = EpisodicLife(env)
        if 'FIRE' in env.unwrapped.get_action_meanings():
            env = FireReset(env)
        env = Grayscale(env)
        env = Resize(env, 84, 84)
        env = ClipReward(env)
        env = StackFrames(env, 4)

        env = TfEnv(env)

        replay_buffer = SimpleReplayBuffer(
            env_spec=env.spec,
            size_in_transitions=variant_data['buffer_size'],
            time_horizon=1)

        qf = DiscreteCNNQFunction(env_spec=env.spec,
                                  filter_dims=(8, 4, 3),
                                  num_filters=(32, 64, 64),
                                  strides=(4, 2, 1),
                                  dueling=False)

        policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
        epilson_greedy_strategy = EpsilonGreedyStrategy(
            env_spec=env.spec,
            total_timesteps=num_timesteps,
            max_epsilon=1.0,
            min_epsilon=0.02,
            decay_ratio=0.1)

        algo = DQN(env_spec=env.spec,
                   policy=policy,
                   qf=qf,
                   exploration_strategy=epilson_greedy_strategy,
                   replay_buffer=replay_buffer,
                   qf_lr=1e-4,
                   discount=0.99,
                   min_buffer_size=int(1e4),
                   double_q=False,
                   n_train_steps=500,
                   steps_per_epoch=steps_per_epoch,
                   target_network_update_freq=2,
                   buffer_batch_size=32)

        runner.setup(algo, env)
        runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
예제 #8
0
def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_path_length=None):
    """Train DQN on PongNoFrameskip-v4 environment.

    Args:
        ctxt (metarl.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        buffer_size (int): Number of timesteps to store in replay buffer.
        max_path_length (int): Maximum length of a path after which a path is
            considered complete. This is used during testing to minimize the
            memory required to store a single path.

    """
    set_seed(seed)
    with LocalTFRunner(ctxt) as runner:
        n_epochs = 100
        steps_per_epoch = 20
        sampler_batch_size = 500
        num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size

        env = gym.make('PongNoFrameskip-v4')
        env = Noop(env, noop_max=30)
        env = MaxAndSkip(env, skip=4)
        env = EpisodicLife(env)
        if 'FIRE' in env.unwrapped.get_action_meanings():
            env = FireReset(env)
        env = Grayscale(env)
        env = Resize(env, 84, 84)
        env = ClipReward(env)
        env = StackFrames(env, 4)

        env = MetaRLEnv(env, is_image=True)

        replay_buffer = PathBuffer(capacity_in_transitions=buffer_size)

        qf = DiscreteCNNQFunction(env_spec=env.spec,
                                  filters=(
                                              (32, (8, 8)),
                                              (64, (4, 4)),
                                              (64, (3, 3)),
                                          ),
                                  strides=(4, 2, 1),
                                  dueling=False)  # yapf: disable

        policy = DiscreteQfDerivedPolicy(env_spec=env.spec, qf=qf)
        exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
                                                 policy=policy,
                                                 total_timesteps=num_timesteps,
                                                 max_epsilon=1.0,
                                                 min_epsilon=0.02,
                                                 decay_ratio=0.1)

        algo = DQN(env_spec=env.spec,
                   policy=policy,
                   qf=qf,
                   exploration_policy=exploration_policy,
                   replay_buffer=replay_buffer,
                   qf_lr=1e-4,
                   discount=0.99,
                   min_buffer_size=int(1e4),
                   max_path_length=max_path_length,
                   double_q=False,
                   n_train_steps=500,
                   steps_per_epoch=steps_per_epoch,
                   target_network_update_freq=2,
                   buffer_batch_size=32)

        runner.setup(algo, env)
        runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)