def wrap(self, input_action): """ Args: input_action (dict): nested tensor action produced by the neural net. Dictionary keys are those marked True in 'to_learn'. Returns: actions (dict): nested tensor action which includes all action components expected by the GKP class. """ # step counter to follow the script of periodicity 'period' i = self._env._elapsed_steps % self.period out_shape = nest_utils.get_outer_shape(input_action, self._action_spec) action = {} for a in self.to_learn.keys(): C1 = self.use_mask and self.mask[a][i]==0 C2 = not self.to_learn[a] if C1 or C2: # if not learning: replicate scripted action action[a] = common.replicate(self.script[a][i], out_shape) else: # if learning: rescale input tensor action[a] = input_action[a]*self.scale[a] if self.learn_residuals: action[a] += common.replicate(self.script[a][i], out_shape) return action
def _run(self, time_step=None, policy_state=None, num_episodes=None, maximum_iterations=None): """See `run()` docstring for details.""" if time_step is None: time_step = self.env.reset() if policy_state is None: policy_state = self.policy.get_initial_state(self.env.batch_size) # Batch dim should be first index of tensors during data # collection. batch_dims = nest_utils.get_outer_shape(time_step, self.env.time_step_spec()) counter = tf.zeros(batch_dims, tf.int32) num_episodes = num_episodes or self._num_episodes [_, time_step, policy_state] = tf.nest.map_structure( tf.stop_gradient, tf.while_loop(cond=self._loop_condition_fn(num_episodes), body=self._loop_body_fn(), loop_vars=[counter, time_step, policy_state], parallel_iterations=1, maximum_iterations=maximum_iterations, name='driver_loop')) return time_step, policy_state
def compute_value(self, time_steps): nest_utils.assert_same_structure(time_steps, self.time_step_spec) # get number of actions from the policy batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._train_policy.distribution( time_steps, policy_state=policy_state).action actions = tf.nest.map_structure( lambda d: d.sample(), action_distribution) # this is a deterministic policy observations = time_steps.observation pred_input = (observations, actions) # TODO(architsh): check if the minimum should be taken before or after average critic_pred_1, _ = self._critic_network_1(pred_input, None, training=False) critic_pred_2, _ = self._critic_network_2(pred_input, None, training=False) # final value calculation value = tf.minimum(critic_pred_1, critic_pred_2) return value
def compute_value(self, time_steps): nest_utils.assert_same_structure(time_steps, self.time_step_spec) # get number of actions from the policy batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._train_policy.distribution( time_steps, policy_state=policy_state).action actions = tf.nest.map_structure( lambda d: d.sample(self._num_action_samples), action_distribution) # repeat for multiple actions observations = tf.tile(time_steps.observation, tf.constant([self._num_action_samples, 1])) actions = tf.reshape(actions, [-1, actions.shape[-1]]) pred_input = (observations, actions) if self._critic_network_no_entropy_1 is None: critic_pred_1, _ = self._critic_network_1(pred_input, None, training=False) critic_pred_2, _ = self._critic_network_2(pred_input, None, training=False) else: critic_pred_1, _ = self._critic_network_no_entropy_1( pred_input, None, training=False) critic_pred_2, _ = self._critic_network_no_entropy_2( pred_input, None, training=False) # final value calculation critic_pred = tf.minimum(critic_pred_1, critic_pred_2) critic_pred = tf.reshape(critic_pred, [self._num_action_samples, -1]) value = tf.reduce_mean(critic_pred, axis=0) return value
def test_planning_policy_action_shape( observation_space, action_space, optimiser_policy_trajectory_optimiser_factory): """ Ensure action shape of the planning policy is correct. """ population_size = 10 number_of_particles = 1 horizon = 7 time_step_space = time_step_spec(observation_space) trajectory_optimiser, environment_model = get_optimiser_and_environment_model( time_step_space, observation_space, action_space, population_size=population_size, number_of_particles=number_of_particles, horizon=horizon, optimiser_policy_trajectory_optimiser_factory= optimiser_policy_trajectory_optimiser_factory, ) # remember the time step comes from the real environment with batch size 1 observation = create_uniform_distribution_from_spec( observation_space).sample(sample_shape=(1, )) time_step = restart(observation, batch_size=1) planning_policy = PlanningPolicy(environment_model, trajectory_optimiser) policy_step = planning_policy.action(time_step) action = policy_step.action assert get_outer_shape(action, action_space) == (1, ) assert action_space.is_compatible_with(action[0])
def _sample_and_transpose_actions_and_log_probs( self, time_steps: ts.TimeStep, num_action_samples: int, training: Optional[bool] = False ) -> Tuple[types.Tensor, types.Tensor]: """Samples actions and corresponding log probabilities from policy.""" # Get raw action distribution from policy, and initialize bijectors list. batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) if training: action_distribution = self._train_policy.distribution( time_steps, policy_state=policy_state).action else: action_distribution = self._policy.distribution( time_steps, policy_state=policy_state).action actions = tf.nest.map_structure( lambda d: d.sample(num_action_samples, seed=self._action_seed_stream()), action_distribution) log_pi = common.log_probability(action_distribution, actions, self.action_spec) # Swap the first two axes for a [batch, self._num_cql_samples, ...] shape. actions = self._transpose_tile_and_batch_dims(actions) log_pi = self._transpose_tile_and_batch_dims(log_pi) return actions, log_pi
def value_estimation_loss(self, time_steps, returns, weights, debug_summaries=False): """Computes the value estimation loss for actor-critic training. All tensors should have a single batch dimension. Args: time_steps: A batch of timesteps. returns: Per-timestep returns for value function to predict. (Should come from TD-lambda computation.) weights: Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps. debug_summaries: True if debug summaries should be created. Returns: value_estimation_loss: A scalar value_estimation_loss loss. """ observation = time_steps.observation if debug_summaries: observation_list = tf.nest.flatten(observation) show_observation_index = len(observation_list) != 1 for i, single_observation in enumerate(observation_list): observation_name = ('observations_{}'.format(i) if show_observation_index else 'observations') tf.compat.v2.summary.histogram( name=observation_name, data=single_observation, step=self.train_step_counter) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._collect_policy.get_initial_state(batch_size=batch_size) value_preds, unused_policy_state = self._collect_policy.apply_value_network( time_steps.observation, time_steps.step_type, policy_state=policy_state) value_estimation_error = tf.math.squared_difference(returns, value_preds) value_estimation_error *= weights value_estimation_loss = ( tf.reduce_mean(input_tensor=value_estimation_error) * self._value_pred_loss_coef) if debug_summaries: tf.compat.v2.summary.scalar( name='value_pred_avg', data=tf.reduce_mean(input_tensor=value_preds), step=self.train_step_counter) tf.compat.v2.summary.histogram( name='value_preds', data=value_preds, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='value_estimation_error', data=value_estimation_error, step=self.train_step_counter) if self._check_numerics: value_estimation_loss = tf.debugging.check_numerics( value_estimation_loss, 'value_estimation_loss') return value_estimation_loss
def __call__(self, observation, actions=None): """Returns the probability of input actions being feasible.""" batch_dims = nest_utils.get_outer_shape( observation, self._time_step_spec.observation) shape = tf.concat([batch_dims, tf.constant( self._num_actions, shape=[1], dtype=batch_dims.dtype)], axis=-1) return tf.ones(shape)
def _action(self, time_step, policy_state, seed): del seed outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) action = tf.nest.map_structure( lambda t: common.replicate(t, outer_shape), self._action_value) return policy_step.PolicyStep(action, policy_state, self._policy_info)
def _run(self, time_step=None, policy_state=None, maximum_iterations=None): """See `run()` docstring for details.""" if time_step is None: time_step = self.env.current_time_step() if policy_state is None: policy_state = self.policy.get_initial_state(self.env.batch_size) # Batch dim should be first index of tensors during data collection. batch_dims = nest_utils.get_outer_shape( time_step, self.env.time_step_spec()) counter = tf.zeros(batch_dims, tf.int32) [_, time_step, policy_state] = tf.while_loop( cond=self._loop_condition_fn(), body=self._loop_body_fn(), loop_vars=[ counter, time_step, policy_state], back_prop=False, parallel_iterations=1, maximum_iterations=maximum_iterations, name='driver_loop' ) return time_step, policy_state
def actor_loss(self, time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): tf.nest.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, actions) target_q_values = [] for cn in self._critic_networks: target_q_values1, _ = cn(target_input, time_steps.step_type) target_q_values.append(target_q_values1) target_q_values = tf.reduce_min(target_q_values) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values if weights is not None: actor_loss *= weights actor_loss = tf.reduce_mean(input_tensor=actor_loss) if self._debug_summaries: common.generate_tensor_summaries('actor_loss', actor_loss, self.train_step_counter) common.generate_tensor_summaries('actions', actions, self.train_step_counter) common.generate_tensor_summaries('log_pi', log_pi, self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_avg', data=-tf.reduce_mean(input_tensor=log_pi), step=self.train_step_counter) common.generate_tensor_summaries('target_q_values', target_q_values, self.train_step_counter) batch_size = nest_utils.get_outer_shape( time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._train_policy.distribution( time_steps, policy_state).action if isinstance(action_distribution, tfp.distributions.Normal): common.generate_tensor_summaries('act_mean', action_distribution.loc, self.train_step_counter) common.generate_tensor_summaries( 'act_stddev', action_distribution.scale, self.train_step_counter) elif isinstance(action_distribution, tfp.distributions.Categorical): common.generate_tensor_summaries( 'act_mode', action_distribution.mode(), self.train_step_counter) try: common.generate_tensor_summaries('entropy_action', action_distribution.entropy(), self.train_step_counter) except NotImplementedError: pass # Some distributions do not have an analytic entropy. return actor_loss
def value_estimation_loss(self, time_steps, returns, weights): """Computes the value estimation loss for actor-critic training. All tensors should have a single batch dimension. Args: time_steps: A batch of timesteps. returns: Per-timestep returns for value function to predict. (Should come from TD-lambda computation.) weights: Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps. Returns: value_estimation_loss: A scalar value_estimation_loss loss. """ batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] value_state = self._collect_policy.get_initial_value_state( batch_size=batch_size) value_preds, _ = self._collect_policy.apply_value_network( time_steps.observation, time_steps.step_type, value_state=value_state) value_estimation_error = tf.math.squared_difference( returns, value_preds) value_estimation_error *= weights value_estimation_loss = tf.reduce_mean( input_tensor=value_estimation_error) tf.debugging.check_numerics(value_estimation_loss, "Value loss diverged", name="Value_check") return value_estimation_loss
def _time_step_to_initial_observation( self, time_step: TimeStep, environment_model: EnvironmentModel, ): """ Construct initial observation from time step. :param time_step: Initial time step from the real environment with nominal batch size of 1 (because the real environment is assumed to be not "batchable"). :param environment_model: An `EnvironmentModel` is a model of the MDP that represents the environment, consisting of a transition, reward, termination and initial state distribution model, of which some are trainable and some are fixed. :return: Initial observation that has the appropriate batch size as first dimension. """ observation = time_step.observation batch_size = get_outer_shape(observation, environment_model.observation_spec()) # the time step comes from the real environment assert batch_size == ( 1, ), f"batch_size of time_step.observation = {batch_size} and it should be 1" initial_observation = tf.repeat(observation, repeats=self._batch_size, axis=0) return initial_observation
def _action(self, time_step, policy_state, seed): seed_stream = tfd.SeedStream(seed=seed, salt='epsilon_boltzmann') greedy_action, distribution_step = self._greedy_policy.action_distribution( time_step, policy_state, seed_stream) action_dist = tf.nest.map_structure(self._apply_temperature, distribution_step.action) boltzmann_action = tf.nest.map_structure( lambda d: d.sample(seed=seed_stream()), action_dist) outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) rng = tf.random.uniform(outer_shape, maxval=1.0, seed=seed_stream(), name='epsilon_rng') cond = tf.greater(rng, self._get_epsilon()) outer_ndims = int(outer_shape.shape[0]) ##TODO: remove it to allow environment multiprocessing if outer_ndims >= 2: raise ValueError( 'Only supports batched time steps with a single batch dimension' ) action = tf.compat.v1.where(cond, greedy_action.action, boltzmann_action) info = () state = greedy_action.state return policy_step.PolicyStep(action, state, info)
def natural_policy_gradient(self, time_steps, policy_steps_, gradient, weights): """ Compute natural policy gradient wrt actor_net parameters. :param time_steps: batch of TimeSteps with observations for each timestep :param policy_steps_: policy info for time step sampling policy :param gradient: vanilla policy gradient computed on batch :param weights: mask for invalid timesteps :return: natural gradient as single flattened vector, lagrange coefficient for updating parameters with KL constraint """ batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] # get policy info before update action_distribution_parameters = policy_steps_.info def _kl(params): """ Compute KL between old policy and policy using given params""" unflatten_tensor(params, self._opt_policy_parameters) opt_policy_state = self._opt_policy.get_initial_state(batch_size) dists = self._opt_policy.distribution(time_steps, opt_policy_state) policy_distribution = dists.action kl = self._kl_divergence(time_steps, action_distribution_parameters, policy_distribution) return tf.reduce_mean(kl) def _hv(vector: tf.Tensor) -> tf.Tensor: """Compute product of vector with Hessian of KL divergence""" return hessian_vector_product(_kl, self._opt_policy_parameters, vector) flat_grads = flatten_tensors(gradient) # sync optimisation policy with current policy common.soft_variables_update( self.policy.variables(), self._opt_policy.variables(), # pylint: disable=not-callable tau=1.0, ) # approximate natural gradient by approximately solving grad = H @ nat_grad nat_grad = conjugate_gradient(_hv, flat_grads, max_iter=self._cg_iters) # lagrange coefficient for solving the constrained maximisation coeff = tf.sqrt(2.0 * self._max_kl / (tf.transpose(nat_grad) @ _hv(nat_grad) + EPS)) tf.debugging.check_numerics(nat_grad, "Natural gradient", name="natgrad_check") tf.debugging.check_numerics(coeff, "NatGrad lagrange multiplier", name="multiplier_check") return nat_grad, coeff
def testGetOuterShapeDynamicShapeBatched(self): tensor = tf.placeholder(tf.float32, shape=(None, 1)) spec = tensor_spec.TensorSpec([1], dtype=tf.float32) batch_size = nest_utils.get_outer_shape(tensor, spec) with self.test_session() as sess: self.assertEqual( sess.run(batch_size, feed_dict={tensor: [[0.0]] * 8}), [8])
def _distribution(self, time_step, policy_state): """Implementation of `distribution`. Returns a `Categorical` distribution. The returned `Categorical` distribution has (unnormalized) probabilities `exp(inverse_temperature * weights)`. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: Unused in `CategoricalPolicy`. It is simply passed through. Returns: A `PolicyStep` named tuple containing: `action`: A (optionally nested) of tfp.distribution.Distribution capturing the distribution of next actions. `state`: A policy state tensor for the next call to distribution. `info`: Optional side information such as action log probabilities. """ outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) logits = (self._inverse_temperature * common.replicate(self._weights, outer_shape)) action_distribution = tfd.Independent( tfd.Categorical(logits=logits, dtype=tf.nest.flatten(self.action_spec)[0].dtype)) return policy_step.PolicyStep(action_distribution, policy_state)
def run(self, time_step=None, policy_state=(), maximum_iterations=None): """Takes steps in the environment using the policy while updating observers. Args: time_step: optional initial time_step. If None, it will use the current_time_step of the environment. Elements should be shape [batch_size, ...]. policy_state: optional initial state for the policy. maximum_iterations: Optional maximum number of iterations of the while loop to run. If provided, the cond output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than maximum_iterations. Returns: time_step: TimeStep named tuple with final observation, reward, etc. policy_state: Tensor with final step policy state. """ if time_step is None: time_step = self._env.current_time_step() # Batch dim should be first index of tensors during data collection. batch_dims = nest_utils.get_outer_shape(time_step, self._env.time_step_spec()) counter = tf.zeros(batch_dims, tf.int32) [_, time_step, policy_state ] = tf.while_loop(cond=self._loop_condition_fn(), body=self._loop_body_fn(), loop_vars=[counter, time_step, policy_state], back_prop=False, parallel_iterations=1, maximum_iterations=maximum_iterations, name='driver_loop') return time_step, policy_state
def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi, target_q_values, time_steps): if self._debug_summaries: common.generate_tensor_summaries('actor_loss', actor_loss, self.train_step_counter) common.generate_tensor_summaries('actions', actions, self.train_step_counter) common.generate_tensor_summaries('log_pi', log_pi, self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_avg', data=-tf.reduce_mean(input_tensor=log_pi), step=self.train_step_counter) common.generate_tensor_summaries('target_q_values', target_q_values, self.train_step_counter) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._train_policy.distribution( time_steps, policy_state).action if isinstance(action_distribution, tfp.distributions.Normal): common.generate_tensor_summaries('act_mean', action_distribution.loc, self.train_step_counter) common.generate_tensor_summaries('act_stddev', action_distribution.scale, self.train_step_counter) elif isinstance(action_distribution, tfp.distributions.Categorical): common.generate_tensor_summaries('act_mode', action_distribution.mode(), self.train_step_counter) common.generate_tensor_summaries('entropy_action', action_distribution.entropy(), self.train_step_counter)
def _action(self, time_step, policy_state, seed): observation_and_action_constraint_splitter = ( self.observation_and_action_constraint_splitter) outer_dims = nest_utils.get_outer_shape(time_step, self._time_step_spec) if observation_and_action_constraint_splitter is not None: observation, mask = observation_and_action_constraint_splitter( time_step.observation) zero_logits = tf.cast(tf.zeros_like(mask), tf.float32) masked_categorical = masked.MaskedCategorical(zero_logits, mask) action_ = tf.cast( masked_categorical.sample() + self.action_spec.minimum, self.action_spec.dtype) # If the action spec says each action should be shaped (1,), add another # dimension so the final shape is (B, 1) rather than (B,). if self.action_spec.shape.rank == 1: action_ = tf.expand_dims(action_, axis=-1) policy_info = tensor_spec.sample_spec_nest(self._info_spec, outer_dims=outer_dims) else: observation = time_step.observation action_ = tensor_spec.sample_spec_nest(self._action_spec, seed=seed, outer_dims=outer_dims) policy_info = tensor_spec.sample_spec_nest(self._info_spec, outer_dims=outer_dims) if self._accepts_per_arm_features: def _gather_fn(t): return tf.gather(params=t, indices=action_, batch_dims=1) chosen_arm_features = tf.nest.map_structure( _gather_fn, observation['per_arm']) policy_info = policy_info._replace( chosen_arm_features=chosen_arm_features) # TODO(b/78181147): Investigate why this control dependency is required. if time_step is not None: with tf.control_dependencies(tf.nest.flatten(time_step)): action_ = tf.nest.map_structure(tf.identity, action_) if self.emit_log_probability: if observation_and_action_constraint_splitter is not None: log_probability = masked_categorical.log_prob( action_ - self.action_spec.minimum) else: action_probability = tf.nest.map_structure( _uniform_probability, self._action_spec) log_probability = tf.nest.map_structure( tf.math.log, action_probability) policy_info = policy_step.set_log_probability( policy_info, log_probability) step = policy_step.PolicyStep(action_, policy_state, policy_info) return step
def _action(self, time_step, policy_state, seed): seed_stream = tfp.util.SeedStream(seed=seed, salt='epsilon_greedy') greedy_action = self._greedy_policy.action(time_step, policy_state) random_action = self._random_policy.action(time_step, (), seed_stream()) outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) rng = tf.random.uniform(outer_shape, maxval=1.0, seed=seed_stream(), name='epsilon_rng') cond = tf.greater(rng, self._get_epsilon()) # Selects the action/info from the random policy with probability epsilon. # TODO(b/133175894): tf.compat.v1.where only supports a condition which is # either a scalar or a vector. Use tf.compat.v2 so that it can support any # condition whose leading dimensions are the same as the other operands of # tf.where. outer_ndims = int(outer_shape.shape[0]) if outer_ndims >= 2: raise ValueError( 'Only supports batched time steps with a single batch dimension' ) action = tf.nest.map_structure( lambda g, r: tf.compat.v1.where(cond, g, r), greedy_action.action, random_action.action) if greedy_action.info: if not random_action.info: raise ValueError('Incompatible info field') info = nest_utils.where(cond, greedy_action.info, random_action.info) # Overwrite bandit policy info type. if policy_utilities.has_bandit_policy_type(info, check_for_tensor=True): # Generate mask of the same shape as bandit_policy_type (batch_size, 1). # This is the opposite of `cond`, which is 1-D bool tensor (batch_size,) # that is true when greedy policy was used, otherwise `cond` is false. random_policy_mask = tf.reshape( tf.logical_not(cond), tf.shape(info.bandit_policy_type)) bandit_policy_type = policy_utilities.bandit_policy_uniform_mask( info.bandit_policy_type, mask=random_policy_mask) info = policy_utilities.set_bandit_policy_type( info, bandit_policy_type) else: if random_action.info: raise ValueError('Incompatible info field') info = () # The state of the epsilon greedy policy is the state of the underlying # greedy policy (the random policy carries no state). # It is commonly assumed that the new policy state only depends only # on the previous state and "time_step", the action (be it the greedy one # or the random one) does not influence the new policy state. state = greedy_action.state return policy_step.PolicyStep(action, state, info)
def _distribution(self, time_step, policy_state): outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) action = common.replicate(self._action_value, outer_shape) def dist_fn(action): """Return a categorical distribution with all density on fixed action.""" return tfp.distributions.Deterministic(loc=action) return policy_step.PolicyStep(nest.map_structure(dist_fn, action), policy_state)
def _get_policy_info_and_action(self, time_step): outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) log_probability = tf.nest.map_structure( lambda _: tf.zeros(outer_shape, tf.float32), self._action_spec) policy_info = policy_step.set_log_probability( self._policy_info, log_probability=log_probability) action = tf.nest.map_structure(lambda t: common.replicate(t, outer_shape), self._action_value) return policy_info, action
def _update(traj): if traj.reward.shape: outer_shape = nest_utils.get_outer_shape(traj.reward, reward_spec) batch_size = outer_shape[0] if len(outer_shape) > 1: batch_size *= outer_shape[1] else: batch_size = 1 return batch_size
def _action(self, time_step, policy_state, seed): outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) action = common.replicate(self._next_action, outer_shape) self._action_index += 1 self._action_index %= self._actions.shape[0] self._next_action.assign(self._actions[self._action_index]) return policy_step.PolicyStep(action, policy_state, info=())
def _action(self, time_step, policy_state, seed): outer_dims = nest_utils.get_outer_shape(time_step, self._time_step_spec) action_ = tensor_spec.sample_spec_nest( self._action_spec, seed=seed, outer_dims=outer_dims) # TODO(b/78181147): Investigate why this control dependency is required. if time_step is not None: with tf.control_dependencies(nest.flatten(time_step)): action_ = nest.map_structure(tf.identity, action_) return policy_step.PolicyStep(action_, policy_state)
def behavior_loss(self, time_steps, actions, weights=None): with tf.name_scope('behavior_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._behavior_policy.get_initial_state(batch_size) action_distribution = self._behavior_policy.distribution( time_steps, policy_state=policy_state).action log_pi = common.log_probability(action_distribution, actions, self.action_spec) return -1.0 * tf.reduce_mean(log_pi)
def _line_search(self, time_steps, policy_steps_, advantages, natural_gradient, coeff, weights): """Find new policy parameters by line search in natural gradient direction""" batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] # old policy distribution action_distribution_parameters = policy_steps_.info actions = policy_steps_.action actions_distribution = distribution_spec.nested_distributions_from_specs( self._action_distribution_spec, action_distribution_parameters["dist_params"]) act_log_probs = common.log_probability(actions_distribution, actions, self._action_spec) # loss for the old policy loss_threshold = self.policy_gradient_loss( time_steps, actions, tf.stop_gradient(act_log_probs), tf.stop_gradient(advantages), actions_distribution, weights, ) policy_params = flatten_tensors(self._actor_net.trainable_variables) # try different steps_sizes, accept first one that improves loss and satisfies KL constraint for it in range(self._backtrack_iters): new_params = policy_params - self._backtrack_coeff**it * coeff * natural_gradient unflatten_tensor(new_params, self._opt_policy_parameters) opt_policy_state = self._opt_policy.get_initial_state(batch_size) dists = self._opt_policy.distribution(time_steps, opt_policy_state) new_policy_distribution = dists.action kl = tf.reduce_mean( self._kl_divergence(time_steps, action_distribution_parameters, new_policy_distribution)) loss = self.policy_gradient_loss( time_steps, actions, tf.stop_gradient(act_log_probs), tf.stop_gradient(advantages), new_policy_distribution, weights, ) if kl < self._max_kl and loss < loss_threshold: return new_params # no improvement found return policy_params
def _action(self, time_step, policy_state, seed): i = policy_state[0] % self.period # position within the policy period out_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) action = {} for a in self.script: A = common.replicate(self.script[a][i], out_shape) if a == 'alpha': # do Markovian feedback A *= time_step.observation['msmt'][:,-1,None] if policy_state[0] == 0: A *= 0 action[a] = A return policy_step.PolicyStep(action, policy_state+1, self._policy_info)
def _update(traj): self._agent.update_observation_normalizer(traj.observation) self._agent.update_reward_normalizer(traj.reward) if traj.reward.shape: outer_shape = nest_utils.get_outer_shape(traj.reward, reward_spec) batch_size = outer_shape[0] if len(outer_shape) > 1: batch_size *= outer_shape[1] else: batch_size = 1 return batch_size