def call(self, inputs: Dict[str, tf.Tensor]) -> tf.Tensor: """Predict the target state after multiple time-steps. Args: inputs: dict of tensors with dimensions [batch, x, y]. Returns: labels: tensor with dimensions [batch, time, x, y], giving the target value of the predicted state at steps [1, ..., self.num_time_steps] for model training. """ constant_state = { k: v for k, v in inputs.items() if k in self.equation.constant_keys } evolving_inputs = { k: v for k, v in inputs.items() if k in self.equation.evolving_keys } def advance(evolving_state, _): return self.take_time_step({**evolving_state, **constant_state}) advanced = tf.scan(advance, tf.range(self.num_time_steps), initializer=evolving_inputs) advanced = tensor_ops.moveaxis(advanced, source=0, destination=1) # TODO(shoyer): support multiple targets, once keras does. # https://github.com/tensorflow/tensorflow/issues/25299 return advanced[self.target]
def call(self, inputs: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Predict the evolved state. Args: inputs: dict of tensors with dimensions [batch, x, y]. Returns: labels: dict of tensors with dimensions [batch, time, x, y], giving the predicted state at steps [1, ..., self.num_time_steps]. """ constant_state = { k: v for k, v in inputs.items() if k in self.equation.constant_keys } evolving_inputs = { k: v for k, v in inputs.items() if k in self.equation.evolving_keys } def advance(evolving_state, _): state = dict(evolving_state) state.update(constant_state) return self.take_time_step(state) advanced = tf.scan(advance, tf.range(self.num_time_steps), initializer=evolving_inputs) advanced = tensor_ops.moveaxis(advanced, source=0, destination=1) return advanced
def integrate_steps( model: models.TimeStepModel, state: KeyedTensors, steps: ArrayLike, initial_time: float = 0.0, axis: int = 0, xla_compile: bool = False, ) -> KeyedTensors: """Integrate some fixed number of time steps. Args: model: model to integrate. state: starting value of the state. steps: number of time steps at which the solution is saved. initial_time: initial time for time integration. axis: axis in result tensors along which the integrated solution is stacked. xla_compile: whether to compile with XLA or not. Returns: Time evolved states at the times specified in `times`. Each tensor has the same shape as the inputs, with an additional dimension inserted to store values at each requested time. """ # TODO(shoyer): explicitly include time? del initial_time # unused state = nest.map_structure(tf.convert_to_tensor, state) steps = tf.convert_to_tensor(steps, dtype=tf.int32) constant_state = { k: v for k, v in state.items() if k in model.equation.constant_keys } evolving_state = { k: v for k, v in state.items() if k in model.equation.evolving_keys } @tf.function def advance_until_saved_step(state, start_stop): """Integrate until the next step at which to save results.""" start, stop = start_stop # can't use range() in a for loop with XLA: # https://github.com/tensorflow/tensorflow/issues/30182 i = start while i < stop: state = model.take_time_step({**state, **constant_state}) i += 1 return state if xla_compile: advance_until_saved_step = _xla_decorator(advance_until_saved_step) starts = tf.concat([[0], steps[:-1]], axis=0) integrated = tf.scan(advance_until_saved_step, [starts, steps], initializer=evolving_state) integrated_constants = nest.map_structure( lambda x: tf.broadcast_to(x, steps.shape.as_list() + x.shape.as_list()), constant_state) integrated.update(integrated_constants) return tensor_ops.moveaxis(integrated, 0, axis)