Exemple #1
0
    def call(self, inputs: Dict[str, tf.Tensor]) -> tf.Tensor:
        """Predict the target state after multiple time-steps.

    Args:
      inputs: dict of tensors with dimensions [batch, x, y].

    Returns:
      labels: tensor with dimensions [batch, time, x, y], giving the target
        value of the predicted state at steps [1, ..., self.num_time_steps]
        for model training.
    """
        constant_state = {
            k: v
            for k, v in inputs.items() if k in self.equation.constant_keys
        }
        evolving_inputs = {
            k: v
            for k, v in inputs.items() if k in self.equation.evolving_keys
        }

        def advance(evolving_state, _):
            return self.take_time_step({**evolving_state, **constant_state})

        advanced = tf.scan(advance,
                           tf.range(self.num_time_steps),
                           initializer=evolving_inputs)
        advanced = tensor_ops.moveaxis(advanced, source=0, destination=1)
        # TODO(shoyer): support multiple targets, once keras does.
        # https://github.com/tensorflow/tensorflow/issues/25299
        return advanced[self.target]
Exemple #2
0
    def call(self, inputs: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
        """Predict the evolved state.

    Args:
      inputs: dict of tensors with dimensions [batch, x, y].

    Returns:
      labels: dict of tensors with dimensions [batch, time, x, y], giving the
        predicted state at steps [1, ..., self.num_time_steps].
    """
        constant_state = {
            k: v
            for k, v in inputs.items() if k in self.equation.constant_keys
        }
        evolving_inputs = {
            k: v
            for k, v in inputs.items() if k in self.equation.evolving_keys
        }

        def advance(evolving_state, _):
            state = dict(evolving_state)
            state.update(constant_state)
            return self.take_time_step(state)

        advanced = tf.scan(advance,
                           tf.range(self.num_time_steps),
                           initializer=evolving_inputs)
        advanced = tensor_ops.moveaxis(advanced, source=0, destination=1)
        return advanced
Exemple #3
0
def integrate_steps(
    model: models.TimeStepModel,
    state: KeyedTensors,
    steps: ArrayLike,
    initial_time: float = 0.0,
    axis: int = 0,
    xla_compile: bool = False,
) -> KeyedTensors:
    """Integrate some fixed number of time steps.

  Args:
    model: model to integrate.
    state: starting value of the state.
    steps: number of time steps at which the solution is saved.
    initial_time: initial time for time integration.
    axis: axis in result tensors along which the integrated solution is
      stacked.
    xla_compile: whether to compile with XLA or not.

  Returns:
    Time evolved states at the times specified in `times`. Each tensor has the
    same shape as the inputs, with an additional dimension inserted to store
    values at each requested time.
  """
    # TODO(shoyer): explicitly include time?
    del initial_time  # unused

    state = nest.map_structure(tf.convert_to_tensor, state)
    steps = tf.convert_to_tensor(steps, dtype=tf.int32)
    constant_state = {
        k: v
        for k, v in state.items() if k in model.equation.constant_keys
    }
    evolving_state = {
        k: v
        for k, v in state.items() if k in model.equation.evolving_keys
    }

    @tf.function
    def advance_until_saved_step(state, start_stop):
        """Integrate until the next step at which to save results."""
        start, stop = start_stop
        # can't use range() in a for loop with XLA:
        # https://github.com/tensorflow/tensorflow/issues/30182
        i = start
        while i < stop:
            state = model.take_time_step({**state, **constant_state})
            i += 1
        return state

    if xla_compile:
        advance_until_saved_step = _xla_decorator(advance_until_saved_step)

    starts = tf.concat([[0], steps[:-1]], axis=0)
    integrated = tf.scan(advance_until_saved_step, [starts, steps],
                         initializer=evolving_state)

    integrated_constants = nest.map_structure(
        lambda x: tf.broadcast_to(x,
                                  steps.shape.as_list() + x.shape.as_list()),
        constant_state)
    integrated.update(integrated_constants)

    return tensor_ops.moveaxis(integrated, 0, axis)