Example #1
0
 def test_soft_threshold(self, x, threshold, expected_y, expected_dydx):
     x = tf.convert_to_tensor(value=x, dtype=self.dtype)
     y, dydx = tfp_math_gradient.value_and_gradient(
         lambda x_: numeric.soft_threshold(x_, threshold), x)
     y_, dydx_ = self.evaluate([y, dydx])
     self.assertAllClose(expected_y, y_)
     self.assertAllClose(expected_dydx, dydx_)
Example #2
0
    def test_soft_threshold(self, x, threshold, expected_y, expected_dy_dx):
        x = tf.convert_to_tensor(x)
        with tf.GradientTape() as tape:
            tape.watch(x)
            y = numeric.soft_threshold(x, threshold)
        dy_dx = tape.gradient(y, x)

        y, dy_dx = self.evaluate([y, dy_dx])

        self.assertAllClose(expected_y, y)
        self.assertAllClose(expected_dy_dx, dy_dx)
Example #3
0
        def _loop_body(  # pylint: disable=missing-docstring
                iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update):
            # Inner loop of the minimizer.
            #
            # This loop updates a single coordinate of x_update.  Ideally, an
            # iteration of this loop would set
            #
            #   x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R }
            #
            # where
            #
            #   LocalLoss(x_update')
            #     = LocalLossSmoothComponent(x_update')
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #    := (UnregularizedLoss(x_start + x_update') -
            #        UnregularizedLoss(x_start + x_update)
            #         + l2_regularizer * (||x_start + x_update'||_2**2 -
            #                             ||x_start + x_update||_2**2)
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #
            # In this algorithm approximate the above argmin using (univariate)
            # proximal gradient descent:
            #
            # (*)  x_update[j] = prox_{t * l1_regularizer * L1}(
            #                 x_update[j] -
            #                 t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z))
            #
            # where
            #
            #   UnivariateLocalLossSmoothComponent(z)
            #       := LocalLossSmoothComponent(x_update + z*e_j)
            #
            # and we approximate
            #
            #       d/dz UnivariateLocalLossSmoothComponent(z)
            #     = grad LocalLossSmoothComponent(x_update))[j]
            #    ~= (grad LossSmoothComponent(x_start)
            #         + x_update matmul HessianOfLossSmoothComponent(x_start))[j].
            #
            # To choose the parameter t, we squint and pretend that the inner term of
            # (*) is a Newton update as if we were using Newton's method to minimize
            # UnivariateLocalLossSmoothComponent.  That is, we choose t such that
            #
            #   -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC)
            #
            # at z=0.  Hence
            #
            #   t = learning_rate / (d^2/dz^2|z=0 ULLSC)
            #     = learning_rate / HessianOfLossSmoothComponent(
            #                           x_start + x_update)[j,j]
            #    ~= learning_rate / HessianOfLossSmoothComponent(
            #                           x_start)[j,j]
            #
            # The above approximation is equivalent to assuming that
            # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order
            # effects.
            #
            # Note that because LossSmoothComponent is (assumed to be) convex, t is
            # positive.

            # In above notation, coord = j.
            coord = iter_ % dims
            # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2,
            # computed incrementally, where x_update_end and x_update_start are as
            # defined in the convergence criteria.  Accordingly, we reset
            # x_update_diff_norm_sq to zero at the beginning of each sweep.
            x_update_diff_norm_sq = tf.where(
                tf.equal(coord, 0), tf.zeros_like(x_update_diff_norm_sq),
                x_update_diff_norm_sq)

            # Recall that x_update and hess_matmul_x_update has the rightmost
            # dimension transposed to the leftmost dimension.
            w_old = x_start[..., coord] + x_update[coord, ...]
            # This is the coordinatewise Newton update if no L1 regularization.
            # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC).
            second_deriv = _hessian_diag_elt_with_l2(coord)
            newton_step = -_mul_ignoring_nones(  # pylint: disable=invalid-unary-operand-type
                learning_rate, grad_loss_with_l2[..., coord] +
                hess_matmul_x_update[coord, ...]) / second_deriv

            # Applying the soft-threshold operator accounts for L1 regularization.
            # In above notation, delta =
            #     prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old.
            delta = (soft_threshold(
                w_old + newton_step,
                _mul_ignoring_nones(learning_rate, l1_regularizer) /
                second_deriv) - w_old)

            def _do_update(x_update_diff_norm_sq, x_update,
                           hess_matmul_x_update):  # pylint: disable=missing-docstring
                hessian_column_with_l2 = sparse_or_dense_matvecmul(
                    hessian_unregularized_loss_outer,
                    hessian_unregularized_loss_middle *
                    _sparse_or_dense_matmul_onehot(
                        hessian_unregularized_loss_outer, coord),
                    adjoint_a=True)

                if l2_regularizer is not None:
                    hessian_column_with_l2 += _one_hot_like(
                        hessian_column_with_l2,
                        coord,
                        on_value=2. * l2_regularizer)

                # Move the batch dimensions of `hessian_column_with_l2` to rightmost in
                # order to conform to `hess_matmul_x_update`.
                n = tf.rank(hessian_column_with_l2)
                perm = tf.roll(tf.range(n), shift=1, axis=0)
                hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2,
                                                      perm=perm)

                # Update the entire batch at `coord` even if `delta` may be 0 at some
                # batch coordinates. In those cases, adding `delta` is a no-op.
                x_update = tf.tensor_scatter_add(x_update, [[coord]], [delta])

                with tf.control_dependencies([x_update]):
                    x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2
                    hess_matmul_x_update_ = (hess_matmul_x_update +
                                             delta * hessian_column_with_l2)

                    # Hint that loop vars retain the same shape.
                    x_update_diff_norm_sq_.set_shape(
                        x_update_diff_norm_sq_.shape.merge_with(
                            x_update_diff_norm_sq.shape))
                    hess_matmul_x_update_.set_shape(
                        hess_matmul_x_update_.shape.merge_with(
                            hess_matmul_x_update.shape))

                    return [
                        x_update_diff_norm_sq_, x_update, hess_matmul_x_update_
                    ]

            inputs_to_update = [
                x_update_diff_norm_sq, x_update, hess_matmul_x_update
            ]
            return [iter_ + 1] + prefer_static.cond(
                # Note on why checking delta (a difference of floats) for equality to
                # zero is ok:
                #
                # First of all, x - x == 0 in floating point -- see
                # https://stackoverflow.com/a/2686671
                #
                # Delta will conceptually equal zero when one of the following holds:
                # (i)   |w_old + newton_step| <= threshold and w_old == 0
                # (ii)  |w_old + newton_step| > threshold and
                #       w_old + newton_step - sign(w_old + newton_step) * threshold
                #          == w_old
                #
                # In case (i) comparing delta to zero is fine.
                #
                # In case (ii), newton_step conceptually equals
                #     sign(w_old + newton_step) * threshold.
                # Also remember
                #     threshold = -newton_step / (approximation of d/dz|z=0 ULLSC).
                # So (i) happens when
                #     (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step).
                # If we did not require LossSmoothComponent to be strictly convex,
                # then this could actually happen a non-negligible amount of the time,
                # e.g. if the loss function is piecewise linear and one of the pieces
                # has slope 1.  But since LossSmoothComponent is strictly convex, (i)
                # should not systematically happen.
                tf.reduce_all(input_tensor=tf.equal(delta, 0.)),
                lambda: inputs_to_update,
                lambda: _do_update(*inputs_to_update))