def compute_binary_CE_loss(label_predictions, mortality_label):
    #print("Computing binary classification loss: compute_CE_loss")

    mortality_label = mortality_label.reshape(-1)

    if len(label_predictions.size()) == 1:
        label_predictions = label_predictions.unsqueeze(0)

    n_traj_samples = label_predictions.size(0)
    label_predictions = label_predictions.reshape(n_traj_samples, -1)

    idx_not_nan = ~torch.isnan(mortality_label)
    if len(idx_not_nan) == 0.:
        print("All are labels are NaNs!")
        ce_loss = torch.Tensor(0.).to(get_device(mortality_label))

    label_predictions = label_predictions[:, idx_not_nan]
    mortality_label = mortality_label[idx_not_nan]

    if torch.sum(mortality_label == 0.) == 0 or torch.sum(
            mortality_label == 1.) == 0:
        print(
            "Warning: all examples in a batch belong to the same class -- please increase the batch size."
        )

    assert (not torch.isnan(label_predictions).any())
    assert (not torch.isnan(mortality_label).any())

    # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
    mortality_label = mortality_label.repeat(n_traj_samples, 1)
    ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)

    # divide by number of patients in a batch
    ce_loss = ce_loss / n_traj_samples
    return ce_loss
def compute_masked_likelihood(mu, data, mask, likelihood_func):
    # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
    n_traj_samples, n_traj, n_timepoints, n_dims = data.size()

    res = []
    for i in range(n_traj_samples):
        for k in range(n_traj):
            for j in range(n_dims):
                data_masked = torch.masked_select(data[i, k, :, j],
                                                  mask[i, k, :, j].bool())

                #assert(torch.sum(data_masked == 0.) < 10)

                mu_masked = torch.masked_select(mu[i, k, :, j], mask[i, k, :,
                                                                     j].bool())
                log_prob = likelihood_func(mu_masked,
                                           data_masked,
                                           indices=(i, k, j))
                res.append(log_prob)
    # shape: [n_traj*n_traj_samples, 1]

    res = torch.stack(res, 0).to(get_device(data))
    res = res.reshape((n_traj_samples, n_traj, n_dims))
    # Take mean over the number of dimensions
    res = torch.mean(res, -1)  # !!!!!!!!!!! changed from sum to mean
    res = res.transpose(0, 1)
    return res
def mse(mu, data, indices=None):
    n_data_points = mu.size()[-1]

    if n_data_points > 0:
        mse = nn.MSELoss()(mu, data)
    else:
        mse = torch.zeros([1]).to(get_device(data)).squeeze()
    return mse
def compute_multiclass_CE_loss(label_predictions, true_label, mask):
    #print("Computing multi-class classification loss: compute_multiclass_CE_loss")

    if (len(label_predictions.size()) == 3):
        label_predictions = label_predictions.unsqueeze(0)

    n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()

    # assert(not torch.isnan(label_predictions).any())
    # assert(not torch.isnan(true_label).any())

    # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
    true_label = true_label.repeat(n_traj_samples, 1, 1)

    label_predictions = label_predictions.reshape(
        n_traj_samples * n_traj * n_tp, n_dims)
    true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims)

    # choose time points with at least one measurement
    mask = torch.sum(mask, -1) > 0

    # repeat the mask for each label to mark that the label for this time point is present
    pred_mask = mask.repeat(n_dims, 1, 1).permute(1, 2, 0)

    label_mask = mask
    pred_mask = pred_mask.repeat(n_traj_samples, 1, 1, 1)
    label_mask = label_mask.repeat(n_traj_samples, 1, 1, 1)

    pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims)
    label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)

    if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
        assert (label_predictions.size(-1) == true_label.size(-1))
        # targets are in one-hot encoding -- convert to indices
        _, true_label = true_label.max(-1)

    res = []
    for i in range(true_label.size(0)):
        pred_masked = torch.masked_select(label_predictions[i],
                                          pred_mask[i].bool())
        labels = torch.masked_select(true_label[i], label_mask[i].bool())

        pred_masked = pred_masked.reshape(-1, n_dims)

        if (len(labels) == 0):
            continue

        ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long())
        res.append(ce_loss)

    ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
    ce_loss = torch.mean(ce_loss)
    # # divide by number of patients in a batch
    # ce_loss = ce_loss / n_traj_samples
    return ce_loss
def poisson_log_likelihood(masked_log_lambdas, masked_data, indices,
                           int_lambdas):
    # masked_log_lambdas and masked_data
    n_data_points = masked_data.size()[-1]

    if n_data_points > 0:
        log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices]
        #log_prob = log_prob / n_data_points
    else:
        log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze()
    return log_prob
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices=None):
    n_data_points = mu_2d.size()[-1]

    if n_data_points > 0:
        gaussian = Independent(
            Normal(loc=mu_2d, scale=obsrv_std.repeat(n_data_points)), 1)
        log_prob = gaussian.log_prob(data_2d)
        log_prob = log_prob / n_data_points
    else:
        log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
    return log_prob