Beispiel #1
0
def weighted_binary_cross_entropy_loss(model, X, y_true):
    y_pred = model(X)
    _, counts = torch.unique(y_true, return_counts=True)
    pos_weight, neg_weight = counts
    sample_weights = torch.where(y_true == 0, neg_weight, pos_weight)
    loss = BCELoss(reduce=False)(y_pred, y_true.float()) * sample_weights
    loss = loss.mean()
    return create_train_results(y_pred, loss)