Beispiel #1
0
def margin_ranking_loss(
    input1: TensorLikeType,
    input2: TensorLikeType,
    target: TensorLikeType,
    margin: float = 0.0,
    reduction: str = "mean",
) -> TensorLikeType:
    # Formula of loss (implementation gets confusing with all the refs.foo)
    # loss_without_reduction = max(0, −target * (input1 − input2) + margin)
    if input1.ndim != input2.ndim or input1.ndim != target.ndim:
        raise RuntimeError(
            (
                "margin_ranking_loss : All input tensors should have same dimension but got sizes: "
                "input1: {}, input2: {}, target: {} ".format(
                    input1.shape, input2.shape, target.shape
                )
            )
        )
    _check_reduction_value(reduction)
    neg_target = refs.neg(target)
    input_diff = refs.sub(input1, input2)
    mul_target_input = refs.mul(neg_target, input_diff)
    add_margin = refs.add(mul_target_input, margin)
    loss = refs.maximum(add_margin, 0)
    return _apply_loss_reduction(loss, reduction)
Beispiel #2
0
def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
    if eps is None:
        eps = -1.0
    lo = eps
    hi = 1 - eps
    self = refs.clamp(self, lo, hi)
    return refs.log(refs.true_divide(self, refs.sub(1, self)))
Beispiel #3
0
def tanhshrink(a: TensorLikeType) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.tanhshrink
    """
    if not isinstance(a, TensorLike):
        raise RuntimeError(
            "Expected a tensor input for an elementwise unary operation!")
    return refs.sub(a, refs.tanh(a))
Beispiel #4
0
def hinge_embedding_loss(
    input: TensorLikeType,
    target: TensorLikeType,
    margin: float = 1.0,
    reduction: str = "mean",
) -> TensorLikeType:
    # Formula of loss (implementation gets confusing with all the refs.foo)
    # loss_without_reduction = input if y == 1
    #                        = max(0, margin - input) if y == -1
    _check_reduction_value(reduction)
    margin_clamp = refs.maximum(refs.sub(margin, input), 0)
    output_margin = refs.where(refs.ne(target, 1), margin_clamp, 0)
    output_self = refs.where(refs.ne(target, -1), input, 0)
    loss = refs.add(output_margin, output_self)
    return _apply_loss_reduction(loss, reduction)