Exemplo n.º 1
0
    def training_step(self, x, r_loss, beta):
        """Training step for the VAE.
  
    Parameters
    -------------------------------------------
    x: Data
    VAE(tf.keras.Model): Variational Autoencoder model. 
    optimizer(tf.keras.optimizer): Optimizer used.  
    r_loss(float): Parameter controlling reconstruction loss.
    beta(float): Parameter controlling the KL divergence.

    Return:
    Loss(float): Loss value of the training step.

    """
        with tf.GradientTape() as tape:
            reconstructed = self(
                x)  #, training=True)  # Compute input reconstruction.
            # Compute loss.
            loss = trace_loss(x, reconstructed)
            kl = sum(self.losses)
            loss = r_loss * loss + beta * kl

        # Update the weights of the VAE.
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return loss
Exemplo n.º 2
0
    def training_step(self, data, r_loss):
        """Training step for the AE.
  
    Parameters
    -------------------------------------------
    x: Data    
    optimizer(tf.keras.optimizer): Optimizer used.  
    r_loss(float): Parameter controlling reconstruction loss.    

    Return:
    Loss(float): Loss value of the training step.

    """
        x, y = data

        with tf.GradientTape() as tape:
            reconstructed = self(x)  # Compute input reconstruction.

            # Compute loss.
            loss = trace_loss(y, reconstructed)

        # Update the weights of the VAE.
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        fid = fidelity_rho(x, reconstructed)

        return loss, np.mean(fid)
Exemplo n.º 3
0
    def validating_step(self, data, r_loss):
        """Validation step for the VAE.
  
    Parameters
    -------------------------------------------
    x: Data    
    r_loss(float): Parameter controlling reconstruction loss.    

    Return:
    Loss(float): Loss value of the training step.

    """
        x, y = data

        reconstructed = self(x)  # Compute input reconstruction.
        # Compute loss.
        loss = trace_loss(y, reconstructed)

        fid = fidelity_rho(x, reconstructed)

        return loss, np.mean(fid)
Exemplo n.º 4
0
    def validating_step(self, x, r_loss, beta):
        """Validation step for the VAE.
  
    Parameters
    -------------------------------------------
    x: Data    
    r_loss(float): Parameter controlling reconstruction loss.
    beta(float): Parameter controlling the KL divergence.

    Return:
    Loss(float): Loss value of the training step.

    """
        reconstructed = self(x)  # Compute input reconstruction.
        # Compute loss.
        loss = trace_loss(x, reconstructed)
        kl = sum(self.losses)
        loss = r_loss * loss + beta * kl

        fid = fidelity_rho(x, reconstructed)

        return loss, np.mean(fid)