Ejemplo n.º 1
0
    def _hamiltonian(
            self, model: keras.models.Model,
            inputs: tf.Tensor, targets: tf.Tensor,
            params: tf.Tensor, momenta: tf.Tensor
        ) -> float:
        """
            Calculate the value of the Hamiltonian of the system.

            The parameter state should be a tf.Tensor
            (x in the thesis, ξ in the original paper).

            It takes two arguments, params and momenta, instead of the single state
            argument, since it is possible to faster calculate the partial derivative
            with regard to the momenta.
        """

        # Assign state parameters to model
        _model_set_flat_variables(model, params)

        # Predict the outcome with the new parameters
        prediction = model(inputs)

        return 0.5 * (
              self.alpha * model.loss(targets, prediction)
            + self.beta  * tf.tensordot(params, params, 2)
            + self.gamma * tf.tensordot(momenta, momenta, 2)
        )
Ejemplo n.º 2
0
    def train_batch(
            self, model: keras.models.Model,
            input_batch: tf.Tensor, target_batch: tf.Tensor,
            metrics=[]
        ) -> float:
        """
            Train the model with a single batch of samples

            Parameters
            ----------
            model : tf.keras.Model
                The model to train
            input_batch : tf.Tensor
                Input samples of the batch to train.
                These contain the inputs of the batch
            target_batch : tf.Tensor
                Output samples of the batch to train.
                These contain the expected output of the model for the inputs.
            metrics : List[tf.metrics.Metric]
                List of tensorflow metrics.
                These will be evaluated after the training.
        """

        # Basic sanity check that input sample count matches target sample
        # count
        sample_cnt = input_batch.shape[0]
        assert sample_cnt == target_batch.shape[0]

        # Call the model handler
        # (no-op if the model didn't change since the last iteration)
        self._check_model_and_state(model)

        # Redefine hamiltonian and gradient for the current batch
        batch_hamiltonian = self.get_hamiltonian(model, input_batch, target_batch)

        @tf.function
        def loss_gradient(params: tf.Tensor) -> tf.Tensor:
            _model_set_flat_variables(model, params)

            with tf.GradientTape() as tape:
                tape.watch(input_batch)

                loss = 0.5 * (
                       self.alpha * model.loss(target_batch, model(input_batch, training=True)) \
                     + self.beta  * tf.tensordot(params, params, 2)
                )

            return _flatten_variables(
                tape.gradient(loss, model.trainable_variables)
            )
        self._integrator.loss_gradient = loss_gradient

        params = _flatten_variables(model.trainable_variables)

        params, velocity = self._integrator.integrate(
            self.ivp_period, self.ivp_step_size, params
        )
        momenta = tf.sparse.sparse_dense_matmul(self.M, velocity)

        _model_set_flat_variables(model, params)

        # Predict output for the current batch and update the metrics accordingly
        prediction = model(input_batch)
        for metric in metrics:
            metric.update_state(target_batch, prediction)

        # We return the loss for the current batch and the energy in the system
        # for the current batch
        return model.loss(target_batch, prediction), batch_hamiltonian(params, momenta)