def __call__(self, tangent_func: phase_space.SymplecticTangentFunction,
              t: jnp.ndarray, y: phase_space.PhaseSpace,
              dt: jnp.ndarray) -> phase_space.PhaseSpace:
     q, p = y.q, y.p
     # This is intentional to prevent a bug where one uses y later
     del y
     # We always broadcast opposite to numpy (e.g. leading dims (batch) count)
     if dt.ndim > 0:
         dt = dt.reshape(dt.shape + (1, ) * (q.ndim - dt.ndim))
     if t.ndim > 0:
         t = t.reshape(t.shape + (1, ) * (q.ndim - t.ndim))
     t_q = t
     t_p = t
     for c, d in zip(self.momentum_coefficients,
                     self.position_coefficients):
         # Update momentum
         if c != 0.0:
             dp_dt = tangent_func(t_p, phase_space.PhaseSpace(q, p)).p
             p = p + c * dt * dp_dt
             t_p = t_p + c * dt
         # Update position
         if d != 0.0:
             dq_dt = tangent_func(t_q, phase_space.PhaseSpace(q, p)).q
             q = q + d * dt * dq_dt
             t_q = t_q + d * dt
     return phase_space.PhaseSpace(position=q, momentum=p)
Beispiel #2
0
    def compute_metrics(
        masked_lm_logits: jnp.ndarray,
        next_sentence_logits: jnp.ndarray,
        masked_lm_labels: jnp.ndarray,
        masked_lm_weights: jnp.ndarray,
        next_sentence_labels: jnp.ndarray,
    ):
        """Computes the pre-training loss and its components."""
        masked_lm_logits = nn.log_softmax(masked_lm_logits)
        masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )),
                                  masked_lm_logits.shape[-1])
        masked_lm_weights = masked_lm_weights.reshape((-1, ))
        masked_lm_loss = -jnp.sum(
            jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) *
            masked_lm_weights) / jnp.sum(masked_lm_weights)

        next_sentence_logits = nn.log_softmax(next_sentence_logits)
        next_sentence_labels = next_sentence_labels.reshape((-1, ))
        next_sentence_loss = -jnp.mean(
            jnp.sum(
                onehot(next_sentence_labels, next_sentence_logits.shape[-1]) *
                next_sentence_logits,
                axis=-1,
            ))
        return {
            "loss": masked_lm_loss + next_sentence_loss,
            "masked_lm_loss": masked_lm_loss,
            "next_sentence_loss": next_sentence_loss,
        }
Beispiel #3
0
def diagonal_between(x: np.ndarray,
                     start_axis: int = 0,
                     end_axis: int = -1) -> np.ndarray:
    """Returns the diagonal along all dimensions between start and end axes."""
    if end_axis == -1:
        end_axis = x.ndim
    half_ndim, ragged = divmod(end_axis - start_axis, 2)
    if ragged:
        raise ValueError(
            f'Need even number of axes to flatten, got {end_axis - start_axis}.'
        )
    if half_ndim == 0:
        return x

    side_shape = x.shape[start_axis:start_axis + half_ndim]
    side_size = size_at(side_shape)

    shape_2d = x.shape[:start_axis] + (side_size,
                                       side_size) + x.shape[end_axis:]
    shape_result = x.shape[:start_axis] + side_shape + x.shape[end_axis:]

    x = np.diagonal(x.reshape(shape_2d),
                    axis1=start_axis,
                    axis2=start_axis + 1)
    x = np.moveaxis(x, -1, start_axis)
    return x.reshape(shape_result)
Beispiel #4
0
def ab_decomposition(
        u: jnp.ndarray,
        v: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Decompose vector v as follows
    v = u @ a + u_orth @ b. If vector v is a tangent,
    then a matrix is  skew-hermitian.

    Args:
        u: array like of shape (..., n, m).
        v: array like of shape (..., n, m).

    Returns:
        elements of decomposition a, b, and u_orth."""

    n, m = u.shape[-2:]
    tail = u.shape[:-2]
    u = u.reshape((-1, n, m))
    v = v.reshape((-1, n, m))
    u_orth = vmap(lambda x: jnp.linalg.qr(x, mode='complete')[0])(u)[..., m:]
    a = u.conj().transpose((0, 2, 1)) @ v
    b = u_orth.conj().transpose((0, 2, 1)) @ v
    a = a.reshape((*tail, -1, m))
    b = b.reshape((*tail, -1, m))
    u_orth = u_orth.reshape((*tail, n, -1))
    return a, b, u_orth
Beispiel #5
0
def diag_part(a: jnp.ndarray) -> jnp.ndarray:
    """Returns the diagonal part of a matrix.

    Args:
        a: tensor of shape (..., n, n).

    Returns:
        tensor of shape (..., n)."""

    bs_shape = a.shape[:-2]
    matrix_shape = a.shape[-2:]
    a = vmap(jnp.diag)(a.reshape((-1, *matrix_shape)))
    a = a.reshape((*bs_shape, -1))
    return a
Beispiel #6
0
def transp(a: jnp.ndarray) -> jnp.ndarray:
    """Returns transposed matrix.

    Args:
        a: tensor of shape (..., n1, n2)

    Returns:
        tensor of shape (..., n2, n1)"""

    matrix_shape = a.shape[-2:]
    bs_shape = a.shape[:-2]
    a = a.reshape((-1, *matrix_shape))
    a = a.transpose((0, 2, 1))
    a = a.reshape((*bs_shape, matrix_shape[1], matrix_shape[0]))
    return a
Beispiel #7
0
    def project(self, points: jnp.ndarray):
        """Projects a 3D point (x,y,z) to a pixel position (x,y)."""
        batch_shape = points.shape[:-1]
        points = points.reshape((-1, 3))
        local_points = self.points_to_local_points(points)

        # Get normalized local pixel positions.
        x = local_points[..., 0] / local_points[..., 2]
        y = local_points[..., 1] / local_points[..., 2]
        r2 = x**2 + y**2

        # Apply radial distortion.
        distortion = 1.0 + r2 * (
            self.radial_distortion[0] + r2 *
            (self.radial_distortion[1] + self.radial_distortion[2] * r2))

        # Apply tangential distortion.
        x_times_y = x * y
        x = (x * distortion + 2.0 * self.tangential_distortion[0] * x_times_y +
             self.tangential_distortion[1] * (r2 + 2.0 * x**2))
        y = (y * distortion + 2.0 * self.tangential_distortion[1] * x_times_y +
             self.tangential_distortion[0] * (r2 + 2.0 * y**2))

        # Map the distorted ray to the image plane and return the depth.
        pixel_x = self.focal_length * x + self.skew * y + self.principal_point_x
        pixel_y = (self.focal_length * self.pixel_aspect_ratio * y +
                   self.principal_point_y)

        pixels = jnp.stack([pixel_x, pixel_y], axis=-1)
        return pixels.reshape((*batch_shape, 2))
Beispiel #8
0
    def pixels_to_rays(self,
                       pixels: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Returns the rays for the provided pixels.

        Args:
          pixels: [A1, ..., An, 2] tensor or np.array containing 2d pixel positions.

        Returns:
            An array containing the normalized ray directions in world coordinates.
        """
        if pixels.shape[-1] != 2:
            raise ValueError("The last dimension of pixels must be 2.")
        if pixels.dtype != self.dtype:
            raise ValueError(
                f"pixels dtype ({pixels.dtype!r}) must match camera "
                f"dtype ({self.dtype!r})")

        batch_shape = pixels.shape[:-1]
        pixels = pixels.reshape((-1, 2))

        local_rays_dir = self.pixel_to_local_rays(pixels)
        rays_dir = self.orientation.T @ local_rays_dir[..., jnp.newaxis]
        rays_dir = jnp.squeeze(rays_dir, axis=-1)

        # Normalize rays.
        rays_dir /= jnp.linalg.norm(rays_dir, axis=-1, keepdims=True)
        rays_dir = rays_dir.reshape((*batch_shape, 3))
        return rays_dir
Beispiel #9
0
  def apply(self,
            x: jnp.ndarray,
            config: ModelConfig,
            num_classes: int,
            train: bool = True) -> jnp.ndarray:
    """Returns the output of the head block.

    Args:
      x: The input to the block.
      config: A set of model parameters.
      num_classes: Dimension of the output of the model.
      train: Whether we are training or predicting.
    """
    # Build top
    x = conv2d(
        x,
        round_filters(config.top_base_filters, config),
        config,
        activation=config.activation,
        train=train)

    # Build classifier
    x = flax.nn.avg_pool(x, x.shape[1:3])
    if config.dropout_rate and config.dropout_rate > 0:
      x = flax.nn.dropout(x, config.dropout_rate, deterministic=not train)
    x = flax.nn.Dense(
        x, num_classes, kernel_init=dense_kernel_init_fn, name='dense')
    x = x.reshape([x.shape[0], -1])
    return x
Beispiel #10
0
def _samplewise_log_loss(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
    """Based on: https://github.com/scikit-learn/scikit-learn/blob/ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn
    /metrics/_classification.py#L2123"""
    if y_true.ndim == 0:  # If no dimension binary classification problem
        y_true = y_true.reshape(1)[:, jnp.newaxis]
        y_pred = y_pred.reshape(1)[:, jnp.newaxis]
    if y_true.shape[0] == 1:  # Reshuffle data to compute log loss correctly
        y_true = jnp.append(1 - y_true, y_true)
        y_pred = jnp.append(1 - y_pred, y_pred)

    # Clipping
    eps = 1e-15
    y_pred = y_pred.astype(jnp.float32).clip(eps, 1 - eps)

    loss = (y_true * -jnp.log(y_pred)).sum()
    return loss
Beispiel #11
0
    def apply_cost(self,
                   arr: jnp.ndarray,
                   axis: int = 0,
                   fn=None) -> jnp.ndarray:
        """Applies cost matrix to array (vector or matrix).

    This function applies the ground geometry's cost matrix, to perform either
    output = K arr (if axis=1)
    output = K' arr (if axis=0)
    where K is [num_a, num_b]

    Args:
      arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied
        by the cost matrix.
      axis: standard cost matrix if axis=1, transport if 0
      fn: function to apply to cost matrix element-wise before the dot product

    Returns:
      A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1
    """
        if arr.ndim == 1:
            return jax.vmap(
                lambda x: self._apply_cost_to_vec(x, axis, fn),
                1,
                1,
            )(arr.reshape(-1, 1))
        return jax.vmap(
            lambda x: self._apply_cost_to_vec(x, axis, fn),
            1,
            1,
        )(arr)
Beispiel #12
0
 def spatial_de_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
     if self.de_aggregation_type is None:
         assert x.ndim >= 4
         if self.data_format == "NHWC":
             assert x.shape[1:3] == self.initial_spatial_shape
         elif self.data_format == "NCHW":
             assert x.shape[2:4] == self.initial_spatial_shape
         return x
     elif self.de_aggregation_type == "linear_projection":
         assert x.ndim == 2
         n, d = x.shape
         d = min(d, self.max_de_aggregation_dims or d)
         out_d = d * self.initial_spatial_shape[
             0] * self.initial_spatial_shape[1]
         x = hk.Linear(out_d, name="LinearProjection")(x)
         if self.data_format == "NHWC":
             shape = (n, ) + self.initial_spatial_shape + (d, )
         else:
             shape = (n, d) + self.initial_spatial_shape
         return x.reshape(shape)
     elif self.de_aggregation_type == "tile":
         assert x.ndim == 2
         if self.data_format == "NHWC":
             repeats = (1, ) + self.initial_spatial_shape + (1, )
             x = x[:, None, None, :]
         else:
             repeats = (1, 1) + self.initial_spatial_shape
             x = x[:, :, None, None]
         return jnp.tile(x, repeats)
     else:
         raise NotImplementedError()
Beispiel #13
0
def mod_demod_conv_transpose(
    inputs: jnp.ndarray,
    styles: jnp.ndarray,
    orig_weight: jnp.ndarray,
    channel_index: int,
    demodulate: bool = True,
    **kwargs,
):
    assert styles.ndim == 1
    num_spatial = orig_weight.ndim - 2
    if channel_index == -1:
        new_shape = (1, ) * num_spatial + (1, styles.size)
        # Compute normalization over all axes except for output-channel
        reduce_axes = tuple(range(num_spatial)) + (-1, )
    else:
        new_shape = (1, styles.size) + (1, ) * num_spatial
        reduce_axes = tuple(range(1, 2 + num_spatial))

    # Apply styles over input-channel dimension of weights
    weight = orig_weight * styles.reshape(new_shape)

    if demodulate:
        norm = jax.lax.square(weight).sum(axis=reduce_axes, keepdims=True)
        weight = weight * jax.lax.rsqrt(norm + 1e-8)

    inputs = jnp.expand_dims(inputs, axis=0)
    (result, ) = jax.lax.conv_transpose(inputs, weight, **kwargs)
    return result
Beispiel #14
0
def forward(params: Sequence[jnp.ndarray], fns: Sequence[Callable],
            x: jnp.ndarray) -> jnp.ndarray:
    """Forward transformation of composining RealNVP bijectors and a permutation
    bijector between them.

    Args:
        params: List of arrays parameterizing the RealNVP bijectors.
        fns: List of functions that compute the shift and scale of the RealNVP
            affine transformation.
        x: Input to transform according to the composition of RealNVP
            transformations and permutations.

    Returns:
        y: The transformed input.

    """
    num_dims = x.shape[-1]
    num_dims_sq = num_dims**2
    half_num_dims_sq = num_dims_sq // 2
    num_masked = num_dims_sq - half_num_dims_sq
    perm = jnp.roll(jnp.arange(num_dims_sq), 1)
    y = x.reshape((-1, num_dims_sq))
    for i in range(len(fns)):
        y = realnvp.forward(y, num_masked, params[i], fns[i])
        y = permute.forward(y, perm)
    return y.reshape(x.shape)
Beispiel #15
0
 def select_action(self, state: jnp.ndarray):
     return apply_td3_actor_model(
         self.actor_optimizer.target,
         self.action_dim,
         self.max_action,
         state.reshape(1, -1),
     ).flatten()
Beispiel #16
0
def apply_im2col_conv(x: jnp.ndarray, w: jnp.ndarray,
                      filter_shape: Sequence[int], stride: Sequence[int],
                      padding: Union[str, Sequence[Tuple[int, int]]],
                      lhs_dilation: Sequence[int], rhs_dilation: Sequence[int],
                      dimension_numbers: Sequence[str], transpose: bool,
                      **kwargs):
    H, W, C_in = x.shape[-3:]
    Kx, Ky = filter_shape
    C_out = w.shape[-1]

    # assert w.shape == (H, W, Kx, Ky, C_in, C_out)
    w = w.reshape((H, W, Kx * Ky * C_in, C_out))

    x_i2c = im2col(x,
                   filter_shape=filter_shape,
                   stride=stride,
                   padding=padding,
                   lhs_dilation=lhs_dilation,
                   rhs_dilation=rhs_dilation,
                   dimension_numbers=dimension_numbers)

    # if transpose:
    #   x_i2c = jax.ops.index_update(x_i2c, jax.ops.index[...,:], x_i2c[...,::-1])

    assert x_i2c.shape[-3:] == (H, W, C_in * Kx * Ky)

    if x.ndim == 3:
        out = jnp.einsum("hwi,hwio->hwo", x_i2c, w)
    else:
        out = jnp.einsum("bhwi,hwio->bhwo", x_i2c, w)

    # import pdb; pdb.set_trace()

    return out
Beispiel #17
0
def adj(a: jnp.ndarray) -> jnp.ndarray:
    """Returns adjoint matrix.

    Args:
        a: complex valued tensor of shape (..., n1, n2)

    Returns:
        complex valued tensor of shape (..., n2, n1)"""

    matrix_shape = a.shape[-2:]
    bs_shape = a.shape[:-2]
    a = a.reshape((-1, *matrix_shape))
    a = a.transpose((0, 2, 1))
    a = a.reshape((*bs_shape, matrix_shape[1], matrix_shape[0]))
    a = a.conj()
    return a
Beispiel #18
0
def log_prob(params: Sequence[jnp.ndarray], fns: Sequence[Callable],
             y: jnp.ndarray) -> jnp.ndarray:
    """Compute the log-probability of ambient observations under the transformation
    given by composing RealNVP bijectors and a permutation bijector between
    them. Assumes that the base distribution is a standard multivariate normal.

    Args:
        params: List of arrays parameterizing the RealNVP bijectors.
        fns: List of functions that compute the shift and scale of the RealNVP
            affine transformation.
        y: Observations whose likelihood under the composition of bijectors
            should be computed.

    Returns:
        out: The log-probability of the observations given the parameters of the
            bijection composition.

    """
    num_dims = y.shape[-1]
    num_dims_sq = num_dims**2
    half_num_dims_sq = num_dims_sq // 2
    num_masked = num_dims_sq - half_num_dims_sq
    perm = jnp.roll(jnp.arange(num_dims_sq), 1)
    fldj = 0.
    y = y.reshape((-1, num_dims_sq))
    for i in reversed(range(len(fns))):
        y = permute.inverse(y, perm)
        fldj += permute.forward_log_det_jacobian()
        y = realnvp.inverse(y, num_masked, params[i], fns[i])
        fldj += realnvp.forward_log_det_jacobian(y, num_masked, params[i],
                                                 fns[i])
    logprob = jspst.multivariate_normal.logpdf(y, jnp.zeros((num_dims_sq, )),
                                               1.)
    return logprob - fldj
Beispiel #19
0
    def softmax_grad(self, softmax: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Vectorized softmax Jacobian

        Args:
            softmax (jnp.ndarray)
        """
        s = softmax.reshape(-1, 1)
        return jnp.diagflat(s) - jnp.dot(s, s.T)
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        init_cache: bool = False,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:

        residual = hidden_states

        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
        )
        hidden_states = self.dropout_layer(hidden_states,
                                           deterministic=deterministic)
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        hidden_states_shape = hidden_states.shape
        hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
        residual = hidden_states

        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)

        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)

        hidden_states = self.fc2(hidden_states)
        hidden_states = self.dropout_layer(hidden_states,
                                           deterministic=deterministic)

        hidden_states = (residual + hidden_states).reshape(hidden_states_shape)

        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states, )

        if output_attentions:
            outputs += (self_attn_weights, )

        return outputs
 def __call__(self, embeddings: jnp.ndarray) -> jnp.ndarray:
     batch_size, seq_len, _ = embeddings.shape
     output = jnp.matmul(
         embeddings.reshape([-1, self._d_model]),  # Flatten batch dim
         jnp.transpose(self._embedding_matrix))
     bias = hk.get_parameter('bias',
                             shape=[self._vocab_size],
                             init=jnp.zeros)
     output = output + bias
     return output.reshape([batch_size, seq_len, self._vocab_size])
Beispiel #22
0
    def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray:
        b_axes = utils.canonicalize_axis(b_axes, b)
        last_b_axes = range(-len(b_axes), 0)
        x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes)

        b = np.moveaxis(b, b_axes, last_b_axes)
        b = b.reshape((A.shape[1], -1))

        x = sp.linalg.cho_solve(C, b)
        x = x.reshape(x_shape)
        return x
Beispiel #23
0
def _vmap_2d(fn: Callable[[float, float, float], float], cov12: np.ndarray,
             var1: np.ndarray, var2: Optional[np.ndarray],
             diagonal_batch: bool, diagonal_spatial: bool) -> np.ndarray:
    """Effectively a "2D vmap" of `fn(cov12, var1, var2)`.

  Applicable for all possible kernel layouts.

  Args:
    fn:
      scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply.

    cov12:
      covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape
      `(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)`
      depending on `diagonal_batch`, `diagonal_spatial`, and the number of
      spatial dimensions.

    var1:
      variance tensor (`q11`), has shape `(N1[, X, Y, ...])`.

    var2:
      variance tensor (`q22`), has shape `(N1[, X, Y, ...])`.

    diagonal_batch:
      `True` if `cov12` has only one batch dimension.

    diagonal_spatial:
      `True` if `cov12` has spatial dimensions appearing once (vs twice).

  Returns:
    Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same
    shape as `cov12`.
  """
    batch_ndim = 1 if diagonal_batch else 2
    start = 2 - batch_ndim
    cov_end = batch_ndim if diagonal_spatial else cov12.ndim
    _cov12 = utils.make_2d(cov12, start, cov_end)

    var_end = 1 if diagonal_spatial else var1.ndim
    var1 = var1.reshape(var1.shape[:start] + (-1, ) + var1.shape[var_end:])
    var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1, ) +
                                                  var2.shape[var_end:])

    fn = vmap(vmap(np.vectorize(fn),
                   in_axes=(start, None, start),
                   out_axes=start),
              in_axes=(start, start, None),
              out_axes=start)
    out = fn(_cov12, var1, var2)  # type: np.ndarray
    out_shape = (cov12.shape[:start] + cov12.shape[start:cov_end:2] +
                 cov12.shape[start + 1:cov_end:2] + cov12.shape[cov_end:])
    out = out.reshape(out_shape)
    out = utils.zip_axes(out, start, cov_end)
    return out
Beispiel #24
0
 def select_action(self, state: jnp.ndarray) -> jnp.ndarray:
     mu, _ = apply_gaussian_policy_model(
         self.actor_optimizer.target,
         self.action_dim,
         self.max_action,
         state.reshape(1, -1),
         None,
         False,
         True,
     )
     return mu
Beispiel #25
0
 def spatial_aggregation(self, x: jnp.ndarray) -> jnp.ndarray:
     if self.aggregation_type is None:
         return x
     axis = (1, 2) if self.data_format == "NHWC" else (2, 3)
     if self.aggregation_type == "max":
         return jnp.max(x, axis=axis)
     if self.aggregation_type == "mean":
         return jnp.mean(x, axis=axis)
     if self.aggregation_type == "linear_projection":
         x = x.reshape(x.shape[:-3] + (-1, ))
         return hk.Linear(self.output_dim, name="LinearProjection")(x)
     raise NotImplementedError()
    def apply(self,
              x: jnp.ndarray,
              blocks_per_group: int,
              channel_multiplier: int,
              num_outputs: int,
              train: bool = True,
              true_gradient: bool = False) -> jnp.ndarray:
        """Implements a WideResnet with ShakeShake regularization module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      blocks_per_group: How many resnet blocks to add to each group (should be
        4 blocks for a WRN26 as per standard shake shake implementation).
      channel_multiplier: The multiplier to apply to the number of filters in
        the model (1 is classical resnet, 6 for WRN26-2x6, etc...).
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.
      true_gradient: If true, the same mixing parameter will be used for the
        forward and backward pass (see paper for more details).

    Returns:
      The output of the WideResnet with ShakeShake regularization, a tensor of
      shape [batch_size, num_classes].
    """
        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    kernel_init=utils.conv_kernel_init_fn,
                    bias=False,
                    name='init_conv')
        x = utils.activation(x, apply_relu=False, train=train, name='init_bn')
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      16 * channel_multiplier,
                                      train=train,
                                      true_gradient=true_gradient)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      32 * channel_multiplier, (2, 2),
                                      train=train,
                                      true_gradient=true_gradient)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      64 * channel_multiplier, (2, 2),
                                      train=train,
                                      true_gradient=true_gradient)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
Beispiel #27
0
 def sample_action(self, rng: PRNGSequence, state: jnp.ndarray) -> jnp.ndarray:
     mu, log_sig = apply_gaussian_policy_model(
         self.actor_optimizer.target,
         self.action_dim,
         self.max_action,
         state.reshape(1, -1),
         None,
         False,
         True,
     )
     sig = jnp.exp(log_sig)
     return mu + random.normal(rng, mu.shape) * sig
Beispiel #28
0
    def linear_classifier(x: jnp.ndarray, net_params, rng,
                          initializing) -> elegy.OutputStates:
        x = x.reshape((x.shape[0], -1)) / 255

        if initializing:
            w = jax.random.uniform(rng.next(), shape=[x.shape[-1], 10])
            b = jax.random.uniform(rng.next(), shape=[10])
            net_params = (w, b)

        w, b = net_params
        y_pred = jnp.dot(x, w) + b

        return elegy.OutputStates(y_pred, net_params, None)
 def __call__(self, tangent_func: GeneralTangentFunction, t: jnp.ndarray,
              y: M, dt: jnp.ndarray) -> M:  # pytype: disable=invalid-annotation
     k = [tangent_func(t, y)]
     zero = jax.tree_map(jnp.zeros_like, k[0])
     # We always broadcast opposite to numpy (e.g. leading dims (batch) count)
     if dt.ndim > 0:
         dt = dt.reshape(dt.shape + (1, ) * (y.ndim - dt.ndim))
     if t.ndim > 0:
         t = t.reshape(t.shape + (1, ) * (y.ndim - t.ndim))
     for c_n, a_n_row in zip(self.c_tableau, self.a_tableau):
         t_n = t + dt * c_n
         products = [
             a_i * k_i for a_i, k_i in zip(a_n_row, k) if a_i != 0.0
         ]
         delta_n = sum(products, zero)
         y_n = y + dt * delta_n
         k.append(tangent_func(t_n, y_n))
     products = [
         b_i * k_i for b_i, k_i in zip(self.b_tableau, k) if b_i != 0.0
     ]
     delta = sum(products, zero)
     return y + dt * delta
Beispiel #30
0
    def __call__(
            self,
            x: jnp.ndarray) -> tp.Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        x = x.reshape((x.shape[0], -1))  # flatten
        x = hk.Linear(self.hidden_size)(x)
        x = jax.nn.relu(x)

        mean = hk.Linear(self.latent_size, name="linear_mean")(x)
        log_stddev = hk.Linear(self.latent_size, name="linear_std")(x)
        stddev = jnp.exp(log_stddev)

        z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape)

        return z, mean, stddev