def predictions_to_classes(self,
                            predictions: tf.Tensor,
                            best=True) -> pd.DataFrame:
     if best:
         return pd.DataFrame({
             "class_id": predictions.argmax(-1),
             "score": predictions.max(-1)
         }).assign(
             label=lambda df: df.class_id.map(self.id_to_label.__getitem__))
     return pd.DataFrame(predictions, columns=self.id_to_label)
def fit(mps: classifier.MatrixProductState,
        optimizer,
        x: tf.Tensor,
        y: tf.Tensor,
        x_val: Optional[tf.Tensor] = None,
        y_val: Optional[tf.Tensor] = None,
        n_epochs: int = 20,
        batch_size: int = 10,
        n_message: int = 1):
    """Supervised training of an MPS classifier on a dataset.
  Args:
    mps: MatrixProductState classifier object.
    optimizer: TensorFlow optimizer object to use in training.
      A working option is AdamOptimizer with learning_rate=1e-4.
    x: Training data (encoded images) of shape (n_data, n_sites, d_phys)
    y: Training labels in one-hot format of shape (n_data, n_labels)
    x_val: Validation data to calculate loss and accuracy during training.
    y_val: Validation labels to calculate loss and accuracy during training.
    n_epochs: Total number of epochs to train.
    batch_size: Batch size for training.
    n_message: Every how many epoch to print messages (loss, accuracy, times).
  """
    data = tf.cast(x, dtype=mps.dtype)
    labels = tf.cast(y, dtype=mps.dtype)
    n_batch = len(x) // batch_size

    if x_val is not None:
        data_val = tf.cast(x_val, dtype=mps.dtype)
        labels_val = tf.cast(y_val, dtype=mps.dtype)
        n_batch_val = len(x_val) // batch_size

    history = {
        "loss": [],
        "acc": [],
        "total_time": [],
        "val_loss": [],
        "val_acc": []
    }

    start_time = time.time()
    for epoch in range(n_epochs):
        generator = ((data[i * batch_size:(i + 1) * batch_size],
                      labels[i * batch_size:(i + 1) * batch_size])
                     for i in range(n_batch))
        loss, logits = run_epoch(mps, generator, n_batch, optimizer)
        history["loss"].append(loss / len(x))
        history["acc"].append(
            (logits.numpy().argmax(axis=1) == y.argmax(axis=1)).mean())
        history["total_time"].append(time.time() - start_time)

        if x_val is not None:
            val_generator = ((data_val[i * batch_size:(i + 1) * batch_size],
                              labels_val[i * batch_size:(i + 1) * batch_size])
                             for i in range(n_batch_val))
            val_loss, val_logits = run_epoch(mps, val_generator, n_batch_val)
            history["val_loss"].append(val_loss / len(x_val))
            history["val_acc"].append((val_logits.numpy().argmax(
                axis=1) == y_val.argmax(axis=1)).mean())

        if epoch % n_message == 0:
            print("\nEpoch: {}".format(epoch))
            print("Time: {}".format(history["total_time"][-1]))
            print("Loss: {}".format(history["loss"][-1]))
            print("Accuracy: {}".format(history["acc"][-1]))
            if x_val is not None:
                print("Validation Accuracy: {}".format(history["val_acc"][-1]))

    return mps, history