コード例 #1
0
    def testExp3Update(self, observation_shape, num_actions, action, log_prob,
                       reward, learning_rate):
        """Check EXP3 updates for specified actions and rewards."""

        # Compute expected update for each action.
        expected_update_value = exp3_agent.exp3_update_value(reward, log_prob)
        expected_update = np.zeros(num_actions)
        for a, u in zip(action, self.evaluate(expected_update_value)):
            expected_update[a] += u

        # Construct a `Trajectory` for the given action, log prob and reward.
        time_step_spec = time_step.time_step_spec(
            tensor_spec.TensorSpec(observation_shape, tf.float32))
        action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    minimum=0,
                                                    maximum=num_actions - 1)
        initial_step, final_step = _get_initial_and_final_steps(
            observation_shape, reward)
        action_step = _get_action_step(action, log_prob)
        experience = _get_experience(initial_step, action_step, final_step)

        # Construct an agent and perform the update. Record initial and final
        # weights.
        agent = exp3_agent.Exp3Agent(time_step_spec=time_step_spec,
                                     action_spec=action_spec,
                                     learning_rate=learning_rate)
        self.evaluate(agent.initialize())
        initial_weights = self.evaluate(agent.weights)
        loss_info = agent.train(experience)
        self.evaluate(loss_info)
        final_weights = self.evaluate(agent.weights)
        update = final_weights - initial_weights

        # Check that the actual update matches expectations.
        self.assertAllClose(expected_update, update)
コード例 #2
0
 def testExp3UpdateValueShape(self, shape, seed):
     tf.compat.v1.set_random_seed(seed)
     reward = tfd.Uniform(0., 1.).sample(shape)
     log_prob = tfd.Normal(0., 1.).sample(shape)
     update_value = exp3_agent.exp3_update_value(reward, log_prob)
     self.assertAllEqual(shape, update_value.shape)