Пример #1
0
 def conv2d(x: xb.ArrayList, params: dict = None) -> xb.ArrayList:
     (inputs,) = x
     w = params["weights"]
     conv = jax.lax.conv_general_dilated(
         inputs, w, self.stride, self.padding, feature_group_count=groups
     )
     return xb.pack(conv)
Пример #2
0
 def train_dropout(x: ArrayList, params=None) -> ArrayList:
     array = x[0]
     mask = xr.bernoulli(array.shape)
     # It's possible the input array can be smaller than previous ones, such as the last batch
     # in an epoch. So I'll add a bit in there so the mask will be the same shape as the array
     output = pack(
         array * mask / max(1 - self.drop_p, 0.00001)
     )  # The max prevents dividing by zero
     return output
Пример #3
0
        def batchnorm2d(inputs: ArrayList, params: dict) -> ArrayList:
            (x, ) = inputs
            weights, bias = params["weights"], params["bias"]
            num_features = x.shape[1]

            # Reshaping for broadcasting ease
            x_mean = jnp.mean(x, axis=(0, 2, 3)).reshape(1, num_features, 1, 1)
            x_var = jnp.mean((x - x_mean)**2,
                             axis=(0, 2, 3)).reshape(1, num_features, 1, 1)

            x_norm = (x - x_mean) / jnp.sqrt(x_var + epsilon)
            y = weights * x_norm + bias
            return pack(y)
Пример #4
0
 def linear(x: xb.ArrayList, params=None) -> xb.ArrayList:
     (inputs,) = x
     w = params["weights"]
     return xb.pack(jnp.matmul(inputs, w))
Пример #5
0
 def flatten(x: ArrayList, params: dict) -> ArrayList:
     return pack(func(x[0]))
Пример #6
0
 def avg_pool2d(inputs: ArrayList, params):
     sums = jax.lax.reduce_window(
         inputs[0], 0.0, jax.lax.add, window_shape, strides, padding
     )
     return pack(sums / window_size)