Ejemplo n.º 1
0
  def testNeuralLinUCBUpdateNumTrainSteps10(self, batch_size=1, context_dim=10):
    """Check NeuralLinUCBAgent updates when behaving like eps-greedy."""

    # Construct a `Trajectory` for the given action, observation, reward.
    num_actions = 5
    initial_step, final_step = _get_initial_and_final_steps(
        batch_size, context_dim)
    action = np.random.randint(num_actions, size=batch_size, dtype=np.int32)
    action_step = _get_action_step(action)
    experience = _get_experience(initial_step, action_step, final_step)

    # Construct an agent and perform the update.
    observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
    time_step_spec = time_step.time_step_spec(observation_spec)
    action_spec = tensor_spec.BoundedTensorSpec(
        dtype=tf.int32, shape=(), minimum=0, maximum=num_actions - 1)
    encoder = DummyNet(observation_spec)
    encoding_dim = 10
    variable_collection = neural_linucb_agent.NeuralLinUCBVariableCollection(
        num_actions, encoding_dim)
    agent = neural_linucb_agent.NeuralLinUCBAgent(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        encoding_network=encoder,
        encoding_network_num_train_steps=10,
        encoding_dim=encoding_dim,
        variable_collection=variable_collection,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001))

    loss_info, _ = agent.train(experience)
    self.evaluate(agent.initialize())
    self.evaluate(tf.compat.v1.global_variables_initializer())
    loss_value = self.evaluate(loss_info)
    self.assertGreater(loss_value, 0.0)
Ejemplo n.º 2
0
  def testInitializeRestoreVariableCollection(self):
    if not tf.executing_eagerly():
      self.skipTest('Test only works in eager mode.')
    num_actions = 5
    encoding_dim = 7
    variable_collection = neural_linucb_agent.NeuralLinUCBVariableCollection(
        num_actions=num_actions, encoding_dim=encoding_dim)
    self.evaluate(tf.compat.v1.global_variables_initializer())
    self.evaluate(variable_collection.num_samples_list)
    checkpoint = tf.train.Checkpoint(variable_collection=variable_collection)
    checkpoint_dir = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_dir, 'checkpoint')
    checkpoint.save(file_prefix=checkpoint_prefix)

    variable_collection.actions_from_reward_layer.assign(False)

    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    checkpoint_load_status = checkpoint.restore(latest_checkpoint)
    self.evaluate(checkpoint_load_status.initialize_or_restore())
    self.assertEqual(
        self.evaluate(variable_collection.actions_from_reward_layer), True)