def train_step(model, x, t, e, a, optimizer, bs=256, lambd=1., seed=0): """Optimizes the model for one epoch. Args: model: instance of CoupledDeepCPHVAE class. x: a numpy array of input features (Training Data). t: a numpy vector of event times (Training Data). e: a numpy vector of event indicators (1 if event occured, 0 otherwise) (Training Data). a: a numpy vector of the protected group membership (Training Data). optimizer: instance of tf.keras.optimizers (default is Adam) bs: int minibatch size. lambd: float Strength of the VAE loss term. seed: random seed. Returns: None. Trains the model inplace. """ x, t, e, a = shuffle(x, t, e, a, random_state=seed) n = x.shape[0] batches = (n // bs) + 1 for i in range(batches): xb = x[i * bs:(i + 1) * bs] tb = t[i * bs:(i + 1) * bs] eb = e[i * bs:(i + 1) * bs] ab = a[i * bs:(i + 1) * bs] with tf.GradientTape() as tape: pll = partial_ll_loss(model, xb, tb, eb, ab, l2=0.001) vaeloss = vae_loss(model, xb) loss = pll + lambd*vaeloss gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables))
def test_step(model, x, t, e, a, loss='concordance', lambd=1.): """Test the model and compute validation metric. Args: model: instance of CoupledDeepCPHVAE class. x: a numpy array of input features (Val/Test Data). t: a numpy vector of event times (Val/Test Data). e: a numpy vector of event indicators (1 if event occured, 0 otherwise) (Val/Test Data). a: a numpy vector of the protected group membership (Val/Test Data). loss (str): string the loss metric to compute. one of 'concordance' or 'pll'. lambd (float): Strength of the VAE loss term. Returns: a float loss. """ if loss == 'concordance': risks = np.zeros_like(a) lrisksp, lrisksn = model(x) lrisksp, lrisksn = lrisksp[:, 0], lrisksn[:, 0] risks[a == 1] = lrisksp[a == 1] risks[a == 0] = lrisksn[a == 0] pci = lifelines.utils.concordance_index(t[a == 1], -risks[a == 1], e[a == 1]) nci = lifelines.utils.concordance_index(t[a == 0], -risks[a == 0], e[a == 0]) return 0.5 * (nci + pci) if loss == 'pll': pll = partial_ll_loss(model, x, t, e, a, l2=0.001) vaeloss = vae_loss(model, x) loss = pll + lambd*vaeloss return float(loss)