def testNeuralLinUCBUpdateNumTrainSteps10(self, batch_size=1, context_dim=10): """Check NeuralLinUCBAgent updates when behaving like eps-greedy.""" # Construct a `Trajectory` for the given action, observation, reward. num_actions = 5 initial_step, final_step = _get_initial_and_final_steps( batch_size, context_dim) action = np.random.randint(num_actions, size=batch_size, dtype=np.int32) action_step = _get_action_step(action) experience = _get_experience(initial_step, action_step, final_step) # Construct an agent and perform the update. observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = time_step.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec( dtype=tf.int32, shape=(), minimum=0, maximum=num_actions - 1) encoder = DummyNet(observation_spec) encoding_dim = 10 variable_collection = neural_linucb_agent.NeuralLinUCBVariableCollection( num_actions, encoding_dim) agent = neural_linucb_agent.NeuralLinUCBAgent( time_step_spec=time_step_spec, action_spec=action_spec, encoding_network=encoder, encoding_network_num_train_steps=10, encoding_dim=encoding_dim, variable_collection=variable_collection, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)) loss_info, _ = agent.train(experience) self.evaluate(agent.initialize()) self.evaluate(tf.compat.v1.global_variables_initializer()) loss_value = self.evaluate(loss_info) self.assertGreater(loss_value, 0.0)
def testInitializeRestoreVariableCollection(self): if not tf.executing_eagerly(): self.skipTest('Test only works in eager mode.') num_actions = 5 encoding_dim = 7 variable_collection = neural_linucb_agent.NeuralLinUCBVariableCollection( num_actions=num_actions, encoding_dim=encoding_dim) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(variable_collection.num_samples_list) checkpoint = tf.train.Checkpoint(variable_collection=variable_collection) checkpoint_dir = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_dir, 'checkpoint') checkpoint.save(file_prefix=checkpoint_prefix) variable_collection.actions_from_reward_layer.assign(False) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) checkpoint_load_status = checkpoint.restore(latest_checkpoint) self.evaluate(checkpoint_load_status.initialize_or_restore()) self.assertEqual( self.evaluate(variable_collection.actions_from_reward_layer), True)