def train_step(model, x, t, e, a, optimizer, bs=256, lambd=1., seed=0):

  """Optimizes the model for one epoch.

  Args:
    model:
      instance of CoupledDeepCPHVAE class.
    x:
      a numpy array of input features (Training Data).
    t:
      a numpy vector of event times (Training Data).
    e:
      a numpy vector of event indicators (1 if event occured, 0 otherwise)
      (Training Data).
    a:
      a numpy vector of the protected group membership (Training Data).
    optimizer:
      instance of tf.keras.optimizers (default is Adam)
    bs:
      int minibatch size.
    lambd:
      float Strength of the VAE loss term.
    seed:
      random seed.

  Returns:
    None. Trains the model inplace.

  """

  x, t, e, a = shuffle(x, t, e, a, random_state=seed)

  n = x.shape[0]

  batches = (n // bs) + 1

  for i in range(batches):

    xb = x[i * bs:(i + 1) * bs]
    tb = t[i * bs:(i + 1) * bs]
    eb = e[i * bs:(i + 1) * bs]
    ab = a[i * bs:(i + 1) * bs]

    with tf.GradientTape() as tape:

      pll = partial_ll_loss(model, xb, tb, eb, ab, l2=0.001)
      vaeloss = vae_loss(model, xb)
      loss = pll + lambd*vaeloss

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
def test_step(model, x, t, e, a, loss='concordance', lambd=1.):

  """Test the model and compute validation metric.

  Args:
    model:
      instance of CoupledDeepCPHVAE class.
    x:
      a numpy array of input features (Val/Test Data).
    t:
      a numpy vector of event times (Val/Test Data).
    e:
      a numpy vector of event indicators (1 if event occured, 0 otherwise)
      (Val/Test Data).
    a:
      a numpy vector of the protected group membership (Val/Test Data).
    loss (str):
      string the loss metric to compute. one of 'concordance' or 'pll'.
    lambd (float):
       Strength of the VAE loss term.

  Returns:
    a float loss.

  """

  if loss == 'concordance':

    risks = np.zeros_like(a)

    lrisksp, lrisksn = model(x)

    lrisksp, lrisksn = lrisksp[:, 0], lrisksn[:, 0]

    risks[a == 1] = lrisksp[a == 1]
    risks[a == 0] = lrisksn[a == 0]

    pci = lifelines.utils.concordance_index(t[a == 1], -risks[a == 1],
                                            e[a == 1])
    nci = lifelines.utils.concordance_index(t[a == 0], -risks[a == 0],
                                            e[a == 0])
    return 0.5 * (nci + pci)

  if loss == 'pll':

    pll = partial_ll_loss(model, x, t, e, a, l2=0.001)
    vaeloss = vae_loss(model, x)
    loss = pll + lambd*vaeloss
    return float(loss)