Exemple #1
0
def broadcastinfo(
        a_shape: tuple[int, ...],
        b_shape: tuple[int, ...]) -> tuple[tuple[int, ...], tuple[int, ...]]:
    "Get which dimensions are added or repeated when `a` and `b` are broadcast."
    ndim = max(len(a_shape), len(b_shape))

    add_ndims_to_a = ndim - len(a_shape)
    add_ndims_to_b = ndim - len(b_shape)

    a_shape_ = np.array([1] * add_ndims_to_a + list(a_shape))
    b_shape_ = np.array([1] * add_ndims_to_b + list(b_shape))

    if not all((a_shape_ == b_shape_) | (a_shape_ == 1) | (b_shape_ == 1)):
        raise ValueError(f"could not broadcast shapes {a_shape} {b_shape}")

    a_repeatdims = (a_shape_ == 1) & (b_shape_ > 1)  # the repeated dims
    a_repeatdims[:add_ndims_to_a] = True  # the added dims
    a_repeatdims = np.where(
        a_repeatdims == True)[0]  # indices of axes where True
    a_repeatdims = [int(i) for i in a_repeatdims]

    b_repeatdims = (b_shape_ == 1) & (a_shape_ > 1)
    b_repeatdims[:add_ndims_to_b] = True
    b_repeatdims = np.where(b_repeatdims == True)[0]
    b_repeatdims = [int(i) for i in b_repeatdims]

    return tuple(a_repeatdims), tuple(b_repeatdims)
Exemple #2
0
def leaky_relu(a: Variable, alpha: float = 0.02) -> Variable:
    "Elementwise leaky relu."
    multiplier = np.where(a.array > 0, np.array(1, a.dtype),
                          np.array(alpha, a.dtype))
    value = a.array * multiplier
    local_gradients = [(a, lambda path_value: path_value * multiplier)]
    return Variable(value, local_gradients)
Exemple #3
0
def where(condition: np.ndarray, a: Variable, b: Variable) -> Variable:
    "Condition is a boolean NumPy array, a and b are Variables."
    value = np.where(condition, a.array, b.array)
    a_, b_ = enable_broadcast(a, b)

    def multiply_by_locgrad_a(path_value):
        return np.where(
            condition,
            path_value,
            np.zeros(path_value.shape, a.array.dtype),
        )

    def multiply_by_locgrad_b(path_value):
        return np.where(
            condition,
            np.zeros(path_value.shape, a.array.dtype),
            path_value,
        )

    local_gradients = [
        (a_, multiply_by_locgrad_a),
        (b_, multiply_by_locgrad_b),
    ]
    return Variable(value, local_gradients)
Exemple #4
0
 def multiply_by_locgrad_b(path_value):
     return np.where(
         condition,
         np.zeros(path_value.shape, a.array.dtype),
         path_value,
     )