def run_if(cond, func): """Run a function if `cond` Tensor is True. This function is useful for conditionally executing a function only when a condition given by a tf Tensor is True. It is equivalent to the following code if `cond` is a python bool value: ```python if cond: func() ``` However, when `cond` is tf bool scalar tensor, the above code does not always do what we want because tensorflow does not allow bool scalar tensor to be used in the same way as python bool. So we have to use tf.cond to do the job. Args: cond (tf.Tensor): scalar bool Tensor func (Callable): function to be run Returns: None """ scope = get_current_scope() def _if_true(): # The reason of this line is that inside tf.cond, somehow # get_current_scope() is '', which makes operations and summaries inside # func unscoped. We need this line to restore the original name scope. with tf.name_scope(scope): func() return tf.constant(True) tf.cond(cond, _if_true, lambda: tf.constant(False))
def __call__(self, *args, **kwargs): """Call the wrapped function. Tensorflow creates a different instance of Function object for each instance to handle instance specific processing. We need to explicitly call tf_Function.__get__ to handle class methods correctly. Reference: tensorflow.python.eager.def_function.Function.__get__(). """ tf_func_instance = self._tf_func.__get__(self._instance, self._owner) return tf_func_instance(get_current_scope(), *args, **kwargs)
def conditional_update(target, cond, func, *args, **kwargs): """Update target according to cond mask Compute result as an update of `target` based on `cond`. To be specific, result[row] is func(*args[row], **kwargs[row]) if cond[row] is True, otherwise result[row] will be target[row]. Note that target will not be changed. If you simply want to do some conditional computation without actually returning any results. You can use conditional_update in the following way: ``` # func needs to return an empty tuple () conditional_update((), cond, func, *args, **kwargs) ``` Args: target (nested Tensor): target to be updated func (Callable): a function with arguments (*args, **kwargs) and returning a nest with same structure as target cond (Tensor): 1d bool Tensor with shape[0] == target.shape[0] Returns: nest with the same structure and shape as target. """ # shape of indices from where() is [batch_size,1], which is what scatter_nd # needs scatter_indices = tf.where(cond) scope = get_current_scope() def _update_subset(): gather_indices = tf.squeeze(scatter_indices, 1) selected_args = _gather_nest(args, gather_indices) selected_kwargs = _gather_nest(kwargs, gather_indices) with tf.name_scope(scope): # tf.case loses the original name scope. Need to restore it. updates = func(*selected_args, **selected_kwargs) return tf.nest.map_structure( lambda tgt, updt: tf.tensor_scatter_nd_update( tgt, scatter_indices, updt), target, updates) total = tf.shape(cond)[0] n = tf.shape(scatter_indices)[0] return tf.case(((n == 0, lambda: target), (n == total, lambda: func(*args, **kwargs))), default=_update_subset)
def _update(self, experience, weight): batch_size = tf.shape(experience.step_type)[1] counter = tf.zeros((), tf.int32) initial_train_state = common.get_initial_policy_state( batch_size, self.train_state_spec) if self._use_rollout_state: first_train_state = tf.nest.map_structure( lambda state: state[0, ...], experience.state) else: first_train_state = initial_train_state num_steps = tf.shape(experience.step_type)[0] def create_ta(s): # TensorArray cannot use Tensor (batch_size) as element_shape ta_batch_size = experience.step_type.shape[1] return tf.TensorArray(dtype=s.dtype, size=num_steps, element_shape=tf.TensorShape( [ta_batch_size]).concatenate(s.shape)) experience_ta = tf.nest.map_structure( create_ta, nest_utils.to_distribution_param_spec( self.processed_experience_spec)) experience_ta = tf.nest.map_structure( lambda elem, ta: ta.unstack(elem), experience, experience_ta) info_ta = tf.nest.map_structure( create_ta, nest_utils.to_distribution_param_spec(self.train_step_info_spec)) scope = get_current_scope() def _train_loop_body(counter, policy_state, info_ta): exp = tf.nest.map_structure(lambda ta: ta.read(counter), experience_ta) exp = nest_utils.params_to_distributions( exp, self.processed_experience_spec) policy_state = common.reset_state_if_necessary( policy_state, initial_train_state, tf.equal(exp.step_type, StepType.FIRST)) with tf.name_scope(scope): policy_step = self.train_step(exp, policy_state) info_ta = tf.nest.map_structure( lambda ta, x: ta.write(counter, x), info_ta, nest_utils.distributions_to_params(policy_step.info)) counter += 1 return [counter, policy_step.state, info_ta] with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape: tape.watch(self.trainable_variables) [_, _, info_ta] = tf.while_loop( cond=lambda counter, *_: tf.less(counter, num_steps), body=_train_loop_body, loop_vars=[counter, first_train_state, info_ta], back_prop=True, name="train_loop") info = tf.nest.map_structure(lambda ta: ta.stack(), info_ta) info = nest_utils.params_to_distributions( info, self.train_step_info_spec) experience = nest_utils.params_to_distributions( experience, self.processed_experience_spec) training_info = TrainingInfo(action=experience.action, reward=experience.reward, discount=experience.discount, step_type=experience.step_type, rollout_info=experience.rollout_info, info=info, env_id=experience.env_id) loss_info, grads_and_vars = self.train_complete( tape=tape, training_info=training_info, weight=weight) del tape return training_info, loss_info, grads_and_vars
def _train(self, experience, num_updates, mini_batch_size, mini_batch_length, update_counter_every_mini_batch, should_summarize): """Train using experience.""" experience = nest_utils.params_to_distributions( experience, self.experience_spec) experience = self.transform_timestep(experience) experience = self.preprocess_experience(experience) experience = nest_utils.distributions_to_params(experience) length = experience.step_type.shape[1] mini_batch_length = (mini_batch_length or length) assert length % mini_batch_length == 0, ( "length=%s not a multiple of mini_batch_length=%s" % (length, mini_batch_length)) if len(tf.nest.flatten( self.train_state_spec)) > 0 and not self._use_rollout_state: if mini_batch_length == 1: logging.fatal( "Should use TrainerConfig.use_rollout_state=True " "for off-policy training of RNN when minibatch_length==1.") else: common.warning_once( "Consider using TrainerConfig.use_rollout_state=True " "for off-policy training of RNN.") experience = tf.nest.map_structure( lambda x: tf.reshape( x, common.concat_shape([-1, mini_batch_length], tf.shape(x)[2:])), experience) batch_size = tf.shape(experience.step_type)[0] mini_batch_size = (mini_batch_size or batch_size) def _make_time_major(nest): """Put the time dim to axis=0.""" return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1), nest) scope = get_current_scope() for u in tf.range(num_updates): if mini_batch_size < batch_size: indices = tf.random.shuffle( tf.range(tf.shape(experience.step_type)[0])) experience = tf.nest.map_structure( lambda x: tf.gather(x, indices), experience) for b in tf.range(0, batch_size, mini_batch_size): if update_counter_every_mini_batch: common.get_global_counter().assign_add(1) is_last_mini_batch = tf.logical_and( tf.equal(u, num_updates - 1), tf.greater_equal(b + mini_batch_size, batch_size)) do_summary = tf.logical_or(is_last_mini_batch, update_counter_every_mini_batch) common.enable_summary(do_summary) batch = tf.nest.map_structure( lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)], experience) batch = _make_time_major(batch) # Tensorflow graph mode loses the original name scope here. We # need to restore the original name scope with tf.name_scope(scope): training_info, loss_info, grads_and_vars = self._update( batch, weight=tf.cast( tf.shape(batch.step_type)[1], tf.float32) / float(mini_batch_size)) if should_summarize: if do_summary: # Putting `if do_summary` under the above `with` statement # does not help. Somehow `if` statement will also lose # the original name scope. with tf.name_scope(scope): self.summarize_train(training_info, loss_info, grads_and_vars) train_steps = batch_size * mini_batch_length * num_updates return train_steps
def __call__(self, *args, **kwargs): return self._tf_func(get_current_scope(), *args, **kwargs)