def _hamiltonian( self, model: keras.models.Model, inputs: tf.Tensor, targets: tf.Tensor, params: tf.Tensor, momenta: tf.Tensor ) -> float: """ Calculate the value of the Hamiltonian of the system. The parameter state should be a tf.Tensor (x in the thesis, ΞΎ in the original paper). It takes two arguments, params and momenta, instead of the single state argument, since it is possible to faster calculate the partial derivative with regard to the momenta. """ # Assign state parameters to model _model_set_flat_variables(model, params) # Predict the outcome with the new parameters prediction = model(inputs) return 0.5 * ( self.alpha * model.loss(targets, prediction) + self.beta * tf.tensordot(params, params, 2) + self.gamma * tf.tensordot(momenta, momenta, 2) )
def train_batch( self, model: keras.models.Model, input_batch: tf.Tensor, target_batch: tf.Tensor, metrics=[] ) -> float: """ Train the model with a single batch of samples Parameters ---------- model : tf.keras.Model The model to train input_batch : tf.Tensor Input samples of the batch to train. These contain the inputs of the batch target_batch : tf.Tensor Output samples of the batch to train. These contain the expected output of the model for the inputs. metrics : List[tf.metrics.Metric] List of tensorflow metrics. These will be evaluated after the training. """ # Basic sanity check that input sample count matches target sample # count sample_cnt = input_batch.shape[0] assert sample_cnt == target_batch.shape[0] # Call the model handler # (no-op if the model didn't change since the last iteration) self._check_model_and_state(model) # Redefine hamiltonian and gradient for the current batch batch_hamiltonian = self.get_hamiltonian(model, input_batch, target_batch) @tf.function def loss_gradient(params: tf.Tensor) -> tf.Tensor: _model_set_flat_variables(model, params) with tf.GradientTape() as tape: tape.watch(input_batch) loss = 0.5 * ( self.alpha * model.loss(target_batch, model(input_batch, training=True)) \ + self.beta * tf.tensordot(params, params, 2) ) return _flatten_variables( tape.gradient(loss, model.trainable_variables) ) self._integrator.loss_gradient = loss_gradient params = _flatten_variables(model.trainable_variables) params, velocity = self._integrator.integrate( self.ivp_period, self.ivp_step_size, params ) momenta = tf.sparse.sparse_dense_matmul(self.M, velocity) _model_set_flat_variables(model, params) # Predict output for the current batch and update the metrics accordingly prediction = model(input_batch) for metric in metrics: metric.update_state(target_batch, prediction) # We return the loss for the current batch and the energy in the system # for the current batch return model.loss(target_batch, prediction), batch_hamiltonian(params, momenta)