예제 #1
0
def run_if(cond, func):
    """Run a function if `cond` Tensor is True.

    This function is useful for conditionally executing a function only when
    a condition given by a tf Tensor is True. It is equivalent to the following
    code if `cond` is a python bool value:
    ```python
    if cond:
        func()
    ```
    However, when `cond` is tf bool scalar tensor, the above code does not
    always do what we want because tensorflow does not allow bool scalar tensor
    to be used in the same way as python bool. So we have to use tf.cond to do
    the job.

    Args:
        cond (tf.Tensor): scalar bool Tensor
        func (Callable): function to be run
    Returns:
        None
    """
    scope = get_current_scope()

    def _if_true():
        # The reason of this line is that inside tf.cond, somehow
        # get_current_scope() is '', which makes operations and summaries inside
        # func unscoped. We need this line to restore the original name scope.
        with tf.name_scope(scope):
            func()
            return tf.constant(True)

    tf.cond(cond, _if_true, lambda: tf.constant(False))
예제 #2
0
    def __call__(self, *args, **kwargs):
        """Call the wrapped function.

        Tensorflow creates a different instance of Function object for each
        instance to handle instance specific processing. We need to explicitly
        call tf_Function.__get__ to handle class methods correctly.

        Reference: tensorflow.python.eager.def_function.Function.__get__().
        """
        tf_func_instance = self._tf_func.__get__(self._instance, self._owner)
        return tf_func_instance(get_current_scope(), *args, **kwargs)
예제 #3
0
def conditional_update(target, cond, func, *args, **kwargs):
    """Update target according to cond mask

    Compute result as an update of `target` based on `cond`. To be specific,
    result[row] is func(*args[row], **kwargs[row]) if cond[row] is True,
    otherwise result[row] will be target[row]. Note that target will not be
    changed.

    If you simply want to do some conditional computation without actually
    returning any results. You can use conditional_update in the following way:
    ```
    # func needs to return an empty tuple ()
    conditional_update((), cond, func, *args, **kwargs)
    ```

    Args:
        target (nested Tensor): target to be updated
        func (Callable): a function with arguments (*args, **kwargs) and returning
            a nest with same structure as target
        cond (Tensor): 1d bool Tensor with shape[0] == target.shape[0]
    Returns:
        nest with the same structure and shape as target.
    """
    # shape of indices from where() is [batch_size,1], which is what scatter_nd
    # needs
    scatter_indices = tf.where(cond)
    scope = get_current_scope()

    def _update_subset():
        gather_indices = tf.squeeze(scatter_indices, 1)
        selected_args = _gather_nest(args, gather_indices)
        selected_kwargs = _gather_nest(kwargs, gather_indices)
        with tf.name_scope(scope):
            # tf.case loses the original name scope. Need to restore it.
            updates = func(*selected_args, **selected_kwargs)
        return tf.nest.map_structure(
            lambda tgt, updt: tf.tensor_scatter_nd_update(
                tgt, scatter_indices, updt), target, updates)

    total = tf.shape(cond)[0]
    n = tf.shape(scatter_indices)[0]
    return tf.case(((n == 0, lambda: target),
                    (n == total, lambda: func(*args, **kwargs))),
                   default=_update_subset)
예제 #4
0
    def _update(self, experience, weight):
        batch_size = tf.shape(experience.step_type)[1]
        counter = tf.zeros((), tf.int32)
        initial_train_state = common.get_initial_policy_state(
            batch_size, self.train_state_spec)
        if self._use_rollout_state:
            first_train_state = tf.nest.map_structure(
                lambda state: state[0, ...], experience.state)
        else:
            first_train_state = initial_train_state
        num_steps = tf.shape(experience.step_type)[0]

        def create_ta(s):
            # TensorArray cannot use Tensor (batch_size) as element_shape
            ta_batch_size = experience.step_type.shape[1]
            return tf.TensorArray(dtype=s.dtype,
                                  size=num_steps,
                                  element_shape=tf.TensorShape(
                                      [ta_batch_size]).concatenate(s.shape))

        experience_ta = tf.nest.map_structure(
            create_ta,
            nest_utils.to_distribution_param_spec(
                self.processed_experience_spec))
        experience_ta = tf.nest.map_structure(
            lambda elem, ta: ta.unstack(elem), experience, experience_ta)
        info_ta = tf.nest.map_structure(
            create_ta,
            nest_utils.to_distribution_param_spec(self.train_step_info_spec))

        scope = get_current_scope()

        def _train_loop_body(counter, policy_state, info_ta):
            exp = tf.nest.map_structure(lambda ta: ta.read(counter),
                                        experience_ta)
            exp = nest_utils.params_to_distributions(
                exp, self.processed_experience_spec)
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_train_state,
                tf.equal(exp.step_type, StepType.FIRST))

            with tf.name_scope(scope):
                policy_step = self.train_step(exp, policy_state)

            info_ta = tf.nest.map_structure(
                lambda ta, x: ta.write(counter, x), info_ta,
                nest_utils.distributions_to_params(policy_step.info))

            counter += 1

            return [counter, policy_step.state, info_ta]

        with tf.GradientTape(persistent=True,
                             watch_accessed_variables=False) as tape:
            tape.watch(self.trainable_variables)
            [_, _, info_ta] = tf.while_loop(
                cond=lambda counter, *_: tf.less(counter, num_steps),
                body=_train_loop_body,
                loop_vars=[counter, first_train_state, info_ta],
                back_prop=True,
                name="train_loop")
            info = tf.nest.map_structure(lambda ta: ta.stack(), info_ta)
            info = nest_utils.params_to_distributions(
                info, self.train_step_info_spec)
            experience = nest_utils.params_to_distributions(
                experience, self.processed_experience_spec)
            training_info = TrainingInfo(action=experience.action,
                                         reward=experience.reward,
                                         discount=experience.discount,
                                         step_type=experience.step_type,
                                         rollout_info=experience.rollout_info,
                                         info=info,
                                         env_id=experience.env_id)

        loss_info, grads_and_vars = self.train_complete(
            tape=tape, training_info=training_info, weight=weight)

        del tape

        return training_info, loss_info, grads_and_vars
예제 #5
0
    def _train(self, experience, num_updates, mini_batch_size,
               mini_batch_length, update_counter_every_mini_batch,
               should_summarize):
        """Train using experience."""
        experience = nest_utils.params_to_distributions(
            experience, self.experience_spec)
        experience = self.transform_timestep(experience)
        experience = self.preprocess_experience(experience)
        experience = nest_utils.distributions_to_params(experience)

        length = experience.step_type.shape[1]
        mini_batch_length = (mini_batch_length or length)
        assert length % mini_batch_length == 0, (
            "length=%s not a multiple of mini_batch_length=%s" %
            (length, mini_batch_length))

        if len(tf.nest.flatten(
                self.train_state_spec)) > 0 and not self._use_rollout_state:
            if mini_batch_length == 1:
                logging.fatal(
                    "Should use TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN when minibatch_length==1.")
            else:
                common.warning_once(
                    "Consider using TrainerConfig.use_rollout_state=True "
                    "for off-policy training of RNN.")

        experience = tf.nest.map_structure(
            lambda x: tf.reshape(
                x, common.concat_shape([-1, mini_batch_length],
                                       tf.shape(x)[2:])), experience)

        batch_size = tf.shape(experience.step_type)[0]
        mini_batch_size = (mini_batch_size or batch_size)

        def _make_time_major(nest):
            """Put the time dim to axis=0."""
            return tf.nest.map_structure(lambda x: common.transpose2(x, 0, 1),
                                         nest)

        scope = get_current_scope()

        for u in tf.range(num_updates):
            if mini_batch_size < batch_size:
                indices = tf.random.shuffle(
                    tf.range(tf.shape(experience.step_type)[0]))
                experience = tf.nest.map_structure(
                    lambda x: tf.gather(x, indices), experience)
            for b in tf.range(0, batch_size, mini_batch_size):
                if update_counter_every_mini_batch:
                    common.get_global_counter().assign_add(1)
                is_last_mini_batch = tf.logical_and(
                    tf.equal(u, num_updates - 1),
                    tf.greater_equal(b + mini_batch_size, batch_size))
                do_summary = tf.logical_or(is_last_mini_batch,
                                           update_counter_every_mini_batch)
                common.enable_summary(do_summary)
                batch = tf.nest.map_structure(
                    lambda x: x[b:tf.minimum(batch_size, b + mini_batch_size)],
                    experience)
                batch = _make_time_major(batch)
                # Tensorflow graph mode loses the original name scope here. We
                # need to restore the original name scope
                with tf.name_scope(scope):
                    training_info, loss_info, grads_and_vars = self._update(
                        batch,
                        weight=tf.cast(
                            tf.shape(batch.step_type)[1], tf.float32) /
                        float(mini_batch_size))
                if should_summarize:
                    if do_summary:
                        # Putting `if do_summary` under the above `with` statement
                        # does not help. Somehow `if` statement will also lose
                        # the original name scope.
                        with tf.name_scope(scope):
                            self.summarize_train(training_info, loss_info,
                                                 grads_and_vars)

        train_steps = batch_size * mini_batch_length * num_updates
        return train_steps
예제 #6
0
 def __call__(self, *args, **kwargs):
     return self._tf_func(get_current_scope(), *args, **kwargs)