示例#1
0
    def _prepare_common_params(self, ode_fn, initial_state, initial_time):
        error_if_wrong_dtype = functools.partial(
            util.error_if_not_real_or_complex, identifier='initial_state')

        initial_state = tf.nest.map_structure(tf.convert_to_tensor,
                                              initial_state)
        tf.nest.map_structure(error_if_wrong_dtype, initial_state)

        state_shape = tf.nest.map_structure(ps.shape, initial_state)
        common_state_dtype = dtype_util.common_dtype(initial_state)
        real_dtype = dtype_util.real_dtype(common_state_dtype)
        # Use tf.cast instead of tf.convert_to_tensor for differentiable
        # parameters because the tf.custom_gradient decorator converts raw floats
        # into tf.float32, which cannot be converted to tf.float64.
        initial_time = tf.cast(initial_time, real_dtype)
        if self._validate_args:
            initial_time = tf.ensure_shape(initial_time, [])

        rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype)
        atol = tf.convert_to_tensor(self._atol, dtype=real_dtype)
        safety_factor = tf.convert_to_tensor(self._safety_factor,
                                             dtype=real_dtype)

        if self._validate_args:
            safety_factor = tf.ensure_shape(safety_factor, [])

        # Convert everything to operate on a single, concatenated vector form.
        initial_state_vec = util.get_state_vec(initial_state)
        ode_fn_vec = util.get_ode_fn_vec(ode_fn, state_shape)
        num_odes = tf.size(initial_state_vec)

        return util.Bunch(
            initial_state=initial_state,
            initial_time=initial_time,
            common_state_dtype=common_state_dtype,
            real_dtype=real_dtype,
            rtol=rtol,
            atol=atol,
            safety_factor=safety_factor,
            state_shape=state_shape,
            initial_state_vec=initial_state_vec,
            ode_fn_vec=ode_fn_vec,
            num_odes=num_odes,
        )
        def _example_parser(
                example: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
            """Parse features and labels from a serialized tf.train.Example."""
            features_spec = _make_features_spec(self.load_domain_label)
            features = tf.io.parse_single_example(example['features'],
                                                  features_spec)

            sys_utt = tf.io.parse_tensor(features[SYS_UTT_NAME],
                                         out_type=tf.int32)
            usr_utt = tf.io.parse_tensor(features[USR_UTT_NAME],
                                         out_type=tf.int32)
            sys_utt_raw = tf.io.parse_tensor(features[SYS_UTT_RAW_NAME],
                                             out_type=tf.string)
            usr_utt_raw = tf.io.parse_tensor(features[USR_UTT_RAW_NAME],
                                             out_type=tf.string)
            state_label = tf.io.parse_tensor(features[STATE_LABEL_NAME],
                                             out_type=tf.int32)
            dialog_len = features[DIAL_LEN_NAME]

            # Extract maxmimum dialog and utterance lengths.
            max_dialog_len = MAX_DIALOG_LEN[self.name]
            max_utt_len = MAX_UTT_LEN[self.name]

            # Ensure shape of parsed tensors.
            sys_utt = tf.ensure_shape(sys_utt, (max_dialog_len, max_utt_len))
            usr_utt = tf.ensure_shape(usr_utt, (max_dialog_len, max_utt_len))
            sys_utt_raw = tf.ensure_shape(sys_utt_raw, (max_dialog_len, ))
            usr_utt_raw = tf.ensure_shape(usr_utt_raw, (max_dialog_len, ))
            state_label = tf.ensure_shape(state_label, (max_dialog_len, ))

            parsed_example = {
                SYS_UTT_NAME: sys_utt,
                USR_UTT_NAME: usr_utt,
                USR_UTT_RAW_NAME: usr_utt_raw,
                SYS_UTT_RAW_NAME: sys_utt_raw,
                STATE_LABEL_NAME: state_label,
                DIAL_LEN_NAME: dialog_len,
            }

            # Optionally, load domain labels.
            if self.load_domain_label:
                domain_label = tf.io.parse_tensor(features[DOMAIN_LABEL_NAME],
                                                  out_type=tf.int32)
                domain_label = tf.ensure_shape(domain_label,
                                               (max_dialog_len, ))
                parsed_example[DOMAIN_LABEL_NAME] = domain_label

            if self.add_dialog_turn_id:
                example_id = example[self._fingerprint_key]
                dialog_turn_id = tf.range(example_id * max_dialog_len,
                                          (example_id + 1) * max_dialog_len,
                                          dtype=tf.int32)
                dialog_turn_id = tf.ensure_shape(dialog_turn_id,
                                                 (max_dialog_len))
                parsed_example[DIAL_TURN_ID_NAME] = dialog_turn_id

            return parsed_example
示例#3
0
    def _prepare_coefficients(self, dtype):
        bdf_coefficients = tf.concat(
            [[0.], tf.cast(self._bdf_coefficients, dtype=dtype)], 0)
        if self._validate_args:
            bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6])
        util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients')
        newton_coefficients = 1. / (
            (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS)
        error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / (
            bdf_util.ORDERS + 1)

        return newton_coefficients, error_coefficients
示例#4
0
    def _initialize_solver_internal_state(
        self,
        ode_fn,
        initial_time,
        initial_state,
    ):
        p = self._prepare_common_params(
            ode_fn=ode_fn,
            initial_state=initial_state,
            initial_time=initial_time,
        )

        first_step_size = self._first_step_size
        if first_step_size is None:
            _, error_coefficients = self._prepare_coefficients(
                p.common_state_dtype)
            first_step_size = bdf_util.first_step_size(
                p.atol, error_coefficients[1], p.initial_state_vec,
                p.initial_time, p.ode_fn_vec, p.rtol, p.safety_factor)
        first_step_size = tf.convert_to_tensor(first_step_size,
                                               dtype=p.real_dtype)
        if self._validate_args:
            first_step_size = tf.ensure_shape(first_step_size, [])

        first_order_backward_difference = p.ode_fn_vec(
            p.initial_time, p.initial_state_vec) * tf.cast(
                first_step_size, p.common_state_dtype)
        backward_differences = tf.concat(
            [
                p.initial_state_vec[tf.newaxis, :],
                first_order_backward_difference[tf.newaxis, :],
                tf.zeros(ps.stack([bdf_util.MAX_ORDER + 1, p.num_odes]),
                         dtype=p.common_state_dtype),
            ],
            axis=0,
        )
        return _BDFSolverInternalState(
            backward_differences=backward_differences,
            order=tf.ones([], tf.int32),
            step_size=first_step_size)
示例#5
0
def contrastive_loss(features,
                     labels=None,
                     temperature=1.0,
                     contrast_mode=enums.LossContrastMode.ALL_VIEWS,
                     summation_location=enums.LossSummationLocation.OUTSIDE,
                     denominator_mode=enums.LossDenominatorMode.ALL,
                     positives_cap=-1,
                     scale_by_temperature=True):
    r"""Contrastive loss over features.

  Implemented as described in: https://arxiv.org/abs/2004.11362, Equation 2.

  Given `num_views` different views of each of `batch_size` samples, let `f_i`
  (i \in [1, 2 ... (num_views * batch_size)]) denote each respective feature
  vector. The contrastive loss then takes the following form:

    L = \sum_{i} L_i

  where each L_i is computed as:

    L_i = -\tau * \sum_{k \in P(i)} \log(p_{ik})    (1)

  where P(i) is the set of positives for entry i (distinct from i) and where:

                       \exp(f_i^T f_k / \tau)
    p_{ik} = ----------------------------------------                        (2)
             \sum_{j \in A(i)} \exp(f_i^T f_j / \tau)

  where A(i) is the set of all positives or negatives (distinct from i). `i` is
  the anchor, and \tau is the temperature.

  This maximizes the likelihood of a given (anchor, positive) pair with
  respect to all possible pairs where the first member is the anchor and the
  second member is a positive or a negative.

  A typical way to define a positive is to define samples from the
  same class (but not the anchor itself) regardless of what view they are from.
  Similarly, a typical way to define a negative is for it to be any view of a
  sample from a different class.

  There are two ways to define which feature pairs should be treated as
  positives and negatives. All views of the same sample are always treated as
  positives. You can declare other samples to be positives by providing `labels`
  such that all samples with the same label will be positives for each other.

  If `labels` is not provided then we default to every sample belonging to its
  own unique class. Therefore, the only positive used is another view of the
  anchor itself. This implements the loss as described in:

    https://arxiv.org/pdf/2002.05709.pdf
    A Simple Framework for Contrastive Learning of Visual Representations
    Chen T., Kornblith S., Norouzi M., Hinton G.

  It is recommended to use features whose L_2 norm is 1. since that ensures
  that the loss does not return NaN values without changing the intended
  behaviour of the loss function.

  In (1) above, note that the summation over positives is located outside of the
  \log(). However, one can permute these two operations. The result is Eq. 3 in
  https://arxiv.org/abs/2004.11362. Users can specify the location of the
  summation relative to the \log() via the `summation_location' argmument:
   - 'out': Eq. 2 in https://arxiv.org/abs/2004.11362.
   - 'in' : Eq. 3 in https://arxiv.org/abs/2004.11362.

  Additionally, in (2) above, note that the denominator sums over *all* entries
  distinct from i. One can change which terms are included in the denominator
  via the `denominator_mode` argument:
   - LossDenominatorMode.ALL : All entries (i.e., all negatives and all
             positives) distinct from i are included.
   - LossDenominatorMode.ONE_POSITIVE : All negatives are included but only the
             single positive in the numerator of (2) is included. Any other
             positives are excluded.
   - LossDenominatorMode.ONLY_NEGATIVES: All negatives are included but no
             positives are, not even the single positive in the numerator of
             (2).

  On TPUs, this method will internally perform the cross-replica operations that
  enable using the samples from all cores in computing the loss. The inputs to
  this function should be the features and labels from a single core and each
  core will compute the loss using just these features as anchors, but will use
  positives and negatives from the full global batch. Since the loss for each
  anchor is only computed on one TPU core, it's still necessary to have a
  cross-replica reduction in the final loss computation.

  Also, though it is not applicable to multiview contrastive learning, this
  function will work if |features| contains only 1 view. In the high batch size
  limit, the implemented contrastive loss with only 1 view, positives_cap = 1,
  and temperature = 1.0 is equivalent to the N-pairs loss
  (https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective.pdf)

  Args:
    features: A Tensor of rank at least 3, where the first 2 dimensions are
      batch_size and num_views, and the remaining dimensions are the feature
      shape. Note that when running on TPU, batch_size is the per-core batch
      size.
    labels: One-hot labels to be used to construct the supervised contrastive
      loss. Samples with the same labels are used as positives for each other.
      Labels must have shape [batch_size, num_labels] with numeric dtype and be
      0-1 valued. Note that when running on TPU, batch_size is the per-core
      batch size.
    temperature: Temperature at which softmax evaluation is done. Temperature
      must be a python scalar or scalar Tensor of numeric dtype.
    contrast_mode: LossContrastMode specifying which views get used as anchors
      (f_i in the expression above)
      'ALL_VIEWS': All the views of all samples are used as anchors (f_i in the
        expression above).
      'ONE_VIEW': Just the first view of each sample is used as an anchor (f_i
        in the expression above). This view is called the `core` view against
        which other views are contrasted.
    summation_location: LossSummationLocation specifying location of positives
      summation. See documentation above for more details.
    denominator_mode: LossDenominatorMode specifying which positives to include
      in contrastive denominator. See documentation above for more details.
    positives_cap: Integer maximum number of positives *other* than
      augmentations of anchor. Infinite if < 0. Must be multiple of num_views.
      Including augmentations, a maximum of (positives_cap + num_views - 1)
      positives is possible. This parameter modifies the contrastive numerator
      by selecting which positives are present in the summation, and which
      positives contribure to the denominator if denominator_mode ==
      enums.LossDenominatorMode.ALL.
    scale_by_temperature: Boolean. Whether to scale the loss by `temperature`.
      The loss gradient naturally has a 1/temperature scaling factor, so this
      counteracts it.

  Returns:
    Scalar tensor with contrastive loss value with shape [batch_size] and dtype
    tf.float32. The loss for each batch element is the mean over all views.

  Raises:
    ValueError if the shapes of any of the Tensors are unexpected, or if both
    `labels` and `mask` are not `None`.
  """
    features = tf.convert_to_tensor(features)
    labels = tf.convert_to_tensor(labels) if labels is not None else None

    local_batch_size, num_views = _validate_contrastive_loss_inputs(
        features, labels, contrast_mode, summation_location, denominator_mode,
        positives_cap)

    # Flatten `features` to a single dimension per view per sample so it has shape
    # [local_batch_size, num_views, num_features].
    if features.shape.rank > 3:
        features = tf.reshape(
            features, tf.concat([tf.shape(features)[:2], [-1]], axis=0),
            'flattened_features')
    if features.dtype != tf.float32:
        features = tf.cast(features, tf.float32)

    # Grab the features from all TPU cores. We use the local batch as anchors and
    # the full global batch as contrastives. If not on TPU, global_features is the
    # same as features.
    global_features = utils.cross_replica_concat(features)
    global_batch_size = tf.compat.dimension_at_index(global_features.shape,
                                                     0).value
    local_replica_id = utils.local_tpu_replica_id()

    # Generate the [local_batch_size, global_batch_size] slice of the
    # [global_batch_size, global_batch_size] identity matrix that corresponds to
    # the current replica.
    diagonal_mask = tf.one_hot(
        tf.range(local_batch_size) + (local_replica_id * local_batch_size),
        global_batch_size)

    # Generate `mask` with shape [local_batch_size, global_batch_size] that
    # indicates which samples should be considered positives for each other.
    if labels is None:
        # Defaults to every sample belonging to its own unique class, containing
        # just that sample and other views of it.
        mask = diagonal_mask
    else:
        labels = tf.cast(labels,
                         tf.float32)  # TPU matmul op unsupported for ints.
        global_labels = utils.cross_replica_concat(labels)
        mask = tf.linalg.matmul(labels, global_labels, transpose_b=True)
    mask = tf.ensure_shape(mask, [local_batch_size, global_batch_size])

    # To streamline the subsequent TF, the first two dimensions of
    # `global_features` (i.e., global_batch_size and num_views) should be
    # transposed and then flattened. The result has shape
    # [num_views * global_batch_size, num_features], and its first dimension
    # elements are grouped by view, not by sample.
    all_global_features = tf.reshape(
        tf.transpose(global_features, perm=[1, 0, 2]),
        [num_views * global_batch_size, -1])

    if contrast_mode == enums.LossContrastMode.ONE_VIEW:
        anchor_features = features[:, 0]
        num_anchor_views = 1
    else:  # contrast_mode == enums.LossContrastMode.ALL_VIEWS
        # Reshape features to match how global_features is reshaped above.
        anchor_features = tf.reshape(tf.transpose(features, perm=[1, 0, 2]),
                                     [num_views * local_batch_size, -1])
        num_anchor_views = num_views

    # Generate `logits`, the tensor of (temperature-scaled) dot products of the
    # anchor features with all features. It has shape
    # [local_batch_size * num_anchor_views, global_batch_size * num_views]. To
    # improve numerical stability, subtract out the largest |logits| element in
    # each row from all elements in that row. Since |logits| is only ever used as
    # a ratio of exponentials of |logits| values, this subtraction does not change
    # the results correctness. A stop_gradient() is needed because this change is
    # just for numerical precision.
    logits = tf.linalg.matmul(anchor_features,
                              all_global_features,
                              transpose_b=True)
    temperature = tf.cast(temperature, tf.float32)
    logits = logits / temperature
    logits = (logits -
              tf.reduce_max(tf.stop_gradient(logits), axis=1, keepdims=True))
    exp_logits = tf.exp(logits)

    # The following masks are all tiled by the number of views, i.e., they have
    # shape [local_batch_size * num_anchor_views, global_batch_size * num_views].
    positives_mask, negatives_mask = (_create_tiled_masks(
        mask, diagonal_mask, num_views, num_anchor_views, positives_cap))
    num_positives_per_row = tf.reduce_sum(positives_mask, axis=1)

    if denominator_mode == enums.LossDenominatorMode.ALL:
        denominator = tf.reduce_sum(
            exp_logits * negatives_mask, axis=1,
            keepdims=True) + tf.reduce_sum(
                exp_logits * positives_mask, axis=1, keepdims=True)
    elif denominator_mode == enums.LossDenominatorMode.ONE_POSITIVE:
        denominator = exp_logits + tf.reduce_sum(
            exp_logits * negatives_mask, axis=1, keepdims=True)
    else:  # denominator_mode == enums.LossDenominatorMode.ONLY_NEGATIVES
        denominator = tf.reduce_sum(exp_logits * negatives_mask,
                                    axis=1,
                                    keepdims=True)

    # Note that num_positives_per_row can be zero only if 1 view is used. The
    # various tf.math.divide_no_nan() calls below are to handle this case.
    if summation_location == enums.LossSummationLocation.OUTSIDE:
        log_probs = (logits - tf.math.log(denominator)) * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
    else:  # summation_location == enums.LossSummationLocation.INSIDE
        log_probs = exp_logits / denominator * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
        log_probs = tf.math.log(log_probs)

    loss = -log_probs
    if scale_by_temperature:
        loss *= temperature
    loss = tf.reshape(loss, [num_anchor_views, local_batch_size])

    if num_views != 1:
        loss = tf.reduce_mean(loss, axis=0)
    else:
        # The 1 view case requires special handling bc, unlike in the > 1 view case,
        # not all samples are guaranteed to have a positive. Also, no reduction over
        # views is needed.
        num_valid_views_per_sample = (tf.reshape(num_positives_per_row,
                                                 [1, local_batch_size]))
        loss = tf.squeeze(
            tf.math.divide_no_nan(loss, num_valid_views_per_sample))

    return loss
示例#6
0
    def _solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
    ):
        # This function is comprised of the following sequential stages:
        # (1) Make static assertions.
        # (2) Initialize variables.
        # (3) Make non-static assertions.
        # (4) Solve up to final time.
        # (5) Return `Results` object.
        #
        # The stages can be found in the code by searching for (n) where n=1..5.
        #
        # By static vs. non-static assertions (see stages 1 and 3), we mean
        # assertions that can be made before the graph is run vs. those that can
        # only be made at run time. The latter are constructed as a list of
        # tf.Assert operations by the function `assert_ops` (see below).
        #
        # If `solution_times` is specified as a `Tensor`, stage 4 consists of three
        # nested loops, which can be conceptually understood as follows:
        # ```
        # current_time, current_state = initial_time, initial_state
        # order, step_size = 1, first_step_size
        # for solution_time in solution_times:
        #   while current_time < solution_time:
        #     while True:
        #       next_time = current_time + step_size
        #       next_state, error = (
        #           solve_nonlinear_equation_to_get_approximate_state_at_next_time(
        #           current_time, current_state, next_time, order))
        #       if error < tolerance:
        #         current_time, current_state = next_time, next_state
        #         order, step_size = (
        #           maybe_update_order_and_step_size(order, step_size))
        #         break
        #       else:
        #         step_size = decrease_step_size(step_size)
        # ```
        # The outermost loop advances the solver to the next `solution_time` (see
        # `advance_to_solution_time`). The middle loop advances the solver by a
        # small timestep (see `step`). The innermost loop determines the size of
        # that timestep (see `maybe_step`).
        #
        # If `solution_times` is specified as
        # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped
        # and `solution_time` in the middle loop is replaced by `final_time`.

        def advance_to_solution_time(n, diagnostics, iterand,
                                     solver_internal_state, state_vec_array,
                                     time_array):
            """Takes multiple steps to advance time to `solution_times[n]`."""
            def step_cond(next_time, diagnostics, iterand, *_):
                return (iterand.time < next_time) & (tf.equal(
                    diagnostics.status, 0))

            nth_solution_time = solution_time_array.read(n)
            [
                _, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ] = tf.while_loop(step_cond, step, [
                nth_solution_time, diagnostics, iterand, solver_internal_state,
                state_vec_array, time_array
            ])
            state_vec_array = state_vec_array.write(
                n, solver_internal_state.backward_differences[0])
            time_array = time_array.write(n, nth_solution_time)
            return (n + 1, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def step(next_time, diagnostics, iterand, solver_internal_state,
                 state_vec_array, time_array):
            """Takes a single step."""
            distance_to_next_time = next_time - iterand.time
            overstepped = iterand.new_step_size > distance_to_next_time
            iterand = iterand._replace(new_step_size=tf1.where(
                overstepped, distance_to_next_time, iterand.new_step_size),
                                       should_update_step_size=overstepped
                                       | iterand.should_update_step_size)

            if not self._evaluate_jacobian_lazily:
                diagnostics = diagnostics._replace(
                    num_jacobian_evaluations=diagnostics.
                    num_jacobian_evaluations + 1)
                iterand = iterand._replace(jacobian_mat=jacobian_fn_mat(
                    iterand.time,
                    solver_internal_state.backward_differences[0]),
                                           jacobian_is_up_to_date=True)

            def maybe_step_cond(accepted, diagnostics, *_):
                return tf.logical_not(accepted) & tf.equal(
                    diagnostics.status, 0)

            _, diagnostics, iterand, solver_internal_state = tf.while_loop(
                maybe_step_cond, maybe_step,
                [False, diagnostics, iterand, solver_internal_state])

            if solution_times_chosen_by_solver:
                state_vec_array = state_vec_array.write(
                    state_vec_array.size(),
                    solver_internal_state.backward_differences[0])
                time_array = time_array.write(time_array.size(), iterand.time)

            return (next_time, diagnostics, iterand, solver_internal_state,
                    state_vec_array, time_array)

        def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
            """Takes a single step only if the outcome has a low enough error."""
            [
                num_jacobian_evaluations, num_matrix_factorizations,
                num_ode_fn_evaluations, status
            ] = diagnostics
            [
                jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps,
                num_steps_same_size, should_update_jacobian,
                should_update_step_size, time, unitary, upper
            ] = iterand
            [backward_differences, order, step_size] = solver_internal_state

            if max_num_steps is not None:
                status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

            backward_differences = tf1.where(
                should_update_step_size,
                bdf_util.interpolate_backward_differences(
                    backward_differences, order, new_step_size / step_size),
                backward_differences)
            step_size = tf1.where(should_update_step_size, new_step_size,
                                  step_size)
            should_update_factorization = should_update_step_size
            num_steps_same_size = tf1.where(should_update_step_size, 0,
                                            num_steps_same_size)

            def update_factorization():
                return bdf_util.newton_qr(
                    jacobian_mat, newton_coefficients_array.read(order),
                    step_size)

            if self._evaluate_jacobian_lazily:

                def update_jacobian_and_factorization():
                    new_jacobian_mat = jacobian_fn_mat(time,
                                                       backward_differences[0])
                    new_unitary, new_upper = update_factorization()
                    return [
                        new_jacobian_mat, True, num_jacobian_evaluations + 1,
                        new_unitary, new_upper
                    ]

                def maybe_update_factorization():
                    new_unitary, new_upper = tf.cond(
                        should_update_factorization, update_factorization,
                        lambda: [unitary, upper])
                    return [
                        jacobian_mat, jacobian_is_up_to_date,
                        num_jacobian_evaluations, new_unitary, new_upper
                    ]

                [
                    jacobian_mat, jacobian_is_up_to_date,
                    num_jacobian_evaluations, unitary, upper
                ] = tf.cond(should_update_jacobian,
                            update_jacobian_and_factorization,
                            maybe_update_factorization)
            else:
                unitary, upper = update_factorization()
                num_matrix_factorizations += 1

            tol = p.atol + p.rtol * tf.abs(backward_differences[0])
            newton_tol = newton_tol_factor * tf.norm(tol)

            [
                newton_converged, next_backward_difference, next_state_vec,
                newton_num_iters
            ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                                newton_coefficients_array.read(order),
                                p.ode_fn_vec, order, step_size, time,
                                newton_tol, unitary, upper)
            num_steps += 1
            num_ode_fn_evaluations += newton_num_iters

            # If Newton's method failed and the Jacobian was up to date, decrease the
            # step size.
            newton_failed = tf.logical_not(newton_converged)
            should_update_step_size = newton_failed & jacobian_is_up_to_date
            new_step_size = step_size * tf1.where(should_update_step_size,
                                                  newton_step_size_factor, 1.)

            # If Newton's method failed and the Jacobian was NOT up to date, update
            # the Jacobian.
            should_update_jacobian = newton_failed & tf.logical_not(
                jacobian_is_up_to_date)

            error_ratio = tf1.where(
                newton_converged,
                bdf_util.error_ratio(next_backward_difference,
                                     error_coefficients_array.read(order),
                                     tol), np.nan)
            accepted = error_ratio < 1.
            converged_and_rejected = newton_converged & tf.logical_not(
                accepted)

            # If Newton's method converged but the solution was NOT accepted, decrease
            # the step size.
            new_step_size = tf1.where(
                converged_and_rejected,
                util.next_step_size(step_size, order, error_ratio,
                                    p.safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = should_update_step_size | converged_and_rejected

            # If Newton's method converged and the solution was accepted, update the
            # matrix of backward differences.
            time = tf1.where(accepted, time + step_size, time)
            backward_differences = tf1.where(
                accepted,
                bdf_util.update_backward_differences(backward_differences,
                                                     next_backward_difference,
                                                     next_state_vec, order),
                backward_differences)
            jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(
                accepted)
            num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                            num_steps_same_size)

            # Order and step size are only updated if we have taken strictly more than
            # order + 1 steps of the same size. This is to prevent the order from
            # being throttled.
            should_update_order_and_step_size = accepted & (num_steps_same_size
                                                            > order + 1)

            backward_differences_array = tf.TensorArray(
                backward_differences.dtype,
                size=bdf_util.MAX_ORDER + 3,
                clear_after_read=False,
                element_shape=next_backward_difference.shape).unstack(
                    backward_differences)
            new_order = order
            new_error_ratio = error_ratio
            for offset in [-1, +1]:
                proposed_order = tf.clip_by_value(order + offset, 1, max_order)
                proposed_error_ratio = bdf_util.error_ratio(
                    backward_differences_array.read(proposed_order + 1),
                    error_coefficients_array.read(proposed_order), tol)
                proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
                new_order = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_order, new_order)
                new_error_ratio = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_error_ratio,
                    new_error_ratio)
            order = new_order
            error_ratio = new_error_ratio

            new_step_size = tf1.where(
                should_update_order_and_step_size,
                util.next_step_size(step_size, order, error_ratio,
                                    p.safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = (should_update_step_size
                                       | should_update_order_and_step_size)

            diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                          num_matrix_factorizations,
                                          num_ode_fn_evaluations, status)
            iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date,
                                  new_step_size, num_steps,
                                  num_steps_same_size, should_update_jacobian,
                                  should_update_step_size, time, unitary,
                                  upper)
            solver_internal_state = _BDFSolverInternalState(
                backward_differences, order, step_size)
            return accepted, diagnostics, iterand, solver_internal_state

        # (1) Make static assertions.
        # TODO(b/138304296): Support specifying Jacobian sparsity patterns.
        if jacobian_sparsity is not None:
            raise NotImplementedError(
                'The BDF solver does not support specifying '
                'Jacobian sparsity patterns.')
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'The BDF solver does not support batching.')
        solution_times_chosen_by_solver = (isinstance(solution_times,
                                                      base.ChosenBySolver))

        with tf.name_scope(self._name):

            # (2) Convert to tensors.
            p = self._prepare_common_params(
                ode_fn=ode_fn,
                initial_state=initial_state,
                initial_time=initial_time,
            )

            if jacobian_fn is None and dtype_util.is_complex(
                    p.common_state_dtype):
                raise NotImplementedError(
                    'The BDF solver does not support automatic '
                    'Jacobian computations for complex dtypes.')

            # Convert everything to operate on a single, concatenated vector form.
            jacobian_fn_mat = util.get_jacobian_fn_mat(
                jacobian_fn,
                p.ode_fn_vec,
                p.state_shape,
                dtype=p.common_state_dtype,
            )

            num_solution_times = 0
            if solution_times_chosen_by_solver:
                final_time = tf.cast(solution_times.final_time, p.real_dtype)
            else:
                solution_times = tf.cast(solution_times, p.real_dtype)
                final_time = tf.reduce_max(solution_times)
                num_solution_times = tf.size(solution_times)
                solution_time_array = tf.TensorArray(
                    solution_times.dtype,
                    size=num_solution_times,
                    element_shape=[]).unstack(solution_times)
                util.error_if_not_vector(solution_times, 'solution_times')
            min_step_size_factor = tf.convert_to_tensor(
                self._min_step_size_factor, dtype=p.real_dtype)
            max_step_size_factor = tf.convert_to_tensor(
                self._max_step_size_factor, dtype=p.real_dtype)
            max_num_steps = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
            max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32)
            max_num_newton_iters = self._max_num_newton_iters
            if max_num_newton_iters is not None:
                max_num_newton_iters = tf.convert_to_tensor(
                    max_num_newton_iters, dtype=tf.int32)
            newton_tol_factor = tf.convert_to_tensor(self._newton_tol_factor,
                                                     dtype=p.real_dtype)
            newton_step_size_factor = tf.convert_to_tensor(
                self._newton_step_size_factor, dtype=p.real_dtype)
            newton_coefficients, error_coefficients = self._prepare_coefficients(
                p.common_state_dtype)
            if self._validate_args:
                final_time = tf.ensure_shape(final_time, [])
                min_step_size_factor = tf.ensure_shape(min_step_size_factor,
                                                       [])
                max_step_size_factor = tf.ensure_shape(max_step_size_factor,
                                                       [])
                if max_num_steps is not None:
                    max_num_steps = tf.ensure_shape(max_num_steps, [])
                max_order = tf.ensure_shape(max_order, [])
                if max_num_newton_iters is not None:
                    max_num_newton_iters = tf.ensure_shape(
                        max_num_newton_iters, [])
                newton_tol_factor = tf.ensure_shape(newton_tol_factor, [])
                newton_step_size_factor = tf.ensure_shape(
                    newton_step_size_factor, [])
            newton_coefficients_array = tf.TensorArray(
                newton_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(newton_coefficients)
            error_coefficients_array = tf.TensorArray(
                error_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(error_coefficients)
            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                solver_internal_state = self._initialize_solver_internal_state(
                    ode_fn=ode_fn,
                    initial_state=initial_state,
                    initial_time=initial_time,
                )
            state_vec_array = tf.TensorArray(
                p.common_state_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=p.initial_state_vec.shape)
            time_array = tf.TensorArray(
                p.real_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=tf.TensorShape([]))
            diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0,
                                          num_matrix_factorizations=0,
                                          num_ode_fn_evaluations=0,
                                          status=0)
            iterand = _BDFIterand(
                jacobian_mat=tf.zeros([p.num_odes, p.num_odes],
                                      dtype=p.common_state_dtype),
                jacobian_is_up_to_date=False,
                new_step_size=solver_internal_state.step_size,
                num_steps=0,
                num_steps_same_size=0,
                should_update_jacobian=True,
                should_update_step_size=False,
                time=p.initial_time,
                unitary=tf.zeros([p.num_odes, p.num_odes],
                                 dtype=p.common_state_dtype),
                upper=tf.zeros([p.num_odes, p.num_odes],
                               dtype=p.common_state_dtype),
            )

            # (3) Make non-static assertions.
            with tf.control_dependencies(
                    self._assert_ops(
                        previous_solver_internal_state=
                        previous_solver_internal_state,
                        initial_state_vec=p.initial_state_vec,
                        final_time=final_time,
                        initial_time=p.initial_time,
                        solution_times=solution_times,
                        max_num_steps=max_num_steps,
                        max_num_newton_iters=max_num_newton_iters,
                        atol=p.atol,
                        rtol=p.rtol,
                        first_step_size=solver_internal_state.step_size,
                        safety_factor=p.safety_factor,
                        min_step_size_factor=min_step_size_factor,
                        max_step_size_factor=max_step_size_factor,
                        max_order=max_order,
                        newton_tol_factor=newton_tol_factor,
                        newton_step_size_factor=newton_step_size_factor,
                        solution_times_chosen_by_solver=
                        solution_times_chosen_by_solver,
                    )):

                # (4) Solve up to final time.
                if solution_times_chosen_by_solver:

                    def step_cond(next_time, diagnostics, iterand, *_):
                        return (iterand.time < next_time) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(step_cond, step, [
                        final_time, diagnostics, iterand,
                        solver_internal_state, state_vec_array, time_array
                    ])

                else:

                    def advance_to_solution_time_cond(n, diagnostics, *_):
                        return (n < num_solution_times) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        state_vec_array, time_array
                    ] = tf.while_loop(
                        advance_to_solution_time_cond,
                        advance_to_solution_time, [
                            0, diagnostics, iterand, solver_internal_state,
                            state_vec_array, time_array
                        ])

                # (6) Return `Results` object.
                states = util.get_state_from_vec(state_vec_array.stack(),
                                                 p.state_shape)
                times = time_array.stack()
                if not solution_times_chosen_by_solver:
                    tensorshape_util.set_shape(times, solution_times.shape)
                    tf.nest.map_structure(
                        lambda s, ini_s: tensorshape_util.set_shape(  # pylint: disable=g-long-lambda
                            s,
                            tensorshape_util.concatenate(
                                solution_times.shape, ini_s.shape)),
                        states,
                        p.initial_state)
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)
示例#7
0
    def solve(self,
              ode_fn,
              initial_time,
              initial_state,
              solution_times,
              jacobian_fn=None,
              jacobian_sparsity=None,
              batch_ndims=None,
              previous_solver_internal_state=None):
        """See `tfp.math.ode.Solver.solve`."""

        # The `solve` function is comprised of the following sequential stages:
        # (1) Make static assertions.
        # (2) Initialize variables.
        # (3) Make non-static assertions.
        # (4) Solve up to final time.
        # (5) Return `Results` object.
        #
        # The stages can be found in the code by searching for (n) where n=1..5.
        #
        # By static vs. non-static assertions (see stages 1 and 3), we mean
        # assertions that can be made before the graph is run vs. those that can
        # only be made at run time. The latter are constructed as a list of
        # tf.Assert operations by the function `assert_ops` (see below).
        #
        # If `solution_times` is specified as a `Tensor`, stage 4 consists of three
        # nested loops, which can be conceptually understood as follows:
        # ```
        # current_time, current_state = initial_time, initial_state
        # order, step_size = 1, first_step_size
        # for solution_time in solution_times:
        #   while current_time < solution_time:
        #     while True:
        #       next_time = current_time + step_size
        #       next_state, error = (
        #           solve_nonlinear_equation_to_get_approximate_state_at_next_time(
        #           current_time, current_state, next_time, order))
        #       if error < tolerance:
        #         current_time, current_state = next_time, next_state
        #         order, step_size = (
        #           maybe_update_order_and_step_size(order, step_size))
        #         break
        #       else:
        #         step_size = decrease_step_size(step_size)
        # ```
        # The outermost loop advances the solver to the next `solution_time` (see
        # `advance_to_solution_time`). The middle loop advances the solver by a
        # small timestep (see `step`). The innermost loop determines the size of
        # that timestep (see `maybe_step`).
        #
        # If `solution_times` is specified as
        # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped
        # and `solution_time` in the middle loop is replaced by `final_time`.

        def assert_ops():
            """Creates a list of assert operations."""
            if not self._validate_args:
                return []
            assert_ops = []
            if ((not initial_state_missing)
                    and (previous_solver_internal_state is not None)):
                assert_initial_state_matches_previous_solver_internal_state = (
                    tf.assert_near(
                        tf.norm(
                            original_initial_state -
                            previous_solver_internal_state.
                            backward_differences[0], np.inf),
                        0.,
                        message=
                        '`previous_solver_internal_state` does not match '
                        '`initial_state`.'))
                assert_ops.append(
                    assert_initial_state_matches_previous_solver_internal_state
                )
            if solution_times_chosen_by_solver:
                assert_ops.append(
                    util.assert_positive(final_time - initial_time,
                                         'final_time - initial_time'))
            else:
                assert_ops += [
                    util.assert_increasing(solution_times, 'solution_times'),
                    util.assert_nonnegative(
                        solution_times[0] - initial_time,
                        'solution_times[0] - initial_time'),
                ]
            if max_num_steps is not None:
                assert_ops.append(
                    util.assert_positive(max_num_steps, 'max_num_steps'))
            if max_num_newton_iters is not None:
                assert_ops.append(
                    util.assert_positive(max_num_newton_iters,
                                         'max_num_newton_iters'))
            assert_ops += [
                util.assert_positive(rtol, 'rtol'),
                util.assert_positive(atol, 'atol'),
                util.assert_positive(first_step_size, 'first_step_size'),
                util.assert_positive(safety_factor, 'safety_factor'),
                util.assert_positive(min_step_size_factor,
                                     'min_step_size_factor'),
                util.assert_positive(max_step_size_factor,
                                     'max_step_size_factor'),
                tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER),
                          [
                              '`max_order` must be between 1 and {}.'.format(
                                  bdf_util.MAX_ORDER)
                          ]),
                util.assert_positive(newton_tol_factor, 'newton_tol_factor'),
                util.assert_positive(newton_step_size_factor,
                                     'newton_step_size_factor'),
            ]
            return assert_ops

        def advance_to_solution_time(n, diagnostics, iterand,
                                     solver_internal_state, states_array,
                                     times_array):
            """Takes multiple steps to advance time to `solution_times[n]`."""
            def step_cond(next_time, diagnostics, iterand, *_):
                return (iterand.time < next_time) & (tf.equal(
                    diagnostics.status, 0))

            solution_times_n = solution_times_array.read(n)
            [
                _, diagnostics, iterand, solver_internal_state, states_array,
                times_array
            ] = tf.while_loop(step_cond, step, [
                solution_times_n, diagnostics, iterand, solver_internal_state,
                states_array, times_array
            ])
            states_array = states_array.write(
                n, solver_internal_state.backward_differences[0])
            times_array = times_array.write(n, solution_times_n)
            return (n + 1, diagnostics, iterand, solver_internal_state,
                    states_array, times_array)

        def step(next_time, diagnostics, iterand, solver_internal_state,
                 states_array, times_array):
            """Takes a single step."""
            distance_to_next_time = next_time - iterand.time
            overstepped = iterand.new_step_size > distance_to_next_time
            iterand = iterand._replace(new_step_size=tf.where(
                overstepped, distance_to_next_time, iterand.new_step_size),
                                       should_update_step_size=overstepped
                                       | iterand.should_update_step_size)

            if not self._evaluate_jacobian_lazily:
                diagnostics = diagnostics._replace(
                    num_jacobian_evaluations=diagnostics.
                    num_jacobian_evaluations + 1)
                iterand = iterand._replace(jacobian=jacobian_fn_mat(
                    iterand.time,
                    solver_internal_state.backward_differences[0]),
                                           jacobian_is_up_to_date=True)

            def maybe_step_cond(accepted, diagnostics, *_):
                return tf.logical_not(accepted) & tf.equal(
                    diagnostics.status, 0)

            _, diagnostics, iterand, solver_internal_state = tf.while_loop(
                maybe_step_cond, maybe_step,
                [False, diagnostics, iterand, solver_internal_state])

            if solution_times_chosen_by_solver:
                states_array = states_array.write(
                    states_array.size(),
                    solver_internal_state.backward_differences[0])
                times_array = times_array.write(times_array.size(),
                                                iterand.time)

            return (next_time, diagnostics, iterand, solver_internal_state,
                    states_array, times_array)

        def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
            """Takes a single step only if the outcome has a low enough error."""
            [
                num_jacobian_evaluations, num_matrix_factorizations,
                num_ode_fn_evaluations, status
            ] = diagnostics
            [
                jacobian, jacobian_is_up_to_date, new_step_size, num_steps,
                num_steps_same_size, should_update_jacobian,
                should_update_step_size, time, unitary, upper
            ] = iterand
            backward_differences, order, state_shape, step_size = solver_internal_state

            if max_num_steps is not None:
                status = tf.where(tf.equal(num_steps, max_num_steps), -1, 0)

            backward_differences = tf.where(
                should_update_step_size,
                bdf_util.interpolate_backward_differences(
                    backward_differences, order, new_step_size / step_size),
                backward_differences)
            step_size = tf.where(should_update_step_size, new_step_size,
                                 step_size)
            should_update_factorization = should_update_step_size
            num_steps_same_size = tf.where(should_update_step_size, 0,
                                           num_steps_same_size)

            def update_factorization():
                return bdf_util.newton_qr(
                    jacobian, newton_coefficients_array.read(order), step_size)

            if self._evaluate_jacobian_lazily:

                def update_jacobian_and_factorization():
                    new_jacobian = jacobian_fn_mat(time,
                                                   backward_differences[0])
                    new_unitary, new_upper = update_factorization()
                    return [
                        new_jacobian, True, num_jacobian_evaluations + 1,
                        new_unitary, new_upper
                    ]

                def maybe_update_factorization():
                    new_unitary, new_upper = tf.cond(
                        should_update_factorization, update_factorization,
                        lambda: [unitary, upper])
                    return [
                        jacobian, jacobian_is_up_to_date,
                        num_jacobian_evaluations, new_unitary, new_upper
                    ]

                [
                    jacobian, jacobian_is_up_to_date, num_jacobian_evaluations,
                    unitary, upper
                ] = tf.cond(should_update_jacobian,
                            update_jacobian_and_factorization,
                            maybe_update_factorization)
            else:
                unitary, upper = update_factorization()
                num_matrix_factorizations += 1

            tol = atol + rtol * tf.abs(backward_differences[0])
            newton_tol = newton_tol_factor * tf.norm(tol)

            [
                newton_converged, next_backward_difference, next_state,
                newton_num_iters
            ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                                newton_coefficients_array.read(order),
                                ode_fn_vec, order, step_size, time, newton_tol,
                                unitary, upper)
            num_steps += 1
            num_ode_fn_evaluations += newton_num_iters

            # If Newton's method failed and the Jacobian was up to date, decrease the
            # step size.
            newton_failed = tf.logical_not(newton_converged)
            should_update_step_size = newton_failed & jacobian_is_up_to_date
            new_step_size = step_size * tf.where(should_update_step_size,
                                                 newton_step_size_factor, 1.)

            # If Newton's method failed and the Jacobian was NOT up to date, update
            # the Jacobian.
            should_update_jacobian = newton_failed & tf.logical_not(
                jacobian_is_up_to_date)

            error_ratio = tf.where(
                newton_converged,
                bdf_util.error_ratio(next_backward_difference,
                                     error_coefficients_array.read(order),
                                     tol), np.nan)
            accepted = error_ratio < 1.
            converged_and_rejected = newton_converged & tf.logical_not(
                accepted)

            # If Newton's method converged but the solution was NOT accepted, decrease
            # the step size.
            new_step_size = tf.where(
                converged_and_rejected,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = should_update_step_size | converged_and_rejected

            # If Newton's method converged and the solution was accepted, update the
            # matrix of backward differences.
            time = tf.where(accepted, time + step_size, time)
            backward_differences = tf.where(
                accepted,
                bdf_util.update_backward_differences(backward_differences,
                                                     next_backward_difference,
                                                     next_state, order),
                backward_differences)
            jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(
                accepted)
            num_steps_same_size = tf.where(accepted, num_steps_same_size + 1,
                                           num_steps_same_size)

            # Order and step size are only updated if we have taken strictly more than
            # order + 1 steps of the same size. This is to prevent the order from
            # being throttled.
            should_update_order_and_step_size = accepted & (num_steps_same_size
                                                            > order + 1)

            backward_differences_array = tf.TensorArray(
                backward_differences.dtype,
                size=bdf_util.MAX_ORDER + 3,
                clear_after_read=False,
                element_shape=next_backward_difference.get_shape()).unstack(
                    backward_differences)
            new_order = order
            new_error_ratio = error_ratio
            for offset in [-1, +1]:
                proposed_order = tf.clip_by_value(order + offset, 1, max_order)
                proposed_error_ratio = bdf_util.error_ratio(
                    backward_differences_array.read(proposed_order + 1),
                    error_coefficients_array.read(proposed_order), tol)
                proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
                new_order = tf.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_order, new_order)
                new_error_ratio = tf.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_error_ratio,
                    new_error_ratio)
            order = new_order
            error_ratio = new_error_ratio

            new_step_size = tf.where(
                should_update_order_and_step_size,
                util.next_step_size(step_size, order, error_ratio,
                                    safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = (should_update_step_size
                                       | should_update_order_and_step_size)

            diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                          num_matrix_factorizations,
                                          num_ode_fn_evaluations, status)
            iterand = _BDFIterand(jacobian, jacobian_is_up_to_date,
                                  new_step_size, num_steps,
                                  num_steps_same_size, should_update_jacobian,
                                  should_update_step_size, time, unitary,
                                  upper)
            solver_internal_state = _BDFSolverInternalState(
                backward_differences, order, state_shape, step_size)
            return accepted, diagnostics, iterand, solver_internal_state

        # (1) Make static assertions.
        # TODO(parsiad): Support specifying Jacobian sparsity patterns.
        if jacobian_sparsity is not None:
            raise NotImplementedError(
                'The BDF solver does not support specifying '
                'Jacobian sparsity patterns.')
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'The BDF solver does not support batching.')
        solution_times_chosen_by_solver = (isinstance(solution_times,
                                                      base.ChosenBySolver))
        initial_state_missing = initial_state is None
        if initial_state_missing and previous_solver_internal_state is None:
            raise ValueError(
                'At least one of `initial_state` or `previous_solver_internal_state` '
                'must be specified')

        with tf.name_scope(self._name):

            # (2) Initialize variables.
            original_initial_state = initial_state
            if previous_solver_internal_state is None:
                initial_state = tf.convert_to_tensor(initial_state)
                original_state_shape = tf.shape(initial_state)
            else:
                initial_state = previous_solver_internal_state.backward_differences[
                    0]
                original_state_shape = previous_solver_internal_state.state_shape
            state_dtype = initial_state.dtype
            util.error_if_not_real_or_complex(initial_state, 'initial_state')
            # TODO(parsiad): Support complex automatic Jacobians.
            if jacobian_fn is None and state_dtype.is_complex:
                raise NotImplementedError(
                    'The BDF solver does not support automatic '
                    'Jacobian computations for complex dtypes.')
            num_odes = tf.size(initial_state)
            original_state_tensor_shape = initial_state.get_shape()
            initial_state = tf.reshape(initial_state, [-1])
            ode_fn_vec = util.get_ode_fn_vec(ode_fn, original_state_shape)
            # `real_dtype` is the floating point `dtype` associated with
            # `initial_state.dtype` (recall that the latter can be complex).
            real_dtype = tf.abs(initial_state).dtype
            initial_time = tf.ensure_shape(
                tf.convert_to_tensor(initial_time, dtype=real_dtype), [])
            num_solution_times = 0
            if solution_times_chosen_by_solver:
                final_time = solution_times.final_time
                final_time = tf.ensure_shape(
                    tf.convert_to_tensor(final_time, dtype=real_dtype), [])
            else:
                solution_times = tf.convert_to_tensor(solution_times,
                                                      dtype=real_dtype)
                num_solution_times = tf.size(solution_times)
                solution_times_array = tf.TensorArray(
                    solution_times.dtype,
                    size=num_solution_times,
                    element_shape=[]).unstack(solution_times)
                util.error_if_not_vector(solution_times, 'solution_times')
            jacobian_fn_mat = util.get_jacobian_fn_mat(
                jacobian_fn,
                ode_fn_vec,
                original_state_shape,
                use_pfor=self._use_pfor_to_compute_jacobian)
            rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype)
            atol = tf.convert_to_tensor(self._atol, dtype=real_dtype)
            safety_factor = tf.ensure_shape(
                tf.convert_to_tensor(self._safety_factor, dtype=real_dtype),
                [])
            min_step_size_factor = tf.ensure_shape(
                tf.convert_to_tensor(self._min_step_size_factor,
                                     dtype=real_dtype), [])
            max_step_size_factor = tf.ensure_shape(
                tf.convert_to_tensor(self._max_step_size_factor,
                                     dtype=real_dtype), [])
            max_num_steps = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
            max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32)
            max_num_newton_iters = self._max_num_newton_iters
            if max_num_newton_iters is not None:
                max_num_newton_iters = tf.convert_to_tensor(
                    max_num_newton_iters, dtype=tf.int32)
            newton_tol_factor = tf.ensure_shape(
                tf.convert_to_tensor(self._newton_tol_factor,
                                     dtype=real_dtype), [])
            newton_step_size_factor = tf.ensure_shape(
                tf.convert_to_tensor(self._newton_step_size_factor,
                                     dtype=real_dtype), [])
            bdf_coefficients = tf.cast(
                tf.concat([[0.],
                           tf.convert_to_tensor(self._bdf_coefficients,
                                                dtype=real_dtype)], 0),
                state_dtype)
            util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients')
            newton_coefficients = 1. / (
                (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS)
            newton_coefficients_array = tf.TensorArray(
                newton_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(newton_coefficients)
            error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / (
                bdf_util.ORDERS + 1)
            error_coefficients_array = tf.TensorArray(
                error_coefficients.dtype,
                size=bdf_util.MAX_ORDER + 1,
                clear_after_read=False,
                element_shape=[]).unstack(error_coefficients)
            first_step_size = self._first_step_size
            if first_step_size is None:
                first_step_size = bdf_util.first_step_size(
                    atol, error_coefficients_array.read(1), initial_state,
                    initial_time, ode_fn_vec, rtol, safety_factor)
            elif previous_solver_internal_state is not None:
                tf.logging.warn(
                    '`first_step_size` is ignored since'
                    '`previous_solver_internal_state` was specified.')
            first_step_size = tf.convert_to_tensor(first_step_size,
                                                   dtype=real_dtype)
            if self._validate_args:
                if max_num_steps is not None:
                    max_num_steps = tf.ensure_shape(max_num_steps, [])
                max_order = tf.ensure_shape(max_order, [])
                if max_num_newton_iters is not None:
                    max_num_newton_iters = tf.ensure_shape(
                        max_num_newton_iters, [])
                bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6])
                first_step_size = tf.ensure_shape(first_step_size, [])
            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                first_order_backward_difference = ode_fn_vec(
                    initial_time, initial_state) * tf.cast(
                        first_step_size, state_dtype)
                backward_differences = tf.concat([
                    tf.reshape(initial_state, [1, -1]),
                    first_order_backward_difference[tf.newaxis, :],
                    tf.zeros(tf.stack([bdf_util.MAX_ORDER + 1, num_odes]),
                             dtype=state_dtype),
                ], 0)
                solver_internal_state = _BDFSolverInternalState(
                    backward_differences=backward_differences,
                    order=1,
                    state_shape=original_state_shape,
                    step_size=first_step_size)
            states_array = tf.TensorArray(
                state_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=initial_state.get_shape())
            times_array = tf.TensorArray(
                real_dtype,
                size=num_solution_times,
                dynamic_size=solution_times_chosen_by_solver,
                element_shape=tf.TensorShape([]))
            diagnostics = _BDFDiagnostics(num_jacobian_evaluations=0,
                                          num_matrix_factorizations=0,
                                          num_ode_fn_evaluations=0,
                                          status=0)
            iterand = _BDFIterand(
                jacobian=tf.zeros([num_odes, num_odes], dtype=state_dtype),
                jacobian_is_up_to_date=False,
                new_step_size=solver_internal_state.step_size,
                num_steps=0,
                num_steps_same_size=0,
                should_update_jacobian=True,
                should_update_step_size=False,
                time=initial_time,
                unitary=tf.zeros([num_odes, num_odes], dtype=state_dtype),
                upper=tf.zeros([num_odes, num_odes], dtype=state_dtype))

            # (3) Make non-static assertions.
            with tf.control_dependencies(assert_ops()):

                # (4) Solve up to final time.
                if solution_times_chosen_by_solver:

                    def step_cond(next_time, diagnostics, iterand, *_):
                        return (iterand.time < next_time) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        states_array, times_array
                    ] = tf.while_loop(step_cond, step, [
                        final_time, diagnostics, iterand,
                        solver_internal_state, states_array, times_array
                    ])

                else:

                    def advance_to_solution_time_cond(n, diagnostics, *_):
                        return (n < num_solution_times) & (tf.equal(
                            diagnostics.status, 0))

                    [
                        _, diagnostics, iterand, solver_internal_state,
                        states_array, times_array
                    ] = tf.while_loop(
                        advance_to_solution_time_cond,
                        advance_to_solution_time, [
                            0, diagnostics, iterand, solver_internal_state,
                            states_array, times_array
                        ])

                # (6) Return `Results` object.
                states = tf.reshape(states_array.stack(),
                                    tf.concat([[-1], original_state_shape], 0))
                times = times_array.stack()
                if not solution_times_chosen_by_solver:
                    times.set_shape(solution_times.get_shape())
                    states.set_shape(solution_times.get_shape().concatenate(
                        original_state_tensor_shape))
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)
示例#8
0
 def ensure(x):
   x = tf.ensure_shape(x, [p.batch_size, p.max_sequence_length])
   return x