def linesearch():
                loss_after = fn_loss(**arguments.to_kwargs())

                with tf.control_dependencies(control_inputs=(loss_after, )):
                    # Replace "/" with "_" to ensure TensorDict is flat
                    _deltas = TensorDict(
                        ((var.name[:-2].replace('/', '_'), delta)
                         for var, delta in zip(variables, deltas)))

                    # TODO: should be moved to initialize_given_variables, but fn_loss...
                    def evaluate_step(arguments, deltas):
                        assignments = list()
                        for variable, delta in zip(variables, deltas.values()):
                            assignments.append(
                                variable.assign_add(delta=delta,
                                                    read_value=False))
                        with tf.control_dependencies(
                                control_inputs=assignments):
                            return fn_loss(**arguments.to_kwargs())

                    _deltas = self.line_search.solve(arguments=arguments,
                                                     x_init=_deltas,
                                                     base_value=loss_before,
                                                     zero_value=loss_after,
                                                     fn_x=evaluate_step)
                    return tuple(_deltas.values())
Example #2
0
    def step(self, *, arguments, variables, fn_loss, **kwargs):
        # Negative value since line search maximizes
        loss_before = -fn_loss(**arguments.to_kwargs())

        with tf.control_dependencies(control_inputs=(loss_before, )):
            deltas = self.optimizer.step(arguments=arguments,
                                         variables=variables,
                                         fn_loss=fn_loss,
                                         **kwargs,
                                         return_estimated_improvement=True)

        with tf.control_dependencies(control_inputs=deltas):
            # Negative value since line search maximizes.
            loss_after = -fn_loss(**arguments.to_kwargs())

            if isinstance(deltas, tuple):
                # If 'return_estimated_improvement' argument exists.
                if len(deltas) != 2:
                    raise TensorforceError(
                        message="Unexpected output of internal optimizer.")
                deltas, estimated_improvement = deltas
                # Negative value since line search maximizes.
                estimated_improvement = -estimated_improvement
            else:
                # Some big value
                estimated_improvement = tf.math.maximum(
                    x=tf.math.abs(x=(loss_after - loss_before)),
                    y=tf.math.maximum(
                        x=loss_after,
                        y=tf_util.constant(value=1.0,
                                           dtype='float'))) * tf_util.constant(
                                               value=1000.0, dtype='float')

            # TODO: debug assertion
            dependencies = [loss_after]
            if self.config.create_debug_assertions:
                dependencies.append(
                    tf.debugging.assert_none_equal(x=loss_before,
                                                   y=loss_after))

        with tf.control_dependencies(control_inputs=dependencies):

            # TODO: should be moved to initialize_given_variables, but fn_loss...
            def evaluate_step(arguments, deltas):
                assignments = list()
                for variable, delta in zip(variables, deltas.values()):
                    assignments.append(
                        variable.assign_add(delta=delta, read_value=False))
                with tf.control_dependencies(control_inputs=assignments):
                    # Negative value since line search maximizes.
                    return -fn_loss(**arguments.to_kwargs())

            deltas = TensorDict(
                ((var.name, delta) for var, delta in zip(variables, deltas)))
            deltas = self.line_search.solve(arguments=arguments,
                                            x_init=deltas,
                                            base_value=loss_before,
                                            zero_value=loss_after,
                                            estimated=estimated_improvement,
                                            fn_x=evaluate_step)

            return list(deltas.values())