def softshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # softshrink(x) = x - lambd if x > lambd # = x + lambd if x < -lambd # = 0 otherwise ge_mask = a > lambd le_mask = a < -lambd zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask)) result = refs.where(ge_mask, a - lambd, a) result = refs.where(le_mask, a + lambd, result) return refs.where(zero_mask, 0, result)
def hinge_embedding_loss( input: TensorLikeType, target: TensorLikeType, margin: float = 1.0, reduction: str = "mean", ) -> TensorLikeType: # Formula of loss (implementation gets confusing with all the refs.foo) # loss_without_reduction = input if y == 1 # = max(0, margin - input) if y == -1 _check_reduction_value(reduction) margin_clamp = refs.maximum(refs.sub(margin, input), 0) output_margin = refs.where(refs.ne(target, 1), margin_clamp, 0) output_self = refs.where(refs.ne(target, -1), input, 0) loss = refs.add(output_margin, output_self) return _apply_loss_reduction(loss, reduction)
def softshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # softshrink(x) = x - lambd if x > lambd # = x + lambd if x < -lambd # = 0 otherwise check( lambd >= 0, lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", ) ge_mask = a > lambd le_mask = a < -lambd zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask)) result = refs.where(ge_mask, a - lambd, a) result = refs.where(le_mask, a + lambd, result) return refs.where(zero_mask, 0, result)
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.prelu """ check( isinstance(a, TensorLike), lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", ) check( isinstance(weight, TensorLike), lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", ) if weight.numel() != 1: check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") channel_size = a.shape[1] if a.ndim >= 2 else 1 check( weight.numel() == channel_size, lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" f" {weight.numel()} and channel size = {channel_size}.", ) check( weight.ndim == 0 or weight.ndim == 1, lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " f"ndim = {weight.ndim}", ) weight = prims.broadcast_in_dim(weight, a.shape, tuple() if weight.ndim == 0 else (1, )) return refs.where(a > 0, a, a * weight)
def softplus( a: TensorLikeType, beta: Optional[NumberType] = None, threshold: NumberType = 20, inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.softplus """ if inplace: raise NotImplementedError rhs: TensorLikeType if beta is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(beta), python_type): msg = "beta argument of type {0} cannot be safely cast to type {1}!".format( type(beta), python_type) raise ValueError(msg) scaled_input = refs.mul(a, beta) rhs = refs.true_divide(refs.log1p(refs.exp(scaled_input)), beta) else: scaled_input = a rhs = refs.log1p(refs.exp(scaled_input)) return refs.where(refs.gt(scaled_input, threshold), a, rhs)
def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.selu """ if inplace: raise NotImplementedError alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 rhs = refs.mul(alpha, refs.expm1(a)) return refs.mul(scale, refs.where(refs.gt(a, 0), a, rhs))
def elu(a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!" .format(type(alpha), python_type)) raise ValueError(msg) rhs = refs.mul(alpha, refs.expm1(a)) else: rhs = refs.expm1(a) return refs.where(refs.gt(a, 0), a, rhs)
def hardshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # hardshrink(x) = x if x > lambd # = x if x < -lambd # = 0 otherwise return refs.where(refs.logical_and(a >= -lambd, a <= lambd), 0, a)
def hardshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # hardshrink(x) = x if x > lambd # = x if x < -lambd # = 0 otherwise return refs.where(abs(a) > abs(lambd), a, 0)