def _train(self, experience, weights=None): """Updates the policy based on the data in `experience`. Note that `experience` should only contain data points that this agent has not previously seen. If `experience` comes from a replay buffer, this buffer should be cleared between each call to `train`. Args: experience: A batch of experience data in the form of a `Trajectory`. weights: (optional) sample weights. Returns: A `LossInfo` containing the loss *before* the training step is taken. In most cases, if `weights` is provided, the entries of this tuple will have been calculated with the weights. Note that each Agent chooses its own method of applying weights. """ # If the experience comes from a replay buffer, the reward has shape: # [batch_size, time_steps] # where `time_steps` is the number of driver steps executed in each # training loop. # We flatten the tensors below in order to reflect the effective batch size. reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self._time_step_spec.observation) if self._observation_and_action_constraint_splitter is not None: observation, _ = self._observation_and_action_constraint_splitter( observation) observation = tf.cast(observation, self._dtype) reward = tf.cast(reward, self._dtype) tf.compat.v1.assign( self.actions_from_reward_layer, tf.less(self._train_step_counter, self._encoding_network_num_train_steps)) def use_actions_from_reward_layer(): return self.compute_loss_using_reward_layer(observation, action, reward, weights, training=True) def no_actions_from_reward_layer(): return self.compute_loss_using_linucb(observation, action, reward, weights, training=True) loss_info = tf.cond(self.actions_from_reward_layer, use_actions_from_reward_layer, no_actions_from_reward_layer) return loss_info
def _update_mixture_distribution(self, experience): reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) policy_choice, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.policy_info[mixture_policy.MIXTURE_AGENT_ID], self._time_step_spec.reward) batch_size = tf.compat.dimension_value( reward.shape[0]) or tf.shape(reward)[0] unnormalized_probabilities = tf.exp(self._mixture_weights) probabilities = unnormalized_probabilities / tf.norm( unnormalized_probabilities, 1) normalizer = tf.reduce_sum(unnormalized_probabilities) probabilities = unnormalized_probabilities / normalizer self._summarize_probabilities(probabilities) repeated_probs = tf.tile(tf.expand_dims(probabilities, axis=0), [batch_size, 1]) probs_per_step = tf.gather(repeated_probs, policy_choice, batch_dims=1) per_step_update_term = tf.expand_dims((1 - reward) / probs_per_step, axis=0) one_hot_policy_choice = tf.one_hot(policy_choice, depth=self._num_agents) update_term = 1 - tf.squeeze( tf.matmul(per_step_update_term, one_hot_policy_choice)) self._update_aggregates(update_term) self._update_inverse_temperature(batch_size) return self._mixture_weights.assign( self._variable_collection.reward_aggregates / self._variable_collection.inverse_temperature)
def _train(self, experience, weights): experience = self._as_trajectory(experience) reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) reward = tf.clip_by_value(reward, clip_value_min=0.0, clip_value_max=1.0) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) partitioned_rewards = tf.dynamic_partition(reward, action, self._num_actions) for k in range(self._num_actions): tf.compat.v1.assign_add( self._alpha[k], tf.cast(tf.reduce_sum(partitioned_rewards[k]), dtype=self._dtype)) tf.compat.v1.assign_add( self._beta[k], tf.cast(tf.reduce_sum(1.0 - partitioned_rewards[k]), dtype=self._dtype)) self.train_step_counter.assign_add(self._batch_size) loss = -1. * tf.reduce_sum(reward) return tf_agent.LossInfo(loss=(loss), extra=())
def _train(self, experience, weights=None): """Updates the policy based on the data in `experience`. Note that `experience` should only contain data points that this agent has not previously seen. If `experience` comes from a replay buffer, this buffer should be cleared between each call to `train`. Args: experience: A batch of experience data in the form of a `Trajectory`. weights: Unused. Returns: A `LossInfo` containing the loss *before* the training step is taken. """ del weights # unused # If the experience comes from a replay buffer, the reward has shape: # [batch_size, time_steps] # where `time_steps` is the number of driver steps executed in each # training loop. # We flatten the tensors below in order to reflect the effective batch size. reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self._time_step_spec.observation) if self._observation_and_action_constraint_splitter is not None: observation, _ = self._observation_and_action_constraint_splitter( observation) observation = tf.cast(observation, self._dtype) reward = tf.cast(reward, self._dtype) for k in range(self._num_actions): diag_mask = tf.linalg.tensor_diag( tf.cast(tf.equal(action, k), self._dtype)) observations_for_arm = tf.matmul(diag_mask, observation) rewards_for_arm = tf.matmul(diag_mask, tf.reshape(reward, [-1, 1])) tf.compat.v1.assign( self._weight_covariances[k], self._gamma * self._weight_covariances[k] + tf.matmul(observations_for_arm, observations_for_arm, transpose_a=True)) tf.compat.v1.assign( self._parameter_estimators[k], self._gamma * self._parameter_estimators[k] + bandit_utils.sum_reward_weighted_observations( rewards_for_arm, observations_for_arm)) batch_size = tf.cast(tf.compat.dimension_value(tf.shape(reward)[0]), dtype=tf.int64) self._train_step_counter.assign_add(batch_size) loss_info = tf_agent.LossInfo(loss=(-1. * tf.reduce_sum(experience.reward)), extra=()) return loss_info
def _process_experience_per_arm(self, experience): """Processes the experience in case the agent accepts per-arm features. In the experience coming from the replay buffer, the reward (and all other elements) have two batch dimensions `batch_size` and `time_steps`, where `time_steps` is the number of driver steps executed in each training loop. We flatten the tensors in order to reflect the effective batch size. Then, all the necessary processing on the observation is done, including splitting the action mask if it is present. After the preprocessing, the per-arm part of the observation is copied over from the respective policy info field and concatenated with the global observation. The action tensor will be replaced by zeros, since in the per-arm case, there is only one reward model to update. Args: experience: An instance of trajectory. Every element in the trajectory has two batch dimensions. Returns: A tuple of reward, action, observation, and batch_size. All the outputs (except `batch_size`) have a single batch dimension of value `batch_size`. """ reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self.training_data_spec.observation) if self._observation_and_action_constraint_splitter is not None: observation, _ = self._observation_and_action_constraint_splitter( observation) batch_size = tf.cast(tf.compat.dimension_value(tf.shape(reward)[0]), dtype=tf.int64) global_observation = observation[bandit_spec_utils.GLOBAL_FEATURE_KEY] if self._add_bias: # The bias is added via a constant 1 feature. global_observation = tf.concat([ global_observation, tf.ones([batch_size, 1], dtype=global_observation.dtype) ], axis=1) # The arm observation we train on needs to be copied from the respective # policy info field to the per arm observation field. Pretending there was # only one action, we fill the action field with zeros. action = tf.zeros(shape=[batch_size], dtype=tf.int64) chosen_action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.policy_info.chosen_arm_features, self.policy.info_spec.chosen_arm_features) arm_observation = chosen_action overall_observation = tf.concat([global_observation, arm_observation], axis=1) overall_observation = tf.reshape( tf.cast(overall_observation, self._dtype), [batch_size, -1]) reward = tf.cast(reward, self._dtype) return reward, action, overall_observation, batch_size
def _train(self, experience, weights): rewards, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) actions, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observations, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self.training_data_spec.observation) if self._observation_and_action_constraint_splitter is not None: observations, _ = self._observation_and_action_constraint_splitter( observations) if self._accepts_per_arm_features: # The arm observation we train on needs to be copied from the respective # policy info field to the per arm observation field. Pretending there was # only one action, we fill the action field with zeros. chosen_action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.policy_info.chosen_arm_features, self.policy.info_spec.chosen_arm_features) observations[ bandit_spec_utils.PER_ARM_FEATURE_KEY] = tf.expand_dims( chosen_action, axis=1) actions = tf.zeros_like(actions) with tf.GradientTape() as tape: loss_info = self.loss(observations, actions, rewards, weights=weights, training=True) self.compute_summaries(loss_info.loss) variables_to_train = self._reward_network.trainable_weights if not variables_to_train: logging.info('No variable to train in the agent.') return loss_info grads = tape.gradient(loss_info.loss, variables_to_train) # Tuple is used for py3, where zip is a generator producing values once. grads_and_vars = tuple(zip(grads, variables_to_train)) if self._gradient_clipping is not None: grads_and_vars = eager_utils.clip_gradient_norms( grads_and_vars, self._gradient_clipping) if self._summarize_grads_and_vars: eager_utils.add_variables_summaries(grads_and_vars, self.train_step_counter) eager_utils.add_gradients_summaries(grads_and_vars, self.train_step_counter) training_lib.apply_gradients(self._optimizer, grads_and_vars, global_step=self.train_step_counter) return loss_info
def _train(self, experience, weights=None): del weights # unused experience = self._as_trajectory(experience) reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self._time_step_spec.observation) policy_choice, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.policy_info[mixture_policy.MIXTURE_AGENT_ID], self._time_step_spec.reward) original_infos, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.policy_info[mixture_policy.SUBPOLICY_INFO], self._original_info_spec) partitioned_nested_infos = nest_utils.batch_nested_tensors( _dynamic_partition_of_nested_tensors(original_infos, policy_choice, self._num_agents)) partitioned_nested_rewards = [ nest_utils.batch_nested_tensors(t) for t in _dynamic_partition_of_nested_tensors( reward, policy_choice, self._num_agents) ] partitioned_nested_actions = [ nest_utils.batch_nested_tensors(t) for t in _dynamic_partition_of_nested_tensors( action, policy_choice, self._num_agents) ] partitioned_nested_observations = [ nest_utils.batch_nested_tensors(t) for t in _dynamic_partition_of_nested_tensors( observation, policy_choice, self._num_agents) ] loss = 0 for k in range(self._num_agents): per_policy_experience = trajectory.single_step( observation=partitioned_nested_observations[k], action=partitioned_nested_actions[k], policy_info=partitioned_nested_infos[k], reward=partitioned_nested_rewards[k], discount=tf.zeros_like(partitioned_nested_rewards[k])) loss_info = self._agents[k].train(per_policy_experience) loss += loss_info.loss common.function_in_tf1()(self._update_mixture_distribution)(experience) return tf_agent.LossInfo(loss=(loss), extra=())
def double_batch_pred2(the_model, all_inputs, specs, is_training=False): outer_dims = nest_utils.get_outer_array_shape(all_inputs, specs) all_inputs, _ = nest_utils.flatten_multi_batched_nested_tensors( all_inputs, specs) vals = the_model(all_inputs, is_training=is_training) vals = tf.reshape(vals, (*outer_dims, -1)) return vals
def _process_experience_global(self, experience): """Processes the experience in case the agent accepts only global features. In the experience coming from the replay buffer, the reward (and all other elements) have two batch dimensions `batch_size` and `time_steps`, where `time_steps` is the number of driver steps executed in each training loop. We flatten the tensors in order to reflect the effective batch size. Then, all the necessary processing on the observation is done, including splitting the action mask if it is present. Args: experience: An instance of trajectory. Every element in the trajectory has two batch dimensions. Returns: A tuple of reward, action, observation, and batch_size. All the outputs (except `batch_size`) have a single batch dimension of value `batch_size`. """ reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self.training_data_spec.observation) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) batch_size = tf.cast(tf.compat.dimension_value(tf.shape(reward)[0]), dtype=tf.int64) if self._observation_and_action_constraint_splitter is not None: observation, _ = self._observation_and_action_constraint_splitter( observation) if self._add_bias: # The bias is added via a constant 1 feature. observation = tf.concat([ observation, tf.ones([batch_size, 1], dtype=observation.dtype) ], axis=1) observation = tf.reshape(tf.cast(observation, self._dtype), [batch_size, -1]) reward = tf.cast(reward, self._dtype) return reward, action, observation, batch_size
def _train(self, experience, weights): rewards, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) actions, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observations, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self._time_step_spec.observation) if self._observation_and_action_constraint_splitter is not None: observations, _ = self._observation_and_action_constraint_splitter( observations) with tf.GradientTape() as tape: loss_info = self.loss(observations, actions, rewards, weights=weights, training=True) tf.debugging.check_numerics(loss_info[0], 'Loss is inf or nan') self.compute_summaries(loss_info.loss) variables_to_train = self._reward_network.trainable_weights if not variables_to_train: logging.info('No variable to train in the agent.') return loss_info grads = tape.gradient(loss_info.loss, variables_to_train) # Tuple is used for py3, where zip is a generator producing values once. grads_and_vars = tuple(zip(grads, variables_to_train)) if self._gradient_clipping is not None: grads_and_vars = eager_utils.clip_gradient_norms( grads_and_vars, self._gradient_clipping) if self._summarize_grads_and_vars: eager_utils.add_variables_summaries(grads_and_vars, self.train_step_counter) eager_utils.add_gradients_summaries(grads_and_vars, self.train_step_counter) training_lib.apply_gradients(self._optimizer, grads_and_vars, global_step=self.train_step_counter) return loss_info
def testFlattenMultiBatchedSingleTensor(self): spec = tensor_spec.TensorSpec([2, 3], dtype=tf.float32) tensor = self.zeros_from_spec(spec, batch_size=7, extra_sizes=[5]) (batch_flattened_tensor, batch_dims) = nest_utils.flatten_multi_batched_nested_tensors(tensor, spec) self.assertEqual(batch_flattened_tensor.shape.as_list(), [35, 2, 3]) self.evaluate(tf.compat.v1.global_variables_initializer()) batch_dims_ = self.evaluate(batch_dims) self.assertAllEqual(batch_dims_, [7, 5])
def testFlattenMultiBatchedNestedTensorsWithSparseTensor(self): if tf.executing_eagerly(): self.skipTest('Do not check nest processing of data in eager mode. ' 'Placeholders are not compatible with eager execution.') shape = [2, 3] specs = self.nest_spec(shape) tensors = self.zeros_from_spec(specs, batch_size=7, extra_sizes=[5]) (batch_flattened_tensors, _) = nest_utils.flatten_multi_batched_nested_tensors(tensors, specs) tf.nest.assert_same_structure(specs, batch_flattened_tensors) assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), [35, 2, 3]) tf.nest.map_structure(assert_shapes, batch_flattened_tensors)
def process_experience_for_neural_agents( experience, observation_and_action_constraint_splitter, accepts_per_arm_features, training_data_spec): """Processes the experience and prepares it for the network of the agent. First the reward, the action, and the observation are flattened to have only one batch dimension. Then the action mask is removed if it is there. Finally, if the experience includes chosen action features in the policy info, it gets copied in place of the per-arm observation. Args: experience: The experience coming from the replay buffer. observation_and_action_constraint_splitter: If the agent accepts action masks, this function splits the mask from the observation. accepts_per_arm_features: Whether the agent accepts per-arm features. training_data_spec: The data spec describing what the agent expects. Returns: A tuple of (reward, action, observation) tensors to be consumed by the train function of the neural agent. """ flattened_experience, _ = nest_utils.flatten_multi_batched_nested_tensors( experience, training_data_spec) observation = flattened_experience.observation action = flattened_experience.action reward = flattened_experience.reward if observation_and_action_constraint_splitter is not None: observation, _ = observation_and_action_constraint_splitter( observation) if accepts_per_arm_features: # The arm observation we train on needs to be copied from the respective # policy info field to the per arm observation field. Pretending there was # only one action, we fill the action field with zeros. chosen_arm_features = flattened_experience.policy_info.chosen_arm_features observation[ bandit_spec_utils.PER_ARM_FEATURE_KEY] = tf.nest.map_structure( lambda t: tf.expand_dims(t, axis=1), chosen_arm_features) action = tf.zeros_like(action) if bandit_spec_utils.NUM_ACTIONS_FEATURE_KEY in observation: # This change is not crucial but since in training there will be only one # action per sample, it's good to follow the convention that the feature # value for `num_actions` be less than or equal to the maximum available # number of actions. observation[ bandit_spec_utils.NUM_ACTIONS_FEATURE_KEY] = tf.ones_like( observation[bandit_spec_utils.NUM_ACTIONS_FEATURE_KEY]) return observation, action, reward
def _flattened_multibatch_tensor( self, original_tensor: types.Tensor) -> types.Tensor: """Flattens the batch and tile dimensions into a single dimension. Args: original_tensor: Input tensor of shape [batch_size, tile, dim]. Returns: Flattened tensor with the outer dimension (batch_size * tile). """ spec = tf.TensorSpec(shape=original_tensor.shape[2:], dtype=original_tensor.dtype) flattened_tensor, _ = nest_utils.flatten_multi_batched_nested_tensors( original_tensor, spec) return flattened_tensor
def testFlattenMultiBatchedNestedTensors(self): shape = [2, 3] specs = self.nest_spec(shape) tensors = self.zeros_from_spec(specs, batch_size=7, extra_sizes=[5]) (batch_flattened_tensors, batch_dims) = nest_utils.flatten_multi_batched_nested_tensors( tensors, specs) tf.nest.assert_same_structure(specs, batch_flattened_tensors) assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), [35, 2, 3]) tf.nest.map_structure(assert_shapes, batch_flattened_tensors) self.evaluate(tf.compat.v1.global_variables_initializer()) batch_dims_ = self.evaluate(batch_dims) self.assertAllEqual(batch_dims_, [7, 5])
def testFlattenMultiBatchedNestedTensorsWithPartiallyKnownShape(self): if tf.executing_eagerly(): self.skipTest('Do not check nest processing of data in eager mode. ' 'Placeholders are not compatible with eager execution.') shape = [2, 3] specs = self.nest_spec(shape, include_sparse=False) tensors = self.placeholders_from_spec(specs) (batch_flattened_tensors, _) = nest_utils.flatten_multi_batched_nested_tensors( tensors, specs) tf.nest.assert_same_structure(specs, batch_flattened_tensors) assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), [None, 2, 3]) tf.nest.map_structure(assert_shapes, batch_flattened_tensors)
def testFlattenMultiBatchedNestedTensorsWithPartiallyKnownSparseTensor(self): if tf.executing_eagerly(): self.skipTest('Do not check nest processing of data in eager mode. ' 'Placeholders are not compatible with eager execution.') shape = [2, None] specs = self.nest_spec(shape) tensors = self.placeholders_from_spec(specs) (batch_flattened_tensors, _) = nest_utils.flatten_multi_batched_nested_tensors(tensors, specs) tf.nest.assert_same_structure(specs, batch_flattened_tensors) def assert_shapes(t): if isinstance(t, tf.SparseTensor): self.assertEqual(t.shape.rank, 3) else: self.assertEqual(t.shape.as_list(), [None, 2, None]) tf.nest.map_structure(assert_shapes, batch_flattened_tensors)
def _train(self, experience, weights=None): """Updates the policy based on the data in `experience`. Note that `experience` should only contain data points that this agent has not previously seen. If `experience` comes from a replay buffer, this buffer should be cleared between each call to `train`. Args: experience: A batch of experience data in the form of a `Trajectory`. weights: Unused. Returns: A `LossInfo` containing the loss *before* the training step is taken. In most cases, if `weights` is provided, the entries of this tuple will have been calculated with the weights. Note that each Agent chooses its own method of applying weights. """ del weights # unused # If the experience comes from a replay buffer, the reward has shape: # [batch_size, time_steps] # where `time_steps` is the number of driver steps executed in each # training loop. # We flatten the tensors below in order to reflect the effective batch size. reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self._time_step_spec.observation) observation = tf.cast(observation, self._dtype) reward = tf.cast(reward, self._dtype) for k in range(self._num_actions): diag_mask = tf.linalg.tensor_diag( tf.cast(tf.equal(action, k), self._dtype)) observations_for_arm = tf.matmul(diag_mask, observation) rewards_for_arm = tf.matmul(diag_mask, tf.reshape(reward, [-1, 1])) num_samples_for_arm_current = tf.reduce_sum(diag_mask) tf.compat.v1.assign_add(self._num_samples_list[k], num_samples_for_arm_current) num_samples_for_arm_total = self._num_samples_list[k].read_value() # Update the matrix A and b. # pylint: disable=cell-var-from-loop,g-long-lambda def update(cov_matrix, data_vector): return update_a_and_b_with_forgetting(cov_matrix, data_vector, rewards_for_arm, observations_for_arm, self._gamma, self._use_eigendecomp) a_new, b_new, eig_vals, eig_matrix = tf.cond( tf.squeeze(num_samples_for_arm_total) > 0, lambda: update( self._cov_matrix_list[k], self._data_vector_list[k]), lambda: (self._cov_matrix_list[k], self._data_vector_list[k], self._eig_vals_list[k], self._eig_matrix_list[k])) tf.compat.v1.assign(self._cov_matrix_list[k], a_new) tf.compat.v1.assign(self._data_vector_list[k], b_new) tf.compat.v1.assign(self._eig_vals_list[k], eig_vals) tf.compat.v1.assign(self._eig_matrix_list[k], eig_matrix) batch_size = tf.cast(tf.compat.dimension_value(tf.shape(reward)[0]), dtype=tf.int64) self._train_step_counter.assign_add(batch_size) return tf_agent.LossInfo(loss=(-1. * tf.reduce_sum(experience.reward)), extra=())
def _distributed_train_step(self, experience, weights=None): """Distributed train fn to be passed as input to run().""" del weights # unused reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self._time_step_spec.observation) if self._observation_and_action_constraint_splitter is not None: observation, _ = self._observation_and_action_constraint_splitter( observation) observation = tf.reshape(observation, [-1, self._context_dim]) observation = tf.cast(observation, self._dtype) reward = tf.cast(reward, self._dtype) # Increase the step counter. batch_size = tf.cast(tf.compat.dimension_value(tf.shape(reward)[0]), dtype=tf.int64) self._train_step_counter.assign_add(batch_size) for k in range(self._num_actions): diag_mask = tf.linalg.tensor_diag( tf.cast(tf.equal(action, k), self._dtype)) observations_for_arm = tf.matmul(diag_mask, observation) rewards_for_arm = tf.matmul(diag_mask, tf.reshape(reward, [-1, 1])) # Compute local updates for the matrix A and b of this arm. cov_matrix_local_udpate = tf.matmul(observations_for_arm, observations_for_arm, transpose_a=True) data_vector_local_update = bandit_utils.sum_reward_weighted_observations( rewards_for_arm, observations_for_arm) def _merge_fn(strategy, per_replica_cov_matrix_update, per_replica_data_vector_update): """Merge the per-replica-updates.""" # Reduce the per-replica-updates using SUM. reduced_cov_matrix_updates = strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_cov_matrix_update, axis=None) reduced_data_vector_updates = strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_data_vector_update, axis=None) def update_fn(v, t): v.assign(v + t) def assign_fn(v, t): v.assign(t) # Update the model variables. # pylint: disable=cell-var-from-loop strategy.extended.update(self._cov_matrix_list[k], update_fn, args=(reduced_cov_matrix_updates, )) strategy.extended.update(self._data_vector_list[k], update_fn, args=(reduced_data_vector_updates, )) # Compute the eigendecomposition, if needed. if self._use_eigendecomp: eig_vals, eig_matrix = tf.linalg.eigh( self._cov_matrix_list[k]) strategy.extended.update(self._eig_vals_list[k], assign_fn, args=(eig_vals, )) strategy.extended.update(self._eig_matrix_list[k], assign_fn, args=(eig_matrix, )) # Passes the local_updates to the _merge_fn() above that performs custom # computation on the per-replica values. # All replicas pause their execution until merge_call() is done and then, # execution is resumed. replica_context = tf.distribute.get_replica_context() replica_context.merge_call(_merge_fn, args=(cov_matrix_local_udpate, data_vector_local_update)) loss = -1. * tf.reduce_sum(experience.reward) return tf_agent.LossInfo(loss=(loss), extra=())