Exemple #1
0
        def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
            """Takes a single step only if the outcome has a low enough error."""
            [
                num_jacobian_evaluations, num_matrix_factorizations,
                num_ode_fn_evaluations, status
            ] = diagnostics
            [
                jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps,
                num_steps_same_size, should_update_jacobian,
                should_update_step_size, time, unitary, upper
            ] = iterand
            [backward_differences, order, step_size] = solver_internal_state

            if max_num_steps is not None:
                status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

            backward_differences = tf1.where(
                should_update_step_size,
                bdf_util.interpolate_backward_differences(
                    backward_differences, order, new_step_size / step_size),
                backward_differences)
            step_size = tf1.where(should_update_step_size, new_step_size,
                                  step_size)
            should_update_factorization = should_update_step_size
            num_steps_same_size = tf1.where(should_update_step_size, 0,
                                            num_steps_same_size)

            def update_factorization():
                return bdf_util.newton_qr(
                    jacobian_mat, newton_coefficients_array.read(order),
                    step_size)

            if self._evaluate_jacobian_lazily:

                def update_jacobian_and_factorization():
                    new_jacobian_mat = jacobian_fn_mat(time,
                                                       backward_differences[0])
                    new_unitary, new_upper = update_factorization()
                    return [
                        new_jacobian_mat, True, num_jacobian_evaluations + 1,
                        new_unitary, new_upper
                    ]

                def maybe_update_factorization():
                    new_unitary, new_upper = tf.cond(
                        should_update_factorization, update_factorization,
                        lambda: [unitary, upper])
                    return [
                        jacobian_mat, jacobian_is_up_to_date,
                        num_jacobian_evaluations, new_unitary, new_upper
                    ]

                [
                    jacobian_mat, jacobian_is_up_to_date,
                    num_jacobian_evaluations, unitary, upper
                ] = tf.cond(should_update_jacobian,
                            update_jacobian_and_factorization,
                            maybe_update_factorization)
            else:
                unitary, upper = update_factorization()
                num_matrix_factorizations += 1

            tol = atol + rtol * tf.abs(backward_differences[0])
            newton_tol = newton_tol_factor * tf.norm(tol)

            [
                newton_converged, next_backward_difference, next_state_vec,
                newton_num_iters
            ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                                newton_coefficients_array.read(order),
                                ode_fn_vec, order, step_size, time, newton_tol,
                                unitary, upper)
            num_steps += 1
            num_ode_fn_evaluations += newton_num_iters

            # If Newton's method failed and the Jacobian was up to date, decrease the
            # step size.
            newton_failed = tf.logical_not(newton_converged)
            should_update_step_size = newton_failed & jacobian_is_up_to_date
            new_step_size = step_size * tf1.where(should_update_step_size,
                                                  newton_step_size_factor, 1.)

            # If Newton's method failed and the Jacobian was NOT up to date, update
            # the Jacobian.
            should_update_jacobian = newton_failed & tf.logical_not(
                jacobian_is_up_to_date)

            error_ratio = tf1.where(
                newton_converged,
                bdf_util.error_ratio(next_backward_difference,
                                     error_coefficients_array.read(order),
                                     tol), np.nan)
            accepted = error_ratio < 1.
            converged_and_rejected = newton_converged & tf.logical_not(
                accepted)

            # If Newton's method converged but the solution was NOT accepted, decrease
            # the step size.
            new_step_size = tf1.where(
                converged_and_rejected,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = should_update_step_size | converged_and_rejected

            # If Newton's method converged and the solution was accepted, update the
            # matrix of backward differences.
            time = tf1.where(accepted, time + step_size, time)
            backward_differences = tf1.where(
                accepted,
                bdf_util.update_backward_differences(backward_differences,
                                                     next_backward_difference,
                                                     next_state_vec, order),
                backward_differences)
            jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(
                accepted)
            num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                            num_steps_same_size)

            # Order and step size are only updated if we have taken strictly more than
            # order + 1 steps of the same size. This is to prevent the order from
            # being throttled.
            should_update_order_and_step_size = accepted & (num_steps_same_size
                                                            > order + 1)

            backward_differences_array = tf.TensorArray(
                backward_differences.dtype,
                size=bdf_util.MAX_ORDER + 3,
                clear_after_read=False,
                element_shape=next_backward_difference.get_shape()).unstack(
                    backward_differences)
            new_order = order
            new_error_ratio = error_ratio
            for offset in [-1, +1]:
                proposed_order = tf.clip_by_value(order + offset, 1, max_order)
                proposed_error_ratio = bdf_util.error_ratio(
                    backward_differences_array.read(proposed_order + 1),
                    error_coefficients_array.read(proposed_order), tol)
                proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
                new_order = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_order, new_order)
                new_error_ratio = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_error_ratio,
                    new_error_ratio)
            order = new_order
            error_ratio = new_error_ratio

            new_step_size = tf1.where(
                should_update_order_and_step_size,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = (should_update_step_size
                                       | should_update_order_and_step_size)

            diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                          num_matrix_factorizations,
                                          num_ode_fn_evaluations, status)
            iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date,
                                  new_step_size, num_steps,
                                  num_steps_same_size, should_update_jacobian,
                                  should_update_step_size, time, unitary,
                                  upper)
            solver_internal_state = _BDFSolverInternalState(
                backward_differences, order, step_size)
            return accepted, diagnostics, iterand, solver_internal_state
    def update(self,
               expert_dataset_iter,
               policy_dataset_iter,
               discount,
               replay_regularization=0.05,
               nu_reg=10.0):
        """A function that updates nu network.

    When replay regularization is non-zero, it learns
    (d_pi * (1 - replay_regularization) + d_rb * replay_regulazation) /
    (d_expert * (1 - replay_regularization) + d_rb * replay_regulazation)
    instead.

    Args:
      expert_dataset_iter: An tensorflow graph iteratable over expert data.
      policy_dataset_iter: An tensorflow graph iteratable over training policy
        data, used for regularization.
      discount: An MDP discount.
      replay_regularization: A fraction of samples to add from a replay buffer.
      nu_reg: A grad penalty regularization coefficient.
    """

        (expert_states, expert_actions,
         expert_next_states) = expert_dataset_iter.get_next()

        expert_initial_states = expert_states

        rb_states, rb_actions, rb_next_states, _, _ = policy_dataset_iter.get_next(
        )[0]

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self.actor.variables)
            tape.watch(self.nu_net.variables)

            _, policy_next_actions, _ = self.actor(expert_next_states)
            _, rb_next_actions, rb_log_prob = self.actor(rb_next_states)

            _, policy_initial_actions, _ = self.actor(expert_initial_states)

            # Inputs for the linear part of DualDICE loss.
            expert_init_inputs = tf.concat(
                [expert_initial_states, policy_initial_actions], 1)

            expert_inputs = tf.concat([expert_states, expert_actions], 1)
            expert_next_inputs = tf.concat(
                [expert_next_states, policy_next_actions], 1)

            rb_inputs = tf.concat([rb_states, rb_actions], 1)
            rb_next_inputs = tf.concat([rb_next_states, rb_next_actions], 1)

            expert_nu_0 = self.nu_net(expert_init_inputs)
            expert_nu = self.nu_net(expert_inputs)
            expert_nu_next = self.nu_net(expert_next_inputs)

            rb_nu = self.nu_net(rb_inputs)
            rb_nu_next = self.nu_net(rb_next_inputs)

            expert_diff = expert_nu - discount * expert_nu_next
            rb_diff = rb_nu - discount * rb_nu_next

            linear_loss_expert = tf.reduce_mean(expert_nu_0 * (1 - discount))

            linear_loss_rb = tf.reduce_mean(rb_diff)

            rb_expert_diff = tf.concat([expert_diff, rb_diff], 0)
            rb_expert_weights = tf.concat([
                tf.ones(expert_diff.shape) * (1 - replay_regularization),
                tf.ones(rb_diff.shape) * replay_regularization
            ], 0)

            rb_expert_weights /= tf.reduce_sum(rb_expert_weights)
            non_linear_loss = tf.reduce_sum(
                tf.stop_gradient(
                    weighted_softmax(rb_expert_diff, rb_expert_weights,
                                     axis=0)) * rb_expert_diff)

            linear_loss = (linear_loss_expert * (1 - replay_regularization) +
                           linear_loss_rb * replay_regularization)

            loss = (non_linear_loss - linear_loss)

            alpha = tf.random.uniform(shape=(expert_inputs.shape[0], 1))

            nu_inter = alpha * expert_inputs + (1 - alpha) * rb_inputs
            nu_next_inter = alpha * expert_next_inputs + (
                1 - alpha) * rb_next_inputs

            nu_inter = tf.concat([nu_inter, nu_next_inter], 0)

            with tf.GradientTape(watch_accessed_variables=False) as tape2:
                tape2.watch(nu_inter)
                nu_output = self.nu_net(nu_inter)
            nu_grad = tape2.gradient(nu_output, [nu_inter])[0] + EPS
            nu_grad_penalty = tf.reduce_mean(
                tf.square(tf.norm(nu_grad, axis=-1, keepdims=True) - 1))

            nu_loss = loss + nu_grad_penalty * nu_reg
            pi_loss = -loss + keras_utils.orthogonal_regularization(
                self.actor.trunk)

        nu_grads = tape.gradient(nu_loss, self.nu_net.variables)
        pi_grads = tape.gradient(pi_loss, self.actor.variables)

        self.nu_optimizer.apply_gradients(zip(nu_grads, self.nu_net.variables))
        self.actor_optimizer.apply_gradients(
            zip(pi_grads, self.actor.variables))

        del tape

        self.avg_nu_expert(expert_nu)
        self.avg_nu_rb(rb_nu)

        self.nu_reg_metric(nu_grad_penalty)
        self.avg_loss(loss)

        self.avg_actor_loss(pi_loss)
        self.avg_actor_entropy(-rb_log_prob)

        if tf.equal(self.nu_optimizer.iterations % self.log_interval, 0):
            tf.summary.scalar('train dual dice/loss',
                              self.avg_loss.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_loss)

            tf.summary.scalar('train dual dice/nu expert',
                              self.avg_nu_expert.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_nu_expert)

            tf.summary.scalar('train dual dice/nu rb',
                              self.avg_nu_rb.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_nu_rb)

            tf.summary.scalar('train dual dice/nu reg',
                              self.nu_reg_metric.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.nu_reg_metric)

        if tf.equal(self.actor_optimizer.iterations % self.log_interval, 0):
            tf.summary.scalar('train sac/actor_loss',
                              self.avg_actor_loss.result(),
                              step=self.actor_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_actor_loss)

            tf.summary.scalar('train sac/actor entropy',
                              self.avg_actor_entropy.result(),
                              step=self.actor_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_actor_entropy)
Exemple #3
0
def spherical_uniform(
    shape,
    dimension,
    dtype=tf.float32,
    seed=None,
    name=None):
  """Generates `Tensor` drawn from a uniform distribution on the sphere.

  Args:
    shape: Vector-shaped, `int` `Tensor` representing shape of output.
    dimension: Scalar `int` `Tensor`, representing the dimensionality of the
      space where the sphere is embedded.
    dtype: (Optional) TF `dtype` representing `dtype` of output.
      Default value: `tf.float32`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
      Default value: `None` (i.e., no seed).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'random_spherical_uniform').

  Returns:
    spherical_uniform: `Tensor` with specified `shape` and `dtype` consisting
      of real values drawn from a spherical uniform distribution.
  """
  with tf.name_scope(name or 'spherical_uniform'):
    seed = samplers.sanitize_seed(seed)
    dimension = ps.convert_to_shape_tensor(ps.cast(dimension, dtype=tf.int32))
    shape = ps.convert_to_shape_tensor(shape, dtype=tf.int32)
    dimension_static = tf.get_static_value(dimension)
    sample_shape = ps.concat([shape, [dimension]], axis=0)
    sample_shape = ps.convert_to_shape_tensor(sample_shape)
    # Special case one and two dimensions. This is to guard against the case
    # where the normal samples are zero. This can happen in dimensions 1 and 2.
    if dimension_static is not None:
      # This is equivalent to sampling Rademacher random variables.
      if dimension_static == 1:
        return rademacher(sample_shape, dtype=dtype, seed=seed)
      elif dimension_static == 2:
        u = samplers.uniform(
            shape, minval=0, maxval=2 * np.pi, dtype=dtype, seed=seed)
        return tf.stack([tf.math.cos(u), tf.math.sin(u)], axis=-1)
      else:
        normal_samples = samplers.normal(
            shape=ps.concat([shape, [dimension_static]], axis=0),
            seed=seed,
            dtype=dtype)
        unit_norm = normal_samples / tf.norm(
            normal_samples, ord=2, axis=-1)[..., tf.newaxis]
        return unit_norm

    # If we can't determine the dimension statically, tf.where between the
    # different options.
    r_seed, u_seed, n_seed = samplers.split_seed(
        seed, n=3, salt='spherical_uniform_dynamic_shape')
    rademacher_samples = rademacher(sample_shape, dtype=dtype, seed=r_seed)
    u = samplers.uniform(
        shape, minval=0, maxval=2 * np.pi, dtype=dtype, seed=u_seed)
    twod_samples = tf.concat(
        [tf.math.cos(u)[..., tf.newaxis],
         tf.math.sin(u)[..., tf.newaxis] * tf.ones(
             [dimension - 1], dtype=dtype)], axis=-1)

    normal_samples = samplers.normal(
        shape=ps.concat([shape, [dimension]], axis=0),
        seed=n_seed,
        dtype=dtype)
    nd_samples = normal_samples / tf.norm(
        normal_samples, ord=2, axis=-1)[..., tf.newaxis]

    return tf.where(
        tf.math.equal(dimension, 1),
        rademacher_samples,
        tf.where(
            tf.math.equal(dimension, 2),
            twod_samples,
            nd_samples))
Exemple #4
0
def minimize_one_step(gradient_unregularized_loss,
                      hessian_unregularized_loss_outer,
                      hessian_unregularized_loss_middle,
                      x_start,
                      tolerance,
                      l1_regularizer,
                      l2_regularizer=None,
                      maximum_full_sweeps=1,
                      learning_rate=None,
                      name=None):
    """One step of (the outer loop of) the minimization algorithm.

  This function returns a new value of `x`, equal to `x_start + x_update`.  The
  increment `x_update in R^n` is computed by a coordinate descent method, that
  is, by a loop in which each iteration updates exactly one coordinate of
  `x_update`.  (Some updates may leave the value of the coordinate unchanged.)

  The particular update method used is to apply an L1-based proximity operator,
  "soft threshold", whose fixed point `x_update_fix` is the desired minimum

  ```none
  x_update_fix = argmin{
      Loss(x_start + x_update')
        + l1_regularizer * ||x_start + x_update'||_1
        + l2_regularizer * ||x_start + x_update'||_2**2
      : x_update' }
  ```

  where in each iteration `x_update'` is constrained to have at most one nonzero
  coordinate.

  This update method preserves sparsity, i.e., tends to find sparse solutions if
  `x_start` is sparse.  Additionally, the choice of step size is based on
  curvature (Hessian), which significantly speeds up convergence.

  This algorithm assumes that `Loss` is convex, at least in a region surrounding
  the optimum.  (If `l2_regularizer > 0`, then only weak convexity is needed.)

  Args:
    gradient_unregularized_loss: (Batch of) `Tensor` with the same shape and
      dtype as `x_start` representing the gradient, evaluated at `x_start`, of
      the unregularized loss function (denoted `Loss` above).  (In all current
      use cases, `Loss` is the negative log likelihood.)
    hessian_unregularized_loss_outer: (Batch of) `Tensor` or `SparseTensor`
      having the same dtype as `x_start`, and shape `[N, n]` where `x_start` has
      shape `[n]`, satisfying the property
      `Transpose(hessian_unregularized_loss_outer)
      @ diag(hessian_unregularized_loss_middle)
      @ hessian_unregularized_loss_inner
      = (approximation of) Hessian matrix of Loss, evaluated at x_start`.
    hessian_unregularized_loss_middle: (Batch of) vector-shaped `Tensor` having
      the same dtype as `x_start`, and shape `[N]` where
      `hessian_unregularized_loss_outer` has shape `[N, n]`, satisfying the
      property
      `Transpose(hessian_unregularized_loss_outer)
      @ diag(hessian_unregularized_loss_middle)
      @ hessian_unregularized_loss_inner
      = (approximation of) Hessian matrix of Loss, evaluated at x_start`.
    x_start: (Batch of) vector-shaped, `float` `Tensor` representing the current
      value of the argument to the Loss function.
    tolerance: scalar, `float` `Tensor` representing the convergence threshold.
      The optimization step will terminate early, returning its current value of
      `x_start + x_update`, once the following condition is met:
      `||x_update_end - x_update_start||_2 / (1 + ||x_start||_2)
      < sqrt(tolerance)`,
      where `x_update_end` is the value of `x_update` at the end of a sweep and
      `x_update_start` is the value of `x_update` at the beginning of that
      sweep.
    l1_regularizer: scalar, `float` `Tensor` representing the weight of the L1
      regularization term (see equation above).  If L1 regularization is not
      required, then `tfp.glm.fit_one_step` is preferable.
    l2_regularizer: scalar, `float` `Tensor` representing the weight of the L2
      regularization term (see equation above).
      Default value: `None` (i.e., no L2 regularization).
    maximum_full_sweeps: Python integer specifying maximum number of sweeps to
      run.  A "sweep" consists of an iteration of coordinate descent on each
      coordinate. After this many sweeps, the algorithm will terminate even if
      convergence has not been reached.
      Default value: `1`.
    learning_rate: scalar, `float` `Tensor` representing a multiplicative factor
      used to dampen the proximal gradient descent steps.
      Default value: `None` (i.e., factor is conceptually `1`).
    name: Python string representing the name of the TensorFlow operation.
      The default name is `"minimize_one_step"`.

  Returns:
    x: (Batch of) `Tensor` having the same shape and dtype as `x_start`,
      representing the updated value of `x`, that is, `x_start + x_update`.
    is_converged: scalar, `bool` `Tensor` indicating whether convergence
      occurred across all batches within the specified number of sweeps.
    iter: scalar, `int` `Tensor` representing the actual number of coordinate
      updates made (before achieving convergence).  Since each sweep consists of
      `tf.size(x_start)` iterations, the maximum number of updates is
      `maximum_full_sweeps * tf.size(x_start)`.

  #### References

  [1]: Jerome Friedman, Trevor Hastie and Rob Tibshirani. Regularization Paths
       for Generalized Linear Models via Coordinate Descent. _Journal of
       Statistical Software_, 33(1), 2010.
       https://www.jstatsoft.org/article/view/v033i01/v33i01.pdf

  [2]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for
       L1-regularized Logistic Regression. _Journal of Machine Learning
       Research_, 13, 2012.
       http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf
  """
    with tf.name_scope(name or 'minimize_one_step'):
        x_shape = _get_shape(x_start)
        batch_shape = x_shape[:-1]
        dims = x_shape[-1]

        def _hessian_diag_elt_with_l2(coord):  # pylint: disable=missing-docstring
            # Returns the (coord, coord) entry of
            #
            #   Hessian(UnregularizedLoss(x) + l2_regularizer * ||x||_2**2)
            #
            # evaluated at x = x_start.
            inner_square = tf.reduce_sum(_sparse_or_dense_matmul_onehot(
                hessian_unregularized_loss_outer, coord)**2,
                                         axis=-1)
            unregularized_component = (
                hessian_unregularized_loss_middle[..., coord] * inner_square)
            l2_component = _mul_or_none(2., l2_regularizer)
            return _add_ignoring_nones(unregularized_component, l2_component)

        grad_loss_with_l2 = _add_ignoring_nones(
            gradient_unregularized_loss,
            _mul_or_none(2., l2_regularizer, x_start))

        # We define `x_update_diff_norm_sq_convergence_threshold` such that the
        # convergence condition
        #     ||x_update_end - x_update_start||_2 / (1 + ||x_start||_2)
        #     < sqrt(tolerance)
        # is equivalent to
        #     ||x_update_end - x_update_start||_2**2
        #     < x_update_diff_norm_sq_convergence_threshold.
        x_update_diff_norm_sq_convergence_threshold = (
            tolerance * (1. + tf.norm(tensor=x_start, ord=2, axis=-1))**2)

        # Reshape update vectors so that the coordinate sweeps happen along the
        # first dimension. This is so that we can use tensor_scatter_update to make
        # sparse updates along the first axis without copying the Tensor.
        # TODO(b/118789120): Switch to something like tf.tensor_scatter_nd_add if
        # or when it exists.
        update_shape = tf.concat([[dims], batch_shape], axis=-1)

        def _loop_cond(iter_, x_update_diff_norm_sq, x_update,
                       hess_matmul_x_update):
            del x_update
            del hess_matmul_x_update
            sweep_complete = (iter_ > 0) & tf.equal(iter_ % dims, 0)
            small_delta = (x_update_diff_norm_sq <
                           x_update_diff_norm_sq_convergence_threshold)
            converged = sweep_complete & small_delta
            allowed_more_iterations = iter_ < maximum_full_sweeps * dims
            return allowed_more_iterations & tf.reduce_any(~converged)

        def _loop_body(  # pylint: disable=missing-docstring
                iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update):
            # Inner loop of the minimizer.
            #
            # This loop updates a single coordinate of x_update.  Ideally, an
            # iteration of this loop would set
            #
            #   x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R }
            #
            # where
            #
            #   LocalLoss(x_update')
            #     = LocalLossSmoothComponent(x_update')
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #    := (UnregularizedLoss(x_start + x_update') -
            #        UnregularizedLoss(x_start + x_update)
            #         + l2_regularizer * (||x_start + x_update'||_2**2 -
            #                             ||x_start + x_update||_2**2)
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #
            # In this algorithm approximate the above argmin using (univariate)
            # proximal gradient descent:
            #
            # (*)  x_update[j] = prox_{t * l1_regularizer * L1}(
            #                 x_update[j] -
            #                 t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z))
            #
            # where
            #
            #   UnivariateLocalLossSmoothComponent(z)
            #       := LocalLossSmoothComponent(x_update + z*e_j)
            #
            # and we approximate
            #
            #       d/dz UnivariateLocalLossSmoothComponent(z)
            #     = grad LocalLossSmoothComponent(x_update))[j]
            #    ~= (grad LossSmoothComponent(x_start)
            #         + x_update matmul HessianOfLossSmoothComponent(x_start))[j].
            #
            # To choose the parameter t, we squint and pretend that the inner term of
            # (*) is a Newton update as if we were using Newton's method to minimize
            # UnivariateLocalLossSmoothComponent.  That is, we choose t such that
            #
            #   -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC)
            #
            # at z=0.  Hence
            #
            #   t = learning_rate / (d^2/dz^2|z=0 ULLSC)
            #     = learning_rate / HessianOfLossSmoothComponent(
            #                           x_start + x_update)[j,j]
            #    ~= learning_rate / HessianOfLossSmoothComponent(
            #                           x_start)[j,j]
            #
            # The above approximation is equivalent to assuming that
            # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order
            # effects.
            #
            # Note that because LossSmoothComponent is (assumed to be) convex, t is
            # positive.

            # In above notation, coord = j.
            coord = iter_ % dims
            # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2,
            # computed incrementally, where x_update_end and x_update_start are as
            # defined in the convergence criteria.  Accordingly, we reset
            # x_update_diff_norm_sq to zero at the beginning of each sweep.
            x_update_diff_norm_sq = tf.where(
                tf.equal(coord, 0),
                dtype_util.as_numpy_dtype(x_update_diff_norm_sq.dtype)(0.),
                x_update_diff_norm_sq)

            # Recall that x_update and hess_matmul_x_update has the rightmost
            # dimension transposed to the leftmost dimension.
            w_old = x_start[..., coord] + x_update[coord, ...]
            # This is the coordinatewise Newton update if no L1 regularization.
            # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC).
            second_deriv = _hessian_diag_elt_with_l2(coord)
            newton_step = -_mul_ignoring_nones(  # pylint: disable=invalid-unary-operand-type
                learning_rate, grad_loss_with_l2[..., coord] +
                hess_matmul_x_update[coord, ...]) / second_deriv

            # Applying the soft-threshold operator accounts for L1 regularization.
            # In above notation, delta =
            #     prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old.
            delta = (soft_threshold(
                w_old + newton_step,
                _mul_ignoring_nones(learning_rate, l1_regularizer) /
                second_deriv) - w_old)

            def _do_update(x_update_diff_norm_sq, x_update,
                           hess_matmul_x_update):  # pylint: disable=missing-docstring
                hessian_column_with_l2 = sparse_or_dense_matvecmul(
                    hessian_unregularized_loss_outer,
                    hessian_unregularized_loss_middle *
                    _sparse_or_dense_matmul_onehot(
                        hessian_unregularized_loss_outer, coord),
                    adjoint_a=True)

                if l2_regularizer is not None:
                    hessian_column_with_l2 += _one_hot_like(
                        hessian_column_with_l2,
                        coord,
                        on_value=2. * l2_regularizer)

                # Move the batch dimensions of `hessian_column_with_l2` to rightmost in
                # order to conform to `hess_matmul_x_update`.
                n = tf.rank(hessian_column_with_l2)
                perm = tf.roll(tf.range(n), shift=1, axis=0)
                hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2,
                                                      perm=perm)

                # Update the entire batch at `coord` even if `delta` may be 0 at some
                # batch coordinates. In those cases, adding `delta` is a no-op.
                x_update = tf.tensor_scatter_nd_add(x_update, [[coord]],
                                                    [delta])

                with tf.control_dependencies([x_update]):
                    x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2
                    hess_matmul_x_update_ = (hess_matmul_x_update +
                                             delta * hessian_column_with_l2)

                    # Hint that loop vars retain the same shape.
                    x_update_diff_norm_sq_.set_shape(
                        x_update_diff_norm_sq_.shape.merge_with(
                            x_update_diff_norm_sq.shape))
                    hess_matmul_x_update_.set_shape(
                        hess_matmul_x_update_.shape.merge_with(
                            hess_matmul_x_update.shape))

                    return [
                        x_update_diff_norm_sq_, x_update, hess_matmul_x_update_
                    ]

            inputs_to_update = [
                x_update_diff_norm_sq, x_update, hess_matmul_x_update
            ]
            return [iter_ + 1] + prefer_static.cond(
                # Note on why checking delta (a difference of floats) for equality to
                # zero is ok:
                #
                # First of all, x - x == 0 in floating point -- see
                # https://stackoverflow.com/a/2686671
                #
                # Delta will conceptually equal zero when one of the following holds:
                # (i)   |w_old + newton_step| <= threshold and w_old == 0
                # (ii)  |w_old + newton_step| > threshold and
                #       w_old + newton_step - sign(w_old + newton_step) * threshold
                #          == w_old
                #
                # In case (i) comparing delta to zero is fine.
                #
                # In case (ii), newton_step conceptually equals
                #     sign(w_old + newton_step) * threshold.
                # Also remember
                #     threshold = -newton_step / (approximation of d/dz|z=0 ULLSC).
                # So (i) happens when
                #     (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step).
                # If we did not require LossSmoothComponent to be strictly convex,
                # then this could actually happen a non-negligible amount of the time,
                # e.g. if the loss function is piecewise linear and one of the pieces
                # has slope 1.  But since LossSmoothComponent is strictly convex, (i)
                # should not systematically happen.
                tf.reduce_all(tf.equal(delta, 0.)),
                lambda: inputs_to_update,
                lambda: _do_update(*inputs_to_update))

        base_dtype = x_start.dtype.base_dtype
        iter_, x_update_diff_norm_sq, x_update, _ = tf.while_loop(
            cond=_loop_cond,
            body=_loop_body,
            loop_vars=[
                tf.zeros([], dtype=np.int32, name='iter'),
                tf.zeros(batch_shape,
                         dtype=base_dtype,
                         name='x_update_diff_norm_sq'),
                tf.zeros(update_shape, dtype=base_dtype, name='x_update'),
                tf.zeros(update_shape,
                         dtype=base_dtype,
                         name='hess_matmul_x_update'),
            ])

        # Convert back x_update to the shape of x_start by transposing the leftmost
        # dimension to the rightmost.
        n = tf.rank(x_update)
        perm = tf.roll(tf.range(n), shift=-1, axis=0)
        x_update = tf.transpose(a=x_update, perm=perm)

        converged = tf.reduce_all(
            x_update_diff_norm_sq < x_update_diff_norm_sq_convergence_threshold
        )
        return x_start + x_update, converged, iter_ / dims
def main(unused_args):
    del unused_args

    #
    # General setup.
    #

    ebm_util.init_tf2()

    ebm_util.set_seed(FLAGS.seed)

    output_dir = FLAGS.logdir
    checkpoint_dir = os.path.join(output_dir, 'checkpoint')
    samples_dir = os.path.join(output_dir, 'samples')

    tf.io.gfile.makedirs(samples_dir)
    tf.io.gfile.makedirs(checkpoint_dir)

    log_f = tf.io.gfile.GFile(os.path.join(output_dir, 'log.out'), mode='w')
    logger = ebm_util.setup_logging('main', log_f, console=False)
    logger.info({k: v._value for (k, v) in FLAGS._flags().items()})  # pylint: disable=protected-access

    #
    # Data
    #

    if FLAGS.dataset == 'mnist':
        x_train = ebm_util.mnist_dataset(N_CH)
    elif FLAGS.dataset == 'celeba':
        x_train = ebm_util.celeba_dataset()
    else:
        raise ValueError(f'Unknown dataset. {FLAGS.dataset}')
    train_ds = tf.data.Dataset.from_tensor_slices(x_train).shuffle(
        10000).batch(FLAGS.batch_size)

    #
    # Models
    #

    if FLAGS.q_type == 'mean_field_gaussian':
        q = MeanFieldGaussianQ()
    u = make_u()

    #
    # Optimizers
    #

    def lr_p(step):
        lr = FLAGS.p_learning_rate * (1. - (step / (1.5 * FLAGS.train_steps)))
        return lr

    def lr_q(step):
        lr = FLAGS.q_learning_rate * (1. - (step / (1.5 * FLAGS.train_steps)))
        return lr

    opt_q = tf.optimizers.Adam(learning_rate=ebm_util.LambdaLr(lr_q))
    opt_p = tf.optimizers.Adam(learning_rate=ebm_util.LambdaLr(lr_p),
                               beta_1=FLAGS.p_adam_beta_1)

    #
    # Checkpointing
    #

    global_step_var = tf.Variable(0, trainable=False)
    checkpoint = tf.train.Checkpoint(opt_p=opt_p,
                                     opt_q=opt_q,
                                     u=u,
                                     q=q,
                                     global_step_var=global_step_var)

    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint')
    if tf.io.gfile.exists(checkpoint_path + '.index'):
        print(f'Restoring from {checkpoint_path}')
        checkpoint.restore(checkpoint_path)

    #
    # Stats initialization
    #

    stat_i = []
    stat_keys = [
        'E_pos',  # Mean energy of the positive samples.
        'E_neg_q',  # Mean energy of the negative samples (pre-HMC).
        'E_neg_p',  # Mean energy of the negative samples (post-HMC).
        'H',  # Entropy of Q (if known).
        'pd_pos',  # Pairse differences of the positive samples.
        'pd_neg_q',  # Pairwise differences of the negative samples (pre-HMC).
        'pd_neg_p',  # Pairwise differences of the negative samples (post-HMC).
        'hmc_disp',  # L2 distance between initial and final entropyMC samples.
        'hmc_p_accept',  # entropyMC P(accept).
        'hmc_step_size',  # entropyMC step size.
        'x_neg_p_min',  # Minimum value of the negative samples (post-HMC).
        'x_neg_p_max',  # Maximum value of the negative samples (post-HMC).
        'time',  # Time taken to do the training step.
    ]
    stat = {k: [] for k in stat_keys}

    def array_to_str(a, fmt='{:>8.4f}'):
        return ' '.join([fmt.format(v) for v in a])

    def stats_callback(step, entropy, pd_neg_q):
        del step, entropy, pd_neg_q

    step_size = FLAGS.mcmc_step_size

    train_ds_iter = iter(train_ds)
    x_pos_1 = ebm_util.data_preprocess(next(train_ds_iter))
    x_pos_2 = ebm_util.data_preprocess(next(train_ds_iter))

    global_step = global_step_var.numpy()

    while global_step < (FLAGS.train_steps + 1):
        for x_pos in train_ds:

            # Drop partial batches.
            if x_pos.shape[0] != FLAGS.batch_size:
                continue

            #
            # Update
            #

            start_time = time.time()

            x_pos = ebm_util.data_preprocess(x_pos)
            x_pos = ebm_util.data_discrete_noise(x_pos)

            if FLAGS.p_loss == 'neutra_hmc':
                (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated,
                 neg_e_q, neg_e_p,
                 neg_e_p_updated) = train_p(q, u, x_pos, step_size, opt_p)
            elif FLAGS.p_loss == 'neutra_iid':
                (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated,
                 neg_e_q, neg_e_p,
                 neg_e_p_updated) = train_p_mh(q, u, x_pos, step_size, opt_p)
            else:
                raise ValueError(f'Unknown P loss {FLAGS.p_loss}')

            if FLAGS.q_loss == 'forward_kl':
                train_q_fwd_kl(q, x_neg_p, opt_q)
                entropy = 0.0
                mle_loss = 0.0
            elif FLAGS.q_loss == 'reverse_kl':
                for _ in range(10):
                    _, entropy = train_q_rev_kl(q, u, opt_q)
                mle_loss = 0.0
            elif FLAGS.q_loss == 'reverse_kl_mle':
                for _ in range(FLAGS.q_sub_steps):
                    alpha = FLAGS.q_rkl_weight
                    (_, entropy, _, mle_loss, norm_grads_ebm,
                     norm_grads_mle) = train_q_rev_kl_mle(
                         q, u, x_pos, tf.convert_to_tensor(alpha), opt_q)

            elif FLAGS.q_loss == 'mle':
                mle_loss = train_q_mle(q, x_pos, opt_q)
                entropy = 0.0
            else:
                raise ValueError(f'Unknown Q loss {FLAGS.q_loss}')

            end_time = time.time()

            #
            # Stats
            #

            hmc_disp = tf.reduce_mean(
                tf.norm(tf.reshape(x_neg_q, [64, -1]) -
                        tf.reshape(x_neg_p, [64, -1]),
                        axis=1))

            if global_step % FLAGS.plot_steps == 0:

                # Positives + negatives.
                ebm_util.plot(
                    tf.reshape(ebm_util.data_postprocess(x_neg_q),
                               [FLAGS.batch_size, N_WH, N_WH, N_CH]),
                    os.path.join(samples_dir, f'x_neg_q_{global_step}.png'))
                ebm_util.plot(
                    tf.reshape(ebm_util.data_postprocess(x_neg_p),
                               [FLAGS.batch_size, N_WH, N_WH, N_CH]),
                    os.path.join(samples_dir, f'x_neg_p_{global_step}.png'))
                ebm_util.plot(
                    tf.reshape(ebm_util.data_postprocess(x_pos),
                               [FLAGS.batch_size, N_WH, N_WH, N_CH]),
                    os.path.join(samples_dir, f'x_pos_{global_step}.png'))

                # Samples for various temperatures.
                for t in [0.1, 0.5, 1.0, 2.0, 4.0]:
                    _, x_neg_q_t, _ = q.sample_with_log_prob(FLAGS.batch_size,
                                                             temp=t)
                    ebm_util.plot(
                        tf.reshape(ebm_util.data_postprocess(x_neg_q_t),
                                   [FLAGS.batch_size, N_WH, N_WH, N_CH]),
                        os.path.join(samples_dir,
                                     f'x_neg_t_{t}_{global_step}.png'))

                stats_callback(global_step, entropy,
                               ebm_util.nearby_difference(x_neg_q))

                stat_i.append(global_step)
                stat['E_pos'].append(pos_e_updated)
                stat['E_neg_q'].append(neg_e_q)
                stat['E_neg_p'].append(neg_e_p)
                stat['H'].append(entropy)
                stat['pd_neg_q'].append(ebm_util.nearby_difference(x_neg_q))
                stat['pd_neg_p'].append(ebm_util.nearby_difference(x_neg_p))
                stat['pd_pos'].append(ebm_util.nearby_difference(x_pos))
                stat['hmc_disp'].append(hmc_disp)
                stat['hmc_p_accept'].append(p_accept)
                stat['hmc_step_size'].append(step_size)
                stat['x_neg_p_min'].append(tf.reduce_min(x_neg_p))
                stat['x_neg_p_max'].append(tf.reduce_max(x_neg_p))
                stat['time'].append(end_time - start_time)

                ebm_util.plot_stat(stat_keys, stat, stat_i, output_dir)

                # Doing a linear interpolation in the latent space.
                z_pos_1 = q.forward(x_pos_1)[0]
                z_pos_2 = q.forward(x_pos_2)[0]

                x_alphas = []
                n_steps = 10
                for j in range(0, n_steps + 1):
                    alpha = (j / n_steps)
                    z_alpha = (1. - alpha) * z_pos_1 + (alpha) * z_pos_2
                    x_alpha = q.reverse(z_alpha)[0]
                    x_alphas.append(x_alpha)

                ebm_util.plot_n_by_m(
                    ebm_util.data_postprocess(
                        tf.reshape(tf.stack(x_alphas, axis=1), [
                            (n_steps + 1) * FLAGS.batch_size, N_WH, N_WH, N_CH
                        ])),
                    os.path.join(samples_dir, f'x_alpha_{global_step}.png'),
                    FLAGS.batch_size, n_steps + 1)

                # Doing random perturbations in the latent space.
                for eps in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2e0, 2.5e0, 3e0]:
                    z_pos_2_eps = z_pos_2 + eps * tf.random.normal(
                        z_pos_2.shape)
                    x_alpha = q.reverse(z_pos_2_eps)[0]
                    ebm_util.plot(
                        tf.reshape(ebm_util.data_postprocess(x_alpha),
                                   [FLAGS.batch_size, N_WH, N_WH, N_CH]),
                        os.path.join(samples_dir,
                                     f'x_alpha_eps_{eps}_{global_step}.png'))

                # Checking the log-probabilites of positive and negative examples under
                # Q.
                z_neg_test, x_neg_test, _ = q.sample_with_log_prob(
                    FLAGS.batch_size, temp=FLAGS.q_temperature)
                z_pos_test = q.forward(x_pos)[0]

                z_neg_test_pd = ebm_util.nearby_difference(z_neg_test)
                z_pos_test_pd = ebm_util.nearby_difference(z_pos_test)

                z_norms_neg = tf.reduce_mean(tf.norm(z_neg_test, axis=1))
                z_norms_pos = tf.reduce_mean(tf.norm(z_pos_test, axis=1))

                log_prob_neg = tf.reduce_mean(q.log_prob(x_neg_test))
                log_prob_pos = tf.reduce_mean(q.log_prob(x_pos))

                logger.info('  '.join([
                    f'i={global_step:6d}',
                    # Pre-update, post-update
                    (f'E_pos=[{pos_e:10.4f} {pos_e_updated:10.4f} ' +
                     f'{pos_e_updated - pos_e:10.4f}]'),
                    # Pre-update pre-HMC, pre-update post-HMC, post-update post-HMC
                    (f'E_neg=[{neg_e_q:10.4f} {neg_e_p:10.4f} ' +
                     f'{neg_e_p_updated:10.4f} {neg_e_p_updated - neg_e_p:10.4f}]'
                     ),
                    f'mle={tf.reduce_mean(mle_loss):8.4f}',
                    f'H={entropy:8.4f}',
                    f'norm_grads_ebm={norm_grads_ebm:8.4f}',
                    f'norm_grads_mle={norm_grads_mle:8.4f}',
                    f'pd(x_pos)={ebm_util.nearby_difference(x_pos):8.4f}',
                    f'pd(x_neg_q)={ebm_util.nearby_difference(x_neg_q):8.4f}',
                    f'pd(x_neg_p)={ebm_util.nearby_difference(x_neg_p):8.4f}',
                    f'hmc_disp={hmc_disp:8.4f}',
                    f'p(accept)={p_accept:8.4f}',
                    f'step_size={step_size:8.4f}',
                    # Min, max.
                    (f'x_neg_q=[{tf.reduce_min(x_neg_q):8.4f} ' +
                     f'{tf.reduce_max(x_neg_q):8.4f}]'),
                    (f'x_neg_p=[{tf.reduce_min(x_neg_p):8.4f} ' +
                     f'{tf.reduce_max(x_neg_p):8.4f}]'),
                    f'z_neg_norm={array_to_str(z_norms_neg)}',
                    f'z_pos_norm={array_to_str(z_norms_pos)}',
                    f'z_neg_test_pd={z_neg_test_pd:>8.2f}',
                    f'z_pos_test_pd={z_pos_test_pd:>8.2f}',
                    f'log_prob_neg={log_prob_neg:12.2f}',
                    f'log_prob_pos={log_prob_pos:12.2f}',
                ]))

            if global_step % FLAGS.save_steps == 0:

                global_step_var.assign(global_step)
                checkpoint.write(os.path.join(checkpoint_dir, 'checkpoint'))

            global_step += 1
Exemple #6
0
 def _sample_n(self, n, seed=None):
   raw = samplers.normal(
       shape=ps.concat([[n], self.batch_shape, [self.dimension]], axis=0),
       seed=seed, dtype=self.dtype)
   unit_norm = raw / tf.norm(raw, ord=2, axis=-1)[..., tf.newaxis]
   return unit_norm
Exemple #7
0
def train_model(model,
                ds_train,
                ds_test,
                logdir,
                total_steps=5000,
                batch_size=128,
                val_batch_size=1000,
                save_freq=5,
                log_freq=250,
                use_metainit=False,
                oneshot_prune_fraction=0.,
                gradient_regularization=0):
    """Training of the CNN on MNIST."""
    logging.info('Writing training logs to %s', logdir)
    writer = tf.summary.create_file_writer(os.path.join(logdir, 'train_logs'))
    optimizer = utils.get_optimizer(total_steps)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)
    train_batch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_batch_accuracy')
    # Let's create 2 disjoint validation sets.
    (val_x, val_y), (val2_x, val2_y) = [
        d for d in ds_train.take(val_batch_size * 2).batch(val_batch_size)
    ]

    # We use a separate set than the one we are using in our training.
    def loss_fn(x, y):
        predictions = model(x, training=True)
        reg_loss = tf.add_n(model.losses) if model.losses else 0
        return loss_object(y, predictions) + reg_loss

    mask_updater = mask_updaters.get_mask_updater(model, optimizer, loss_fn)
    if mask_updater:
        mask_updater.set_validation_data(val2_x, val2_y)
    update_prune_step(model, 0)
    if oneshot_prune_fraction > 0:
        logging.info('Running one shot prunning at the beginning.')
        if not mask_updater:
            raise ValueError(
                'mask_updater does not exists. Please set '
                'mask_updater.update_alg flag for one shot pruning.')
        mask_updater.prune(oneshot_prune_fraction)
    if use_metainit:
        n_params = 0
        for layer in model.layers:
            if isinstance(layer, utils.PRUNING_WRAPPER):
                for _, mask, _ in layer.pruning_vars:
                    n_params += tf.reduce_sum(mask)
        metainit.meta_init(model,
                           loss_object, (128, 28, 28, 1), (128, 10),
                           n_params,
                           mask_gradient_fn=mask_gradients)
    # This is used to calculate some distances, would give incorrect results when
    # we restart the training.
    initial_params = list(map(lambda a: a.numpy(), model.trainable_variables))

    # Create the checkpoint object and restore if there is a checkpoint in the
    # folder.
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt,
                                              directory=logdir,
                                              max_to_keep=None)
    if ckpt_manager.latest_checkpoint:
        logging.info('Restored from %s', ckpt_manager.latest_checkpoint)
        ckpt.restore(ckpt_manager.latest_checkpoint)
        is_restored = True
    else:
        logging.info('Starting from scratch.')
        is_restored = False
    # Obtain global_step after loading checkpoint.
    global_step = optimizer.iterations
    tf.summary.experimental.set_step(global_step)
    trainable_vars = model.trainable_variables

    def get_gradients(x, y, log_batch_gradient=False, is_regularized=True):
        """Gets spars gradients and possibly logs some statistics."""
        is_grad_regularized = gradient_regularization != 0
        with tf.GradientTape(persistent=is_grad_regularized) as tape:
            predictions = model(x, training=True)
            batch_loss = loss_object(y, predictions)
            if is_regularized and is_grad_regularized:
                gradients = tape.gradient(batch_loss, trainable_vars)
                gradients = mask_gradients(model, gradients, trainable_vars)
                grad_vec = flatten_list_of_vars(gradients)
                batch_loss += tf.nn.l2_loss(grad_vec) * gradient_regularization
            # Regularization might have been disabled.
            reg_loss = tf.add_n(model.losses) if model.losses else 0
            if is_regularized:
                batch_loss += reg_loss
        gradients = tape.gradient(batch_loss, trainable_vars)
        # Gradients are dense, we should mask them to ensure updates are sparse;
        # So is the norm calculation.
        gradients = mask_gradients(model, gradients, trainable_vars)
        # If batch gradient log it.
        if log_batch_gradient:
            tf.summary.scalar('train_batch_loss', batch_loss)
            tf.summary.scalar('train_batch_reg_loss', reg_loss)
            train_batch_accuracy.update_state(y, predictions)
            tf.summary.scalar('train_batch_accuracy',
                              train_batch_accuracy.result())
            train_batch_accuracy.reset_states()
        return gradients

    def log_fn():
        logging.info('Logging at iter: %d', global_step.numpy())
        log_sparsities(model)
        test_loss, test_acc = test_model(model, ds_test)
        tf.summary.scalar('test_loss', test_loss)
        tf.summary.scalar('test_acc', test_acc)
        # Log gradient norm.
        # We want to obtain/log gradients without regularization term.
        gradients = get_gradients(val_x,
                                  val_y,
                                  log_batch_gradient=False,
                                  is_regularized=False)
        for var, grad in zip(trainable_vars, gradients):
            tf.summary.scalar(f'gradnorm/{var.name}', tf.norm(grad))
        # Log all gradients together
        all_norm = tf.norm(flatten_list_of_vars(gradients))
        tf.summary.scalar('.allparams/gradnorm', all_norm)
        # Log momentum values:
        for s_name in optimizer.get_slot_names():
            # Currently we only log momentum.
            if s_name not in ['momentum']:
                continue
            all_slots = [
                optimizer.get_slot(var, s_name) for var in trainable_vars
            ]
            all_norm = tf.norm(flatten_list_of_vars(all_slots))
            tf.summary.scalar(f'.allparams/norm_{s_name}', all_norm)
        # Log distance to init.
        for initial_val, val in zip(initial_params, model.trainable_variables):
            tf.summary.scalar(f'dist_init_l2/{val.name}',
                              tf.norm(initial_val - val))
            cos_distance = cosine_distance(initial_val, val)
            tf.summary.scalar(f'dist_init_cosine/{val.name}', cos_distance)
        # Mask update logs:
        if mask_updater:
            tf.summary.scalar('drop_fraction', mask_updater.last_drop_fraction)
        # Log all distances together.
        flat_initial = flatten_list_of_vars(initial_params)
        flat_current = flatten_list_of_vars(model.trainable_variables)
        tf.summary.scalar('.allparams/dist_init_l2/',
                          tf.norm(flat_initial - flat_current))
        tf.summary.scalar('.allparams/dist_init_cosine/',
                          cosine_distance(flat_initial, flat_current))
        # Log masks
        for layer in model.layers:
            if isinstance(layer, utils.PRUNING_WRAPPER):
                for _, mask, _ in layer.pruning_vars:
                    tf.summary.image('mask/%s' % mask.name, var_to_img(mask))
        writer.flush()

    def save_fn(step=None):
        save_step = step if step else global_step
        saved_ckpt = ckpt_manager.save(checkpoint_number=save_step)
        logging.info('Saved checkpoint: %s', saved_ckpt)

    with writer.as_default():
        for x, y in ds_train.repeat().shuffle(
                buffer_size=60000).batch(batch_size):
            if global_step >= total_steps:
                logging.info('Total steps: %d is completed',
                             global_step.numpy())
                save_fn()
                break
            update_prune_step(model, global_step)
            if tf.equal(global_step, 0):
                logging.info('Seed: %s First 10 Label: %s', FLAGS.seed, y[:10])
            if global_step % save_freq == 0:
                # If just loaded, don't save it again.
                if is_restored:
                    is_restored = False
                else:
                    save_fn()
            if global_step % log_freq == 0:
                log_fn()
            gradients = get_gradients(x, y, log_batch_gradient=True)
            tf.summary.scalar('lr', optimizer.lr(global_step))
            optimizer.apply_gradients(zip(gradients, trainable_vars))
            if mask_updater and mask_updater.is_update_iter(global_step):
                # Save the network before mask_update, we want to use negative integers
                # for this.
                save_fn(step=(-global_step + 1))
                # Gradient norm before.
                gradients = get_gradients(val_x,
                                          val_y,
                                          log_batch_gradient=False,
                                          is_regularized=False)
                norm_before = tf.norm(flatten_list_of_vars(gradients))
                results = mask_updater.update(global_step)
                # Save network again
                save_fn(step=-global_step)
                if results:
                    for mask_name, drop_frac in results.items():
                        tf.summary.scalar('drop_fraction/%s' % mask_name,
                                          drop_frac)

                # Gradient norm after mask update.
                gradients = get_gradients(val_x,
                                          val_y,
                                          log_batch_gradient=False,
                                          is_regularized=False)
                norm_after = tf.norm(flatten_list_of_vars(gradients))
                tf.summary.scalar('.allparams/gradnorm_mask_update_improvment',
                                  norm_after - norm_before)

        logging.info('Performance after training:')
        log_fn()
    return model
 def compute_norm(self, x):
     return tf.reduce_sum(tf.norm(x, ord=2, axis=1)**3)
Exemple #9
0
 def _gradients_order2_norm(self, gradients):
     norm = tf.norm(
         tf.stack([tf.norm(grad) for grad in gradients
                   if grad is not None]))
     return norm
Exemple #10
0
    def prune_one_unit(self,
                       pruning_pool,
                       baselines=None,
                       normalized_scores=True,
                       pruning_method=None,
                       is_bp=None):
        """Picks a layer and prunes a single unit using the scoring function.

    Args:
      pruning_pool: list, of layers that are considered for pruning.
      baselines: dict, if exists, subtracts the given constant from the scores
        of individual layers. The keys should a subset of pruning_pool.
      normalized_scores: bool, if True the scores are normalized with l2 norm.
      pruning_method: str, from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs'].
        If given, overwrites the default value.
      is_bp: bool, if True Mean Replacement Pruning is used and bias propagation
        is made. If given, overwrites the default value.

    Raises:
      AssertionError: if the arguments provided doesn't match specs.
    """
        pruning_method = pruning_method if pruning_method else self.pruning_method
        is_bp = is_bp if is_bp else self.is_bp
        if pruning_method not in ALL_SCORING_FUNCTIONS:
            raise ValueError('%s is not one of %s' %
                             (pruning_method, ALL_SCORING_FUNCTIONS))
        if baselines is None:
            baselines = {}
        logging.info('Prunning with: %s, is_bp: %s', pruning_method, is_bp)

        # Calculating the scoring function/mean value.
        is_abs = pruning_method.startswith('abs')
        is_mrs = pruning_method.endswith('mrs')
        is_rs = pruning_method.endswith('rs') and not is_mrs
        is_grad = is_mrs or is_rs
        train_utils.cross_entropy_loss(
            self.model,
            self.subset_val,
            training=False,
            compute_mean_replacement_saliency=is_mrs,
            compute_removal_saliency=is_rs,
            is_abs=is_abs,
            aggregate_values=True,
            run_gradient=is_grad)
        scores = {}
        mean_values = {}
        smallest_score = None
        smallest_l_name = None
        smallest_nprune = None

        for l_name in pruning_pool:
            l_ts = getattr(self.model, l_name + '_ts')
            l = getattr(self.model, l_name)
            mean_values[l_name] = l_ts.get_saved_values('mean')
            # Make sure the masks are applied after last gradient update. Note
            # that this is necessary for `norm` functions, since it doesn't call the
            # model and therefore the masks are not applied.
            l.apply_masks()
            if pruning_method == 'rand':
                scores[l_name] = unitscorers.random_score(
                    l.get_layer().weights[0])
            elif pruning_method == 'norm':
                scores[l_name] = unitscorers.norm_score(
                    l.get_layer().weights[0])
            else:
                # mrs or rs.
                score_name = 'rs' if is_rs else 'mrs'
                scores[l_name] = l_ts.get_saved_values(score_name)
            if normalized_scores:
                scores[l_name] /= tf.norm(scores[l_name])
            baseline_score = baselines.get(l_name, 0)
            if baseline_score != 0:
                # Regularizing the scores with c_flop weights.
                scores[l_name] -= baseline_score
            # If there is an existing mask we have to make sure pruned connections
            # are indicated. Let's set them to very small negative number (-1e10).
            # Note that the elements of `l.mask_bias` consist of zeros and ones only.
            if l.mask_bias is not None:
                # Setting the scores of the pruned units to zero.
                scores[l_name] = scores[l_name] * l.mask_bias
                # Setting the scores of the pruned units to -1e10.
                scores[l_name] += -1e10 * (1 - l.mask_bias)
                # Number of previously pruned units.
                n_pruned = tf.count_nonzero(l.mask_bias - 1).numpy()
                layer_smallest_score = tf.reduce_min(
                    tf.boolean_mask(scores[l_name], l.mask_bias)).numpy()
                # Do not prune the last unit.
                if tf.equal(n_pruned + 1, tf.size(l.mask_bias)):
                    continue
            else:
                n_pruned = 0
                layer_smallest_score = tf.reduce_min(scores[l_name]).numpy()

            logging.info('Layer:%s, min:%f', l_name, layer_smallest_score)
            if smallest_score is None or (layer_smallest_score <
                                          smallest_score):
                smallest_score = layer_smallest_score
                smallest_l_name = l_name
                # We want to prune one more than before.
                smallest_nprune = n_pruned + 1
        logging.info('UNIT_PRUNED, layer:%s, n_pruned:%d', smallest_l_name,
                     smallest_nprune)
        mean_values = {smallest_l_name: mean_values[smallest_l_name]}
        scores = {smallest_l_name: scores[smallest_l_name]}
        input_shapes = {
            smallest_l_name: getattr(self.model,
                                     smallest_l_name + '_ts').xshape
        }
        layers2prune = [smallest_l_name]
        prune_model_with_scores(self.model, scores, is_bp, layers2prune, None,
                                smallest_nprune, mean_values, input_shapes)
def train_step_black_box(data,
                         labels_one_hot,
                         samples,
                         weights,
                         _lambda,
                         trainable=tf.constant(False)):
    print("----Tracing--train_step_black_box")

    @tf.function
    def share_loss(X, weights):
        print("----Tracing---share_loss")

        def kl_divergence(x_d):
            print("---Tracing the KL")
            kl = tf.keras.losses.KLDivergence()
            return kl(tf.exp(model.compute_log_conditional_distribution(x_d)),
                      black_box(x_d, trainable=trainable))

        return tfp.monte_carlo.expectation(f=kl_divergence,
                                           samples=X,
                                           log_prob=model.log_pdf,
                                           use_reparametrization=False)

    with tf.GradientTape() as tape1:
        # share_loss = _lambda*black_box.share_loss(X = samples,  sTGMA = model , weights = weights)
        # cross_entropy = cross_ent(labels_one_hot, black_box(data))

        # loss = cross_entropy + share_loss + black_box.losses()
        #gradients = tape.gradient(loss , black_box.trainable_variables)

        print("--tracing-gradient_persistent")
        #print(samples)
        #print(weights)
        #print(black_box(data))
        share_loss = share_loss(X=samples, weights=weights)

    with tf.GradientTape() as tape2:
        cross_ent = tf.keras.losses.CategoricalCrossentropy()
        logits = black_box(data, trainable=tf.constant(True))
        cross_entropy = cross_ent(labels_one_hot, logits)
        # loss = cross_entropy + share_loss + black_box.losses()
    gradients1 = tape1.gradient(share_loss, black_box.trainable_variables)
    gradients2 = tape2.gradient(cross_entropy, black_box.trainable_variables)
    #tf.print([grads.shape for grads in gradients1] )
    #print("tattaataaa")
    numerator = tf.constant(0.0)
    denominator = tf.constant(0.0)

    for grads1, grads2 in zip(gradients1, gradients2):
        numerator = numerator + tf.reduce_sum(grads2 * grads2 -
                                              grads1 * grads2)
        denominator = denominator + tf.norm(grads1 - grads2)**2
    qiota = 1. - 1. / (1. + _lambda)
    tau = tf.math.maximum(tf.math.minimum(numerator / denominator, qiota), 0.0)
    gradients = [
        tau * grads1 + (1 - tau) * grads2
        for grads1, grads2 in zip(gradients1, gradients2)
    ]
    tf.print("Tau param: ", tau)
    optimizer_black_box.apply_gradients(
        zip(gradients, black_box.trainable_variables))

    del tape1
    del tape2

    return cross_entropy, share_loss, tau  #, gradients
Exemple #12
0
def error_ratio(backward_difference, error_coefficient, tol):
    """Computes the ratio of the error in the computed state to the tolerance."""
    tol_cast = tf.cast(tol, backward_difference.dtype)
    error_ratio_ = tf.norm(error_coefficient * backward_difference / tol_cast)
    return tf.cast(error_ratio_, tf.abs(backward_difference).dtype)
 def while_loop_condition(iteration, eigenvector, old_eigenvector):
     """Returns false if the while loop should terminate."""
     not_done = (iteration < maximum_iterations)
     not_converged = (tf.norm(eigenvector - old_eigenvector) > epsilon)
     return tf.logical_and(not_done, not_converged)
def _maximal_eigenvector_power_method(matrix,
                                      epsilon=1e-6,
                                      maximum_iterations=100):
    """Returns a maximal right-eigenvector of "matrix" using the power method.

  Args:
    matrix: 2D Tensor, the matrix of which we will find a maximal
      right-eigenvector.
    epsilon: non-negative float, if two iterations of the power method differ
      (in L2 norm) by no more than epsilon, we will terminate.
    maximum_iterations: non-negative int, if we perform this many iterations, we
      will terminate.

  Returns:
    A maximal right-eigenvector of "matrix".

  Raises:
    TypeError: if the "matrix" `Tensor` is not floating-point.
    ValueError: if the "epsilon" or "maximum_iterations" parameters violate
      their bounds.
  """
    if not matrix.dtype.is_floating:
        raise TypeError("multipliers must have a floating-point dtype")
    if epsilon <= 0.0:
        raise ValueError("epsilon must be strictly positive")
    if maximum_iterations <= 0:
        raise ValueError("maximum_iterations must be strictly positive")

    def while_loop_condition(iteration, eigenvector, old_eigenvector):
        """Returns false if the while loop should terminate."""
        not_done = (iteration < maximum_iterations)
        not_converged = (tf.norm(eigenvector - old_eigenvector) > epsilon)
        return tf.logical_and(not_done, not_converged)

    def while_loop_body(iteration, eigenvector, old_eigenvector):
        """Performs one iteration of the power method."""
        del old_eigenvector  # Needed by the condition, but not the body (for lint).
        iteration += 1
        # We need to use tf.matmul() and tf.expand_dims(), instead of
        # tf.tensordot(), since the former will infer the shape of the result, while
        # the latter will not (tf.while_loop() needs the shapes).
        new_eigenvector = tf.matmul(matrix, tf.expand_dims(eigenvector, 1))[:,
                                                                            0]
        new_eigenvector /= tf.norm(new_eigenvector)
        return (iteration, new_eigenvector, eigenvector)

    iteration = tf.constant(0)
    eigenvector = tf.ones_like(matrix[:, 0])
    eigenvector /= tf.norm(eigenvector)

    # We actually want a do-while loop, so we explicitly call while_loop_body()
    # once before tf.while_loop().
    iteration, eigenvector, old_eigenvector = while_loop_body(
        iteration, eigenvector, eigenvector)
    iteration, eigenvector, old_eigenvector = tf.while_loop(
        while_loop_condition,
        while_loop_body,
        loop_vars=(iteration, eigenvector, old_eigenvector),
        name="power_method")

    return eigenvector
Exemple #15
0
def minimize(value_and_gradients_function,
             initial_position,
             tolerance=1e-8,
             x_tolerance=0,
             f_relative_tolerance=0,
             max_iterations=50,
             parallel_iterations=1,
             stopping_condition=None,
             params=None,
             name=None):
  """Minimizes a differentiable function.

  Implementation of algorithm described in [HZ2006]. Updated formula for next
  search direction were taken from [HZ2013].

  Supports batches with 1-dimensional batch shape.

  ### References:
  [HZ2006] Hager, William W., and Hongchao Zhang. "Algorithm 851: CG_DESCENT,
    a conjugate gradient method with guaranteed descent."
    http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf
  [HZ2013] W. W. Hager and H. Zhang (2013) The limited memory conjugate gradient
    method.
    https://pdfs.semanticscholar.org/8769/69f3911777e0ff0663f21b67dff30518726b.pdf

  ### Usage:
  The following example demonstrates this optimizer attempting to find the
  minimum for a simple two dimensional quadratic objective function.

  ```python
    minimum = np.array([1.0, 1.0])  # The center of the quadratic bowl.
    scales = np.array([2.0, 3.0])  # The scales along the two axes.

    # The objective function and the gradient.
    def quadratic(x):
      value = tf.reduce_sum(scales * (x - minimum) ** 2)
      return value, tf.gradients(value, x)[0]

    start = tf.constant([0.6, 0.8])  # Starting point for the search.
    optim_results = conjugate_gradient.minimize(
        quadratic, initial_position=start, tolerance=1e-8)

    with tf.Session() as session:
      results = session.run(optim_results)
      # Check that the search converged
      assert(results.converged)
      # Check that the argmin is close to the actual value.
      np.testing.assert_allclose(results.position, minimum)
  ```

  Args:
    value_and_gradients_function:  A Python callable that accepts a point as a
      real `Tensor` and returns a tuple of `Tensor`s of real dtype containing
      the value of the function and its gradient at that point. The function to
      be minimized. The input should be of shape `[..., n]`, where `n` is the
      size of the domain of input points, and all others are batching
      dimensions. The first component of the return value should be a real
      `Tensor` of matching shape `[...]`. The second component (the gradient)
      should also be of shape `[..., n]` like the input value to the function.
    initial_position: Real `Tensor` of shape `[..., n]`. The starting point, or
      points when using batching dimensions, of the search procedure. At these
      points the function value and the gradient norm should be finite.
    tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance
      for the procedure. If the supremum norm of the gradient vector is below
      this number, the algorithm is stopped.
    x_tolerance: Scalar `Tensor` of real dtype. If the absolute change in the
      position between one iteration and the next is smaller than this number,
      the algorithm is stopped.
    f_relative_tolerance: Scalar `Tensor` of real dtype. If the relative change
      in the objective value between one iteration and the next is smaller than
      this value, the algorithm is stopped.
    max_iterations: Scalar positive int32 `Tensor`. The maximum number of
      iterations.
    parallel_iterations: Positive integer. The number of iterations allowed to
      run in parallel.
    stopping_condition: (Optional) A Python function that takes as input two
      Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor. The
      input tensors are `converged` and `failed`, indicating the current status
      of each respective batch member; the return value states whether the
      algorithm should stop. The default is tfp.optimizer.converged_all which
      only stops when all batch members have either converged or failed. An
      alternative is tfp.optimizer.converged_any which stops as soon as one
      batch member has converged, or when all have failed.
    params: ConjugateGradientParams object with adjustable parameters of the
      algorithm. If not supplied, default parameters will be used.
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'minimize' is used.

  Returns:
    optimizer_results: A namedtuple containing the following items:
      converged: boolean tensor of shape `[...]` indicating for each batch
        member whether the minimum was found within tolerance.
      failed:  boolean tensor of shape `[...]` indicating for each batch
        member whether a line search step failed to find a suitable step size
        satisfying Wolfe conditions. In the absence of any constraints on the
        number of objective evaluations permitted, this value will
        be the complement of `converged`. However, if there is
        a constraint and the search stopped due to available
        evaluations being exhausted, both `failed` and `converged`
        will be simultaneously False.
      num_objective_evaluations: The total number of objective
        evaluations performed.
      position: A tensor of shape `[..., n]` containing the last argument value
        found during the search from each starting point. If the search
        converged, then this value is the argmin of the objective function.
      objective_value: A tensor of shape `[...]` with the value of the
        objective function at the `position`. If the search converged, then
        this is the (local) minimum of the objective function.
      objective_gradient: A tensor of shape `[..., n]` containing the gradient
        of the objective function at the `position`. If the search converged
        the max-norm of this tensor should be below the tolerance.

  """
  with tf.compat.v1.name_scope(name, 'minimize', [initial_position, tolerance]):
    if params is None:
      params = ConjugateGradientParams()

    initial_position = tf.convert_to_tensor(
        value=initial_position, name='initial_position')
    dtype = initial_position.dtype
    tolerance = tf.convert_to_tensor(
        value=tolerance, dtype=dtype, name='grad_tolerance')
    f_relative_tolerance = tf.convert_to_tensor(
        value=f_relative_tolerance, dtype=dtype, name='f_relative_tolerance')
    x_tolerance = tf.convert_to_tensor(
        value=x_tolerance, dtype=dtype, name='x_tolerance')
    max_iterations = tf.convert_to_tensor(
        value=max_iterations, name='max_iterations')
    stopping_condition = stopping_condition or converged_all
    delta = tf.convert_to_tensor(
        params.sufficient_decrease_param, dtype=dtype, name='delta')
    sigma = tf.convert_to_tensor(
        params.curvature_param, dtype=dtype, name='sigma')
    eps = tf.convert_to_tensor(
        params.threshold_use_approximate_wolfe_condition,
        dtype=dtype,
        name='sigma')
    eta = tf.convert_to_tensor(
        params.direction_update_param, dtype=dtype, name='eta')
    psi_1 = tf.convert_to_tensor(
        params.initial_guess_small_factor, dtype=dtype, name='psi_1')
    psi_2 = tf.convert_to_tensor(
        params.initial_guess_step_multiplier, dtype=dtype, name='psi_2')

    f0, df0 = value_and_gradients_function(initial_position)
    converged = tf.norm(df0, axis=-1) < tolerance

    initial_state = _OptimizerState(
        converged=converged,
        failed=tf.zeros_like(converged),  # All false.
        num_iterations=tf.convert_to_tensor(value=0),
        num_objective_evaluations=tf.convert_to_tensor(value=1),
        position=initial_position,
        objective_value=f0,
        objective_gradient=df0,
        direction=-df0,
        prev_step=tf.ones_like(f0),
    )

    def _cond(state):
      """Continue if iterations remain and stopping condition is not met."""
      return (
          (state.num_iterations < max_iterations)
          & tf.logical_not(stopping_condition(state.converged, state.failed)))

    def _body(state):
      """Main optimization loop."""
      # We use notation of [HZ2006] for brevity.
      x_k = state.position
      d_k = state.direction
      f_k = state.objective_value
      g_k = state.objective_gradient
      a_km1 = state.prev_step  # Means a_{k-1}.

      # Define scalar function, which is objective restricted to direction.
      def ls_func(alpha):
        pt = x_k + tf.expand_dims(alpha, axis=-1) * d_k
        objective_value, gradient = value_and_gradients_function(pt)
        return ValueAndGradient(
            x=alpha,
            f=objective_value,
            df=_dot(gradient, d_k),
            full_gradient=gradient)

      # Generate initial guess for line search.
      # [HZ2006] suggests to generate first initial guess separately, but
      # [JuliaLineSearches] generates it as if previous step length was 1, and
      # we do the same.
      phi_0 = f_k
      dphi_0 = _dot(g_k, d_k)
      ls_val_0 = ValueAndGradient(
          x=tf.zeros_like(phi_0), f=phi_0, df=dphi_0, full_gradient=g_k)
      step_guess_result = _init_step(ls_val_0, a_km1, ls_func, psi_1, psi_2,
                                     params.quad_step)
      init_step = step_guess_result.step

      # Check if initial step size already satisfies Wolfe condition, and in
      # that case don't perform line search.
      c = init_step.x
      phi_lim = phi_0 + eps * tf.abs(phi_0)
      phi_c = init_step.f
      dphi_c = init_step.df
      # Original Wolfe conditions, T1 in [HZ2006].
      suff_decrease_1 = delta * dphi_0 >= (phi_c - phi_0) / c
      curvature = dphi_c >= sigma * dphi_0
      wolfe1 = suff_decrease_1 & curvature
      # Approximate Wolfe conditions, T2 in [HZ2006].
      suff_decrease_2 = (2 * delta - 1) * dphi_0 >= dphi_c
      curvature = dphi_c >= sigma * dphi_0
      wolfe2 = suff_decrease_2 & curvature & (phi_c <= phi_lim)
      wolfe = wolfe1 | wolfe2
      skip_line_search = (step_guess_result.may_terminate
                          & wolfe) | state.failed | state.converged

      # Call Hager-Zhang line search (L0-L3 in [HZ2006]).
      # Parameter theta from [HZ2006] is not adjustable, it's always 0.5.
      ls_result = linesearch.hager_zhang(
          ls_func,
          value_at_zero=ls_val_0,
          converged=skip_line_search,
          initial_step_size=init_step.x,
          value_at_initial_step=init_step,
          shrinkage_param=params.shrinkage_param,
          expansion_param=params.expansion_param,
          sufficient_decrease_param=delta,
          curvature_param=sigma,
          threshold_use_approximate_wolfe_condition=eps)

      # Moving to the next point, using step length from line search.
      # If line search was skipped, take step length from initial guess.
      # To save objective evaluation, use objective value and gradient returned
      # by line search or initial guess.
      a_k = tf.compat.v1.where(
          skip_line_search, init_step.x, ls_result.left.x)
      x_kp1 = state.position + tf.expand_dims(a_k, -1) * d_k
      f_kp1 = tf.compat.v1.where(
          skip_line_search, init_step.f, ls_result.left.f)
      g_kp1 = tf.compat.v1.where(skip_line_search, init_step.full_gradient,
                                 ls_result.left.full_gradient)

      # Evaluate next direction.
      # Use formulas (2.7)-(2.11) from [HZ2013] with P_k=I.
      y_k = g_kp1 - g_k
      d_dot_y = _dot(d_k, y_k)
      b_k = (_dot(y_k, g_kp1) -
             _norm_sq(y_k) * _dot(g_kp1, d_k) / d_dot_y) / d_dot_y
      eta_k = eta * _dot(d_k, g_k) / _norm_sq(d_k)
      b_k = tf.maximum(b_k, eta_k)
      d_kp1 = -g_kp1 + tf.expand_dims(b_k, -1) * d_k

      # Check convergence criteria.
      grad_converged = _norm_inf(g_kp1) <= tolerance
      x_converged = (_norm_inf(x_kp1 - x_k) <= x_tolerance)
      f_converged = (
          tf.math.abs(f_kp1 - f_k) <= f_relative_tolerance * tf.math.abs(f_k))
      converged = grad_converged | x_converged | f_converged

      # Construct new state for next iteration.
      new_state = _OptimizerState(
          converged=converged,
          failed=state.failed,
          num_iterations=state.num_iterations + 1,
          num_objective_evaluations=state.num_objective_evaluations +
          step_guess_result.func_evals + ls_result.func_evals,
          position=tf.compat.v1.where(state.converged, x_k, x_kp1),
          objective_value=tf.compat.v1.where(state.converged, f_k, f_kp1),
          objective_gradient=tf.compat.v1.where(state.converged, g_k, g_kp1),
          direction=d_kp1,
          prev_step=a_k)
      return (new_state,)

    final_state = tf.while_loop(
        _cond, _body, (initial_state,),
        parallel_iterations=parallel_iterations)[0]
    return OptimizerResult(
        converged=final_state.converged,
        failed=final_state.failed,
        num_iterations=final_state.num_iterations,
        num_objective_evaluations=final_state.num_objective_evaluations,
        position=final_state.position,
        objective_value=final_state.objective_value,
        objective_gradient=final_state.objective_gradient)
 def _assert_ops(
     self,
     previous_solver_internal_state,
     initial_state_vec,
     final_time,
     initial_time,
     solution_times,
     max_num_steps,
     max_num_newton_iters,
     atol,
     rtol,
     first_step_size,
     safety_factor,
     min_step_size_factor,
     max_step_size_factor,
     max_order,
     newton_tol_factor,
     newton_step_size_factor,
     solution_times_chosen_by_solver,
 ):
     """Creates a list of assert operations."""
     if not self._validate_args:
         return []
     assert_ops = []
     if previous_solver_internal_state is not None:
         assert_initial_state_matches_previous_solver_internal_state = (
             tf.debugging.assert_near(
                 tf.norm(
                     initial_state_vec -
                     previous_solver_internal_state.backward_differences[0],
                     np.inf),
                 0.,
                 message='`previous_solver_internal_state` does not match '
                 '`initial_state`.'))
         assert_ops.append(
             assert_initial_state_matches_previous_solver_internal_state)
     assert_ops.append(
         util.assert_positive(final_time - initial_time,
                              'final_time - initial_time'))
     if not solution_times_chosen_by_solver:
         assert_ops += [
             util.assert_increasing(solution_times, 'solution_times'),
             util.assert_nonnegative(solution_times[0] - initial_time,
                                     'solution_times[0] - initial_time'),
         ]
     if max_num_steps is not None:
         assert_ops.append(
             util.assert_positive(max_num_steps, 'max_num_steps'))
     if max_num_newton_iters is not None:
         assert_ops.append(
             util.assert_positive(max_num_newton_iters,
                                  'max_num_newton_iters'))
     assert_ops += [
         util.assert_positive(rtol, 'rtol'),
         util.assert_positive(atol, 'atol'),
         util.assert_positive(first_step_size, 'first_step_size'),
         util.assert_positive(safety_factor, 'safety_factor'),
         util.assert_positive(min_step_size_factor, 'min_step_size_factor'),
         util.assert_positive(max_step_size_factor, 'max_step_size_factor'),
         tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [
             '`max_order` must be between 1 and {}.'.format(
                 bdf_util.MAX_ORDER)
         ]),
         util.assert_positive(newton_tol_factor, 'newton_tol_factor'),
         util.assert_positive(newton_step_size_factor,
                              'newton_step_size_factor'),
     ]
     return assert_ops
Exemple #17
0
def _norm(x):
  """Evaluates L2 norm."""
  return tf.norm(x, axis=-1)
Exemple #18
0
def get_dice_pose_results(bounding_boxes, classes, scores, y_rotation_angles, camera_matrix : np.ndarray, distortion_coefficients : np.ndarray, score_threshold : float = 0.5):
    """Estimates pose results for all die, given estimates for bounding box, die (top face) classes, scores and threshold, rotation angles around vertical axes, and camera information."""
    scores_in_threshold = tf.math.greater(scores, score_threshold)
    classes_in_score = tf.boolean_mask(classes, scores_in_threshold)
    boxes_in_scores = tf.boolean_mask(bounding_boxes, scores_in_threshold)
    y_angles_in_scores = tf.boolean_mask(y_rotation_angles, scores_in_threshold)

    classes_are_dots = tf.equal(classes_in_score, 0)
    classes_are_dice = tf.logical_not(classes_are_dots)
    dice_bounding_boxes = tf.boolean_mask(boxes_in_scores, classes_are_dice)
    dice_y_angles = tf.boolean_mask(y_angles_in_scores, classes_are_dice)
    dice_classes = tf.boolean_mask(classes_in_score, classes_are_dice)
    dot_bounding_boxes = tf.boolean_mask(boxes_in_scores, classes_are_dots)

    dot_centers = _get_dot_centers(dot_bounding_boxes)
    dot_sizes = _get_dot_sizes(dot_bounding_boxes)

    #NB Largest box[2] is the box lower bound 
    dice_bb_lower_y = dice_bounding_boxes[:,2]
    dice_indices = tf.argsort(dice_bb_lower_y, axis = -1, direction='DESCENDING')

    def get_area(bb):
        return tf.math.maximum(bb[:, 3] - bb[:, 1], 0) * tf.math.maximum(bb[:, 2] - bb[:, 0], 0)

    dice_indices_np = dice_indices.numpy()
    bounding_box_pose_results = [_get_die_image_bounding_box_pose(dice_bounding_boxes[index, :], camera_matrix, distortion_coefficients) for index in dice_indices_np]
    approximate_dice_up_vector_pyrender = _get_approximate_dice_up_vector(bounding_box_pose_results, in_pyrender_coords=True)
    pose_results = []
    for index, bounding_box_pose_result in zip(dice_indices_np, bounding_box_pose_results):
        die_box = dice_bounding_boxes[index, :]
        die_y_angle = dice_y_angles[index]
        die_class = dice_classes[index]

        die_box_size = (-die_box[0:2] + die_box[2:4])
        dot_centers_fraction_of_die_box = (dot_centers - die_box[0:2]) / die_box_size
        dot_centers_rounded_rectangle_distance = tf.norm(tf.math.maximum(tf.math.abs(dot_centers_fraction_of_die_box - 0.5) - 0.5 + rounded_rectangle_radius,0.0), axis = -1) - rounded_rectangle_radius
        dots_are_in_rounded_rectangle = dot_centers_rounded_rectangle_distance < 0

        dot_bb_intersection_left = tf.math.maximum(dot_bounding_boxes[:, 1], die_box[1])
        dot_bb_intersection_right = tf.math.minimum(dot_bounding_boxes[:, 3], die_box[3])
        dot_bb_intersection_top = tf.math.maximum(dot_bounding_boxes[:, 0], die_box[0])
        dot_bb_intersection_bottom = tf.math.minimum(dot_bounding_boxes[:, 2], die_box[2])
        dot_bb_intersection = tf.stack([dot_bb_intersection_top, dot_bb_intersection_left, dot_bb_intersection_bottom, dot_bb_intersection_right], axis = 1)
        dot_bb_intersection_area = get_area(dot_bb_intersection)
        dot_bb_area = get_area(dot_bounding_boxes)
        dot_bb_intersection_over_area = dot_bb_intersection_area / dot_bb_area
        dots_have_sufficient_bb_intersection_over_area = tf.greater(dot_bb_intersection_over_area, 0.9)
        
        dots_are_in_box = tf.logical_and(dots_have_sufficient_bb_intersection_over_area, dots_are_in_rounded_rectangle)

        dot_centers_in_box = tf.boolean_mask(dot_centers, dots_are_in_box)
        dot_centers_cv = _convert_tensorflow_points_to_opencv(dot_centers_in_box)
        die_pose_result = _get_die_pose(die_box, die_class, die_y_angle, dot_centers_cv, bounding_box_pose_result, approximate_dice_up_vector_pyrender, camera_matrix, distortion_coefficients)
        die_pose_result.calculate_comparison(dot_centers_cv, camera_matrix, distortion_coefficients)
        die_pose_result.calculate_inliers(_convert_tensorflow_points_to_opencv(dot_sizes))
        pose_results.append(die_pose_result)

        indices_in_box = tf.where(dots_are_in_box)
        inlier_indices_in_box = tf.gather(indices_in_box, die_pose_result.comparison_inlier_indices)
        dot_centers = _delete_tf(dot_centers, inlier_indices_in_box)
        dot_sizes = _delete_tf(dot_sizes, inlier_indices_in_box)
        dot_bounding_boxes = _delete_tf(dot_bounding_boxes, inlier_indices_in_box)

    return pose_results