def softplus( a: TensorLikeType, beta: Optional[NumberType] = None, threshold: NumberType = 20, inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.softplus """ if inplace: raise NotImplementedError rhs: TensorLikeType if beta is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(beta), python_type): msg = "beta argument of type {0} cannot be safely cast to type {1}!".format( type(beta), python_type) raise ValueError(msg) scaled_input = a * beta rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] else: scaled_input = a rhs = torch.log1p(torch.exp(scaled_input)) return torch.where(scaled_input > threshold, a, rhs)
def leaky_relu(a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.leaky_relu """ if inplace: raise NotImplementedError python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(negative_slope), python_type): msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" raise ValueError(msg) return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope))
def _maybe_convert_to_dtype( a: Union[TensorLikeType, NumberType, Sequence], dtype: torch.dtype) -> Union[TensorLikeType, NumberType, Sequence]: import torch._prims as prims if isinstance(a, TensorLike): if a.dtype != dtype: # NOTE: this is incorrect on the CPU # See https://github.com/pytorch/pytorch/issues/77553 return prims.convert_element_type(a, dtype) return a if isinstance(a, Number): return utils.dtype_to_type(dtype)(a) if isinstance(a, Sequence): return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) raise ValueError( "Received type {0} that is neither a tensor or a number!".format( type(a)))
def _maybe_convert_to_dtype( a: Union[TensorLikeType, NumberType, Sequence, None], dtype: torch.dtype ) -> Union[TensorLikeType, NumberType, Sequence, None]: import torch._prims as prims if isinstance(a, TensorLike): if a.dtype != dtype: # NOTE: this is incorrect on the CPU # See https://github.com/pytorch/pytorch/issues/77553 return prims.convert_element_type(a, dtype) return a if isinstance(a, Number): return utils.dtype_to_type(dtype)(a) if isinstance(a, Sequence): return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) # Passthrough None because some functions wrapped with type promotion # wrapper might have optional args if a is None: return None raise ValueError( "Received type {0} that is neither a tensor or a number!".format(type(a)) )
def elu(a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu """ if inplace: raise NotImplementedError rhs: TensorLikeType if alpha is not None: python_type = utils.dtype_to_type(a.dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = ( "alpha argument of type {0} cannot be safely cast to type {1}!" .format(type(alpha), python_type)) raise ValueError(msg) rhs = alpha * torch.expm1(a) else: rhs = torch.expm1(a) return torch.where(a > 0, a, rhs)
def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor: return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)