def broadcastinfo( a_shape: tuple[int, ...], b_shape: tuple[int, ...]) -> tuple[tuple[int, ...], tuple[int, ...]]: "Get which dimensions are added or repeated when `a` and `b` are broadcast." ndim = max(len(a_shape), len(b_shape)) add_ndims_to_a = ndim - len(a_shape) add_ndims_to_b = ndim - len(b_shape) a_shape_ = np.array([1] * add_ndims_to_a + list(a_shape)) b_shape_ = np.array([1] * add_ndims_to_b + list(b_shape)) if not all((a_shape_ == b_shape_) | (a_shape_ == 1) | (b_shape_ == 1)): raise ValueError(f"could not broadcast shapes {a_shape} {b_shape}") a_repeatdims = (a_shape_ == 1) & (b_shape_ > 1) # the repeated dims a_repeatdims[:add_ndims_to_a] = True # the added dims a_repeatdims = np.where( a_repeatdims == True)[0] # indices of axes where True a_repeatdims = [int(i) for i in a_repeatdims] b_repeatdims = (b_shape_ == 1) & (a_shape_ > 1) b_repeatdims[:add_ndims_to_b] = True b_repeatdims = np.where(b_repeatdims == True)[0] b_repeatdims = [int(i) for i in b_repeatdims] return tuple(a_repeatdims), tuple(b_repeatdims)
def leaky_relu(a: Variable, alpha: float = 0.02) -> Variable: "Elementwise leaky relu." multiplier = np.where(a.array > 0, np.array(1, a.dtype), np.array(alpha, a.dtype)) value = a.array * multiplier local_gradients = [(a, lambda path_value: path_value * multiplier)] return Variable(value, local_gradients)
def where(condition: np.ndarray, a: Variable, b: Variable) -> Variable: "Condition is a boolean NumPy array, a and b are Variables." value = np.where(condition, a.array, b.array) a_, b_ = enable_broadcast(a, b) def multiply_by_locgrad_a(path_value): return np.where( condition, path_value, np.zeros(path_value.shape, a.array.dtype), ) def multiply_by_locgrad_b(path_value): return np.where( condition, np.zeros(path_value.shape, a.array.dtype), path_value, ) local_gradients = [ (a_, multiply_by_locgrad_a), (b_, multiply_by_locgrad_b), ] return Variable(value, local_gradients)
def multiply_by_locgrad_b(path_value): return np.where( condition, np.zeros(path_value.shape, a.array.dtype), path_value, )