def dropout(a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False) -> TensorLikeType: if inplace: raise NotImplementedError if not training: return a assert p <= 1 assert p >= 0 if p == 1: return refs.zeros_like(a) if p == 0: return a p1m = 1 - p scale = 1 / p1m mask = refs.lt( refs.uniform(a.shape, low=0.0, high=1.0, dtype=torch.float32, device=a.device), p1m, ) return refs.mul(refs.mul(a, mask), scale)
def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.selu """ if inplace: raise NotImplementedError alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 rhs = refs.mul(alpha, refs.expm1(a)) return refs.mul(scale, refs.where(refs.gt(a, 0), a, rhs))
def margin_ranking_loss( input1: TensorLikeType, input2: TensorLikeType, target: TensorLikeType, margin: float = 0.0, reduction: str = "mean", ) -> TensorLikeType: # Formula of loss (implementation gets confusing with all the refs.foo) # loss_without_reduction = max(0, −target * (input1 − input2) + margin) if input1.ndim != input2.ndim or input1.ndim != target.ndim: raise RuntimeError( ( "margin_ranking_loss : All input tensors should have same dimension but got sizes: " "input1: {}, input2: {}, target: {} ".format( input1.shape, input2.shape, target.shape ) ) ) _check_reduction_value(reduction) neg_target = refs.neg(target) input_diff = refs.sub(input1, input2) mul_target_input = refs.mul(neg_target, input_diff) add_margin = refs.add(mul_target_input, margin) loss = refs.maximum(add_margin, 0) return _apply_loss_reduction(loss, reduction)
def softplus( a: TensorLikeType, beta: Optional[NumberType] = None, threshold: NumberType = 20, inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.softplus """ if inplace: raise NotImplementedError rhs: TensorLikeType if beta is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(beta), python_type): msg = "beta argument of type {0} cannot be safely cast to type {1}!".format( type(beta), python_type) raise ValueError(msg) scaled_input = refs.mul(a, beta) rhs = refs.true_divide(refs.log1p(refs.exp(scaled_input)), beta) else: scaled_input = a rhs = refs.log1p(refs.exp(scaled_input)) return refs.where(refs.gt(scaled_input, threshold), a, rhs)
def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.mish """ if inplace: raise NotImplementedError return refs.mul(a, refs.tanh(refs.nn.functional.softplus(a)))
def elu(a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!" .format(type(alpha), python_type)) raise ValueError(msg) rhs = refs.mul(alpha, refs.expm1(a)) else: rhs = refs.expm1(a) return refs.where(refs.gt(a, 0), a, rhs)