def _action(self, time_step, policy_state, seed): observation = time_step.observation mask = None observation_and_action_constraint_splitter = ( self.observation_and_action_constraint_splitter) if observation_and_action_constraint_splitter is not None: observation, mask = observation_and_action_constraint_splitter( observation) self._check_observation_shape(observation) observation = tf.nest.map_structure( lambda t: tf.cast(t, dtype=self._dtype), observation) # Pass the observations through the encoding network. encoded_observation, _ = self._encoding_network(observation) chosen_actions, est_mean_rewards = tf.cond( self._actions_from_reward_layer, lambda: self._get_actions_from_reward_layer( encoded_observation, mask), lambda: self._get_actions_from_linucb(encoded_observation, mask)) arm_observations = () if self._accepts_per_arm_features: arm_observations = observation[ bandit_spec_utils.PER_ARM_FEATURE_KEY] policy_info = policy_utilities.populate_policy_info( arm_observations, chosen_actions, (), est_mean_rewards, self._emit_policy_info, self._accepts_per_arm_features) return policy_step.PolicyStep(chosen_actions, policy_state, policy_info)
def _action(self, time_step, policy_state, seed): observation = time_step.observation mask = None observation_and_action_constraint_splitter = ( self.observation_and_action_constraint_splitter) if observation_and_action_constraint_splitter is not None: observation, mask = observation_and_action_constraint_splitter( observation) # Pass the observations through the encoding network. encoded_observation, _ = self._encoding_network(observation) encoded_observation = tf.cast(encoded_observation, dtype=self._dtype) if tf.distribute.has_strategy(): if self._distributed_use_reward_layer: chosen_actions, est_mean_rewards, est_rewards_optimistic = ( self._get_actions_from_reward_layer(encoded_observation, mask)) else: chosen_actions, est_mean_rewards, est_rewards_optimistic = ( self._get_actions_from_linucb(encoded_observation, mask)) else: chosen_actions, est_mean_rewards, est_rewards_optimistic = tf.cond( self._actions_from_reward_layer, # pylint: disable=g-long-lambda lambda: self._get_actions_from_reward_layer( encoded_observation, mask), lambda: self._get_actions_from_linucb(encoded_observation, mask)) arm_observations = () if self._accepts_per_arm_features: arm_observations = observation[bandit_spec_utils.PER_ARM_FEATURE_KEY] policy_info = policy_utilities.populate_policy_info( arm_observations, chosen_actions, est_rewards_optimistic, est_mean_rewards, self._emit_policy_info, self._accepts_per_arm_features) return policy_step.PolicyStep(chosen_actions, policy_state, policy_info)
def _distribution(self, time_step, policy_state): observation = time_step.observation if self.observation_and_action_constraint_splitter is not None: observation, _ = self.observation_and_action_constraint_splitter( observation) observation = tf.nest.map_structure( lambda o: tf.cast(o, dtype=self._dtype), observation) global_observation, arm_observations = self._split_observation( observation) if self._add_bias: # The bias is added via a constant 1 feature. global_observation = tf.concat([ global_observation, tf.ones([tf.shape(global_observation)[0], 1], dtype=self._dtype) ], axis=1) # Check the shape of the observation matrix. The observations can be # batched. if not global_observation.shape.is_compatible_with( [None, self._global_context_dim]): raise ValueError( 'Global observation shape is expected to be {}. Got {}.'. format([None, self._global_context_dim], global_observation.shape.as_list())) global_observation = tf.reshape(global_observation, [-1, self._global_context_dim]) est_rewards = [] confidence_intervals = [] for k in range(self._num_actions): current_observation = self._get_current_observation( global_observation, arm_observations, k) model_index = policy_utilities.get_model_index( k, self._accepts_per_arm_features) if self._use_eigendecomp: q_t_b = tf.matmul( self._eig_matrix[model_index], tf.linalg.matrix_transpose(current_observation), transpose_a=True) lambda_inv = tf.divide( tf.ones_like(self._eig_vals[model_index]), self._eig_vals[model_index] + self._tikhonov_weight) a_inv_x = tf.matmul(self._eig_matrix[model_index], tf.einsum('j,jk->jk', lambda_inv, q_t_b)) else: a_inv_x = linalg.conjugate_gradient_solve( self._cov_matrix[model_index] + self._tikhonov_weight * tf.eye(self._overall_context_dim, dtype=self._dtype), tf.linalg.matrix_transpose(current_observation)) est_mean_reward = tf.einsum('j,jk->k', self._data_vector[model_index], a_inv_x) est_rewards.append(est_mean_reward) ci = tf.reshape( tf.linalg.tensor_diag_part( tf.matmul(current_observation, a_inv_x)), [-1, 1]) confidence_intervals.append(ci) if self._exploration_strategy == ExplorationStrategy.optimistic: optimistic_estimates = [ tf.reshape(mean_reward, [-1, 1]) + self._alpha * tf.sqrt(confidence) for mean_reward, confidence in zip(est_rewards, confidence_intervals) ] # Keeping the batch dimension during the squeeze, even if batch_size == 1. rewards_for_argmax = tf.squeeze(tf.stack(optimistic_estimates, axis=-1), axis=[1]) elif self._exploration_strategy == ExplorationStrategy.sampling: mu_sampler = tfd.Normal( loc=tf.stack(est_rewards, axis=-1), scale=self._alpha * tf.sqrt( tf.squeeze(tf.stack(confidence_intervals, axis=-1), axis=1))) rewards_for_argmax = mu_sampler.sample() else: raise ValueError('Exploraton strategy %s not implemented.' % self._exploration_strategy) mask = constraints.construct_mask_from_multiple_sources( time_step.observation, self._observation_and_action_constraint_splitter, (), self._num_actions) if mask is not None: chosen_actions = policy_utilities.masked_argmax( rewards_for_argmax, mask, output_type=tf.nest.flatten(self._action_spec)[0].dtype) else: chosen_actions = tf.argmax(rewards_for_argmax, axis=-1, output_type=tf.nest.flatten( self._action_spec)[0].dtype) action_distributions = tfp.distributions.Deterministic( loc=chosen_actions) policy_info = policy_utilities.populate_policy_info( arm_observations, chosen_actions, rewards_for_argmax, tf.stack(est_rewards, axis=-1), self._emit_policy_info, self._accepts_per_arm_features) return policy_step.PolicyStep(action_distributions, policy_state, policy_info)