Beispiel #1
0
    def fn(global_step):
        """Returns a learning rate given the current training iteration."""

        float_training_steps = ff(training_steps)
        global_step = ff(global_step)

        # ensure we don't train longer than training steps
        global_step = jnp.minimum(global_step, float_training_steps)

        constant_steps = float_training_steps * constant_fraction
        x = jnp.maximum(ff(global_step), ff(constant_steps))

        min_learning_rate = min_learning_rate_mult * learning_rate

        if warmup_fraction:
            min_warmup_fraction = jnp.maximum(warmup_fraction,
                                              constant_fraction)
            warmup_steps = float_training_steps * min_warmup_fraction
            is_warmup = ff(jnp.greater(ff(warmup_steps), ff(global_step)))
            warmup_lr = (global_step / warmup_steps) * learning_rate
        else:
            warmup_lr = learning_rate
            is_warmup = 0.0

        step = x - constant_steps

        constant_and_decay = (learning_rate - min_learning_rate) * (
            jnp.cos(step * onp.pi /
                    (float_training_steps - constant_steps)) / 2.0 +
            0.5) + min_learning_rate

        new_learning_rate = constant_and_decay * (
            1.0 - is_warmup) + is_warmup * (warmup_lr)
        return new_learning_rate
Beispiel #2
0
def _localization_loss(pred_locs, gt_locs, gt_labels, num_matched_boxes):
    """Computes the localization loss.

  Computes the localization loss using smooth l1 loss.
  Args:
    pred_locs: a dict from index to tensor of predicted locations. The shape
      of each tensor is [batch_size, num_anchors, 4].
    gt_locs: a list of tensors representing box regression targets in
      [batch_size, num_anchors, 4].
    gt_labels: a list of tensors that represents the classification groundtruth
      targets. The shape is [batch_size, num_anchors, 1].
    num_matched_boxes: the number of anchors that are matched to a groundtruth
      targets, used as the loss normalizater. The shape is [batch_size].
  Returns:
    box_loss: a float32 representing total box regression loss.
  """

    keys = sorted(pred_locs.keys())
    box_loss = 0
    for i, k in enumerate(keys):
        gt_label = gt_labels[i]
        gt_loc = gt_locs[i]
        pred_loc = jnp.reshape(pred_locs[k], gt_loc.shape)
        mask = jnp.greater(gt_label, 0)
        float_mask = mask.astype(jnp.float32)

        smooth_l1 = jnp.sum(huber_loss(gt_loc, pred_loc), axis=-1)
        smooth_l1 = jnp.multiply(smooth_l1, float_mask)
        box_loss = box_loss + jnp.sum(
            smooth_l1, axis=list(range(1, len(smooth_l1.shape))))
    return jnp.mean(box_loss / num_matched_boxes)
def batch_codebook_coverage(codes: JTensor,
                            num_classes: int,
                            *,
                            paddings: JTensor,
                            data_parallel_axis: Optional[str] = None):
  """Computes codebook coverage within a batch.

  Args:
    codes:         [..., num_groups], values are in [0, num_classes).
    num_classes:   A Python int.
    paddings:      [...], 0/1 value tensor.
    data_parallel_axis: If set will psum() over the axis

  Returns:
    A scalar tf.Tensor, avg coverage across groups.
  """
  # [num_groups, num_classes]
  _, _, histogram = batch_pplx_entropy_from_codes(
      codes,
      num_classes,
      paddings=paddings,
      data_parallel_axis=data_parallel_axis)
  onehot = jnp.greater(histogram, 0).astype(jnp.float32)
  avg_num_covered_words = jnp.mean(jnp.sum(onehot, -1))
  return avg_num_covered_words / num_classes
Beispiel #4
0
    def forward_fn(data: Mapping[str, jnp.ndarray],
                   is_training: bool = True) -> jnp.ndarray:
        """Forward pass."""
        tokens = data['obs']
        input_mask = jnp.greater(tokens, 0)
        seq_length = tokens.shape[1]

        # Embed the input tokens and positions.
        embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
        token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
        token_embs = token_embedding_map(tokens)
        positional_embeddings = hk.get_parameter('pos_embs',
                                                 [seq_length, d_model],
                                                 init=embed_init)
        input_embeddings = token_embs + positional_embeddings

        # Run the transformer over the inputs.
        transformer = model.Transformer(num_heads=num_heads,
                                        num_layers=num_layers,
                                        dropout_rate=dropout_rate)
        output_embeddings = transformer(input_embeddings, input_mask,
                                        is_training)

        # Reverse the embeddings (untied).
        return hk.Linear(vocab_size)(output_embeddings)
Beispiel #5
0
def jax_mv_normal_entropy(cov):
  k = cov.shape[0]
  eigvals = np.linalg.eigvalsh(cov)
  eps = 1e-5*np.max(abs(eigvals))
  mask = np.greater(abs(eigvals)-eps, 0.)
  log_nonzero_eigvals = np.where(mask, np.log(eigvals), np.zeros_like(eigvals))
  log_det = np.sum(log_nonzero_eigvals)
  return k/2. + (k/2.)*np.log(2*np.pi) + .5*log_det
Beispiel #6
0
def insert(m, r, i):
  n = m.shape[0]
  a = np.concatenate([m, r[np.newaxis,:]], axis=0)
  before_inds = np.arange(n+1)*np.less(np.arange(n+1),i)
  after_inds = (np.arange(n+1)-1)*np.greater(np.arange(n+1),i)
  new_ind = np.ones(shape=[n+1], dtype=np.int32)*np.equal(np.arange(n+1),i)*n
  inds = before_inds + after_inds + new_ind
  return a[inds]
Beispiel #7
0
    def _iter_body(state):
        """One step of power iteration."""
        i, new_v, s, s_v, unused_run_step = state
        new_v = new_v / jnp.linalg.norm(new_v)

        s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
        s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
        return (i + 1, s_v, s_new, s_v,
                jnp.greater(jnp.abs(s_new - s), error_tolerance))
def i_stimulus(t, params_and_data):
    return lax.cond(
        np.logical_or(
            np.less(t, params_and_data["stimulus_start_time"]),
            np.greater(t, params_and_data["stimulus_end_time"]),
        ),
        lambda _: 0.0,
        lambda _: params_and_data["i_stimulus"],
        None,
    )
Beispiel #9
0
def lm_loss_fn(forward_fn, vocab_size, params, rng, data, is_training=True):
    """Compute the loss on data wrt params."""
    logits = forward_fn(params, rng, data, is_training)
    targets = hk.one_hot(data['target'], vocab_size)
    assert logits.shape == targets.shape

    mask = jnp.greater(data['obs'], 0)
    loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
    loss = jnp.sum(loss * mask) / jnp.sum(mask)

    return loss
Beispiel #10
0
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int):
    tokens = data['obs']
    input_mask = jnp.greater(tokens, 0)
    seq_length = tokens.shape[1]

    # Embed the input tokens and positions.
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
    token_embs = token_embedding_map(tokens)
    positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model],
                                             init=embed_init)
    input_embeddings = token_embs + positional_embeddings
    return input_embeddings, input_mask
Beispiel #11
0
        def single_iteration_condition(args):
            """Checks if the acceptance ratio or maximum iterations is reached

            Parameters
            ----------
            args : tuple
                loop variables (described in `single_iteration`)

            Returns
            -------
            bool:
                True if acceptance_ratio is reached or the maximum number of
                iterations is reached
            """
            return np.logical_and(np.greater(args[-3], acceptance_ratio),
                                  np.less(args[-2], max_iteration))
Beispiel #12
0
    def _update_controller(args):
        counter = args[0]
        jax.lax.cond(
            np.logical_and(
                np.logical_and(
                    np.greater(counter, 0),
                    np.equal(counter % print_rate, 0)),
                np.not_equal(counter, max_iterations - remainder)),
            lambda _: id_tap(
                _update_pbar, (print_rate,) + args),
            lambda _: (print_rate,) + args,
            operand=None)

        jax.lax.cond(
            np.equal(counter, max_iterations - remainder),
            lambda _: id_tap(
                _update_pbar, (remainder,) + args),
            lambda _: (remainder,) + args,
            operand=None)
Beispiel #13
0
    def forward_fn(data: Mapping[str, jnp.ndarray],
                   is_training: bool = True) -> jnp.ndarray:
        """Forward pass."""
        tokens = data['obs']
        input_mask = jnp.greater(tokens, 0)
        batch_size, seq_length = tokens.shape

        # Embed the input tokens and positions.
        embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
        token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
        input_embeddings = token_embedding_map(tokens)
        positional_embeddings = hk.get_parameter('pos_embs',
                                                 [seq_length, d_model],
                                                 init=embed_init)

        x = input_embeddings + positional_embeddings
        h = jnp.zeros_like(x)

        # Create transformer block
        transformer_block = model.UTBlock(num_heads=num_heads,
                                          num_layers=num_layers,
                                          dropout_rate=dropout_rate)

        transformed_net = hk.transform(transformer_block)

        # lift params
        inner_params = hk.experimental.lift(transformed_net.init)(
            hk.next_rng_key(), h, x, input_mask, is_training)

        def f(_params, _rng, _z, *args):
            return transformed_net.apply(_params,
                                         _rng,
                                         _z,
                                         *args,
                                         is_training=is_training)

        z_star = deq(inner_params, hk.next_rng_key(), h, f, max_iter, x,
                     input_mask)

        # Reverse the embeddings (untied).
        return hk.Linear(vocab_size)(z_star)
Beispiel #14
0
 def _greater(a, b):
     return jnp.greater(a, b)
def pi_adjusted_inverse(
    factor_0: jnp.ndarray,
    factor_1: jnp.ndarray,
    damping: jnp.ndarray,
    pmap_axis_name: str,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Performs inversion with pi-adjusted damping."""
    # Compute the norms of each factor
    norm_0 = jnp.trace(factor_0)
    norm_1 = jnp.trace(factor_1)

    # We need to sync the norms here, because reduction can be non-deterministic.
    # They specifically are on GPUs by default for better performance.
    # Hence although factor_0 and factor_1 are synced, the trace operation above
    # can still produce different answers on different devices.
    norm_0, norm_1 = pmean_if_pmap((norm_0, norm_1), axis_name=pmap_axis_name)

    # Compute the overall scale
    scale = norm_0 * norm_1

    def regular_inverse(
            operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
        factor0, factor1, norm0, norm1, s, d = operand
        # Special cases with one or two scalar factors
        if factor0.size == 1 and factor1.size == 1:
            value = jnp.ones_like(factor0) / jnp.sqrt(s)
            return value, value
        if factor0.size == 1:
            factor1_normed = factor1 / norm1
            damping1 = d / norm1
            factor1_inv = psd_inv_cholesky(factor1_normed, damping1)
            return jnp.full((1, 1), s), factor1_inv
        if factor1.size == 1:
            factor0_normed = factor0 / norm0
            damping0 = d / norm0
            factor0_inv = psd_inv_cholesky(factor0_normed, damping0)
            return factor0_inv, jnp.full((1, 1), s)

        # Invert first factor
        factor0_normed = factor0 / norm0
        damping0 = jnp.sqrt(d * factor1.shape[0] / (s * factor0.shape[0]))
        factor0_inv = psd_inv_cholesky(factor0_normed, damping0) / jnp.sqrt(s)

        # Invert second factor
        factor1_normed = factor1 / norm1
        damping1 = jnp.sqrt(d * factor0.shape[0] / (s * factor1.shape[0]))
        factor1_inv = psd_inv_cholesky(factor1_normed, damping1) / jnp.sqrt(s)
        return factor0_inv, factor1_inv

    def zero_inverse(
            operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
        return (jnp.eye(factor_0.shape[0]) / jnp.sqrt(operand[-1]),
                jnp.eye(factor_1.shape[0]) / jnp.sqrt(operand[-1]))

    # In the special case where for some reason one of the factors is zero, then
    # the correct inverse of `(0 kron A + lambda I)` is
    # `(I/sqrt(lambda) kron (I/sqrt(lambda)`. However, because one of the norms is
    # zero, then `pi` and `1/pi` would be 0 and infinity leading to NaN values.
    # Hence, we need to make this check explicitly.
    return lax.cond(jnp.greater(scale, 0.0),
                    regular_inverse,
                    zero_inverse,
                    operand=(factor_0, factor_1, norm_0, norm_1, scale,
                             damping))
    def update(self, params, x, y, loss=None):
        """
        Description: Updates parameters based on correct value, loss and learning rate.
        Args:
            params (list/numpy.ndarray): Parameters of method pred method
            x (float): input to method
            y (float): true label
            loss (function): loss function. defaults to input value.
        Returns:
            Updated parameters in same shape as input
        """
        assert self.initialized
        assert type(
            params
        ) == dict, "optimizers can only take params in dictionary format"

        grad = self.gradient(params, x, y,
                             loss=loss)  # defined in optimizers core class

        if self.theta is None:
            self.theta = {
                k: -dw
                for (k, w), dw in zip(params.items(), grad.values())
            }
        else:
            self.theta = {
                k: v - dw
                for (k, v), dw in zip(self.theta.items(), grad.values())
            }

        if self.eta is None:
            self.eta = {
                k: dw * dw
                for (k, w), dw in zip(params.items(), grad.values())
            }
        else:
            self.eta = {
                k: v + dw * dw
                for (k, v), dw in zip(self.eta.items(), grad.values())
            }

        if self.theta_max is None:
            self.theta_max = {
                k: np.absolute(v)
                for (k, v) in self.theta.items()
            }
        else:
            self.theta_max = {
                k: np.where(np.greater(np.absolute(v), v_max), np.absolute(v),
                            v_max)
                for (k, v), v_max in zip(self.theta.items(),
                                         self.theta_max.values())
            }

        new_params = {
            k: np.where(np.equal(0.0, np.maximum(theta_max, eta)), theta,
                        theta / np.sqrt(np.maximum(theta_max, eta)))
            for (k, w), theta, theta_max, eta in zip(params.items(
            ), self.theta.values(), self.theta_max.values(), self.eta.values())
        }

        x_new = np.roll(x, 1)
        x_new = jax.ops.index_update(x_new, 0, y)
        y_t = self.pred(params=new_params, x=x_new)

        #        print('y before {0}'.format(y_t))
        x_plus_bias_new = np.vstack((np.ones((1, 1)), x_new))
        new_mapped_params = {
            k: self.norm_project(
                np.where(np.equal(0.0, np.maximum(theta_max, eta)), 0.0,
                         1.0 / np.sqrt(np.maximum(theta_max, eta))),
                x_plus_bias_new, y_t, p)
            for (k, p), theta_max, eta in zip(
                new_params.items(), self.theta_max.values(), self.eta.values())
        }

        #        y_t = self.pred(params=new_mapped_params, x=x_new)
        #        print('y after {0}'.format(y_t))
        return new_mapped_params
Beispiel #17
0
    def fit(self,
            λ,
            ϵ,
            rng=None,
            patience=100,
            min_iterations=100,
            max_iterations=int(1e5),
            print_rate=None,
            best=True):
        """Fitting routine for the IMNN

        Parameters
        ----------
        λ : float
            Coupling strength of the regularisation
        ϵ : float
            Closeness criterion describing how close to the 1 the determinant
            of the covariance (and inverse covariance) of the network outputs
            is desired to be
        rng : int(2,) or None, default=None
            Stateless random number generator
        patience : int, default=10
            Number of iterations where there is no increase in the value of the
            determinant of the Fisher information matrix, used for early
            stopping
        min_iterations : int, default=100
            Number of iterations that should be run before considering early
            stopping using the patience counter
        max_iterations : int, default=int(1e5)
            Maximum number of iterations to run the fitting procedure for
        print_rate : int or None, default=None,
            Number of iterations before updating the progress bar whilst
            fitting. There is a performance hit from updating the progress bar
            more often and there is a large performance hit from using the
            progress bar at all. (Possible ``RET_CHECK`` failure if
            ``print_rate`` is not ``None`` when using GPUs).
            For this reason it is set to None as default
        best : bool, default=True
            Whether to set the network parameter attribute ``self.w`` to the
            parameter values that obtained the maximum determinant of
            the Fisher information matrix or the parameter values at the final
            iteration of fitting

        Example
        -------

        We are going to summarise the mean and variance of some random Gaussian
        noise with 10 data points per example using an AggregatedSimulatorIMNN.
        In this case we are going to generate the simulations on-the-fly with a
        simulator written in jax (from the examples directory). These
        simulations will be generated on-the-fly and passed through the network
        on each of the GPUs in ``jax.devices("gpu")`` and we will make 100
        simulations on each device at a time. The main computation will be done
        on the CPU. We will use 1000 simulations to estimate the covariance of
        the network outputs and the derivative of the mean of the network
        outputs with respect to the model parameters (Gaussian mean and
        variance) and generate the simulations at a fiducial μ=0 and Σ=1. The
        network will be a stax model with hidden layers of ``[128, 128, 128]``
        activated with leaky relu and outputting 2 summaries. Optimisation will
        be via Adam with a step size of ``1e-3``. Rather arbitrarily we'll set
        the regularisation strength and covariance identity constraint to λ=10
        and ϵ=0.1 (these are relatively unimportant for such an easy model).

        .. code-block:: python

            import jax
            import jax.numpy as np
            from jax.experimental import stax, optimizers
            from imnn import AggregatedSimulatorIMNN

            rng = jax.random.PRNGKey(0)

            n_s = 1000
            n_d = 1000
            n_params = 2
            n_summaries = 2
            input_shape = (10,)
            θ_fid = np.array([0., 1.])

            def simulator(rng, θ):
                return θ[0] + jax.random.normal(
                    rng, shape=input_shape) * np.sqrt(θ[1])

            model = stax.serial(
                stax.Dense(128),
                stax.LeakyRelu,
                stax.Dense(128),
                stax.LeakyRelu,
                stax.Dense(128),
                stax.LeakyRelu,
                stax.Dense(n_summaries))
            optimiser = optimizers.adam(step_size=1e-3)

            λ = 10.
            ϵ = 0.1

            model_key, fit_key = jax.random.split(rng)

            host = jax.devices("cpu")[0]
            devices = jax.devices("gpu")

            n_per_device = 100

            imnn = AggregatedSimulatorIMNN(
                n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries,
                input_shape=input_shape, θ_fid=θ_fid, model=model,
                optimiser=optimiser, key_or_state=model_key,
                simulator=simulator, host=host, devices=devices,
                n_per_device=n_per_device)

            imnn.fit(λ, ϵ, rng=fit_key, min_iterations=1000, patience=250,
                     print_rate=None)


        Notes
        -----
        A minimum number of interations should be be run before stopping based
        on a maximum determinant of the Fisher information achieved since the
        loss function has dual objectives. Since the determinant of the
        covariance of the network outputs is forced to 1 quickly, this can be
        at the detriment to the value of the determinant of the Fisher
        information matrix early in the fitting procedure. For this reason
        starting early stopping after the covariance has converged is advised.
        This is not currently implemented but could be considered in the
        future.

        The best fit network parameter values are probably not the most
        representative set of parameters when simulating on-the-fly since there
        is a high chance of a statistically overly-informative set of data
        being generated. Instead, if using
        :func:`~imnn.AggregatedSimulatorIMNN.fit()` consider using
        ``best=False`` which sets ``self.w=self.final_w`` which are the network
        parameter values obtained in the last iteration. Also consider using a
        larger ``patience`` value if using :func:`~imnn.SimulatorIMNN.fit()`
        to overcome the fact that a flukish high value for the determinant
        might have been obtained due to the realisation of the dataset.

        Raises
        ------
        TypeError
            If any input has the wrong type
        ValueError
            If any input (except ``rng``) are ``None``
        ValueError
            If ``rng`` has the wrong shape
        ValueError
            If ``rng`` is ``None`` but simulating on-the-fly

        Methods
        -------
        get_keys_and_params:
            Jitted collection of parameters and random numbers
        calculate_loss:
            Returns the jitted gradient of the loss function wrt summaries
        validation_loss:
            Jitted loss and auxillary statistics from validation set

        Todo
        ----
        - ``rng`` is currently only used for on-the-fly simulation but could
          easily be updated to allow for stochastic models
        - Automatic detection of convergence based on value ``r`` when early
          stopping can be started
        """
        @jax.jit
        def get_keys_and_params(rng, state):
            """Jitted collection of parameters and random numbers

            Parameters
            ----------
            rng : int(2,) or None, default=None
                Stateless random number generator
            state : :obj:state
                The optimiser state used for updating the network parameters
                and optimisation algorithm

            Returns
            -------
            int(2,) or None, default=None:
                Stateless random number generator
            int(2,) or None, default=None:
                Stateless random number generator for training
            int(2,) or None, default=None:
                Stateless random number generator for validation
            list:
                Network parameter values
            """
            rng, training_key, validation_key = self._get_fitting_keys(rng)
            w = self._get_parameters(state)
            return rng, training_key, validation_key, w

        @jax.jit
        @partial(jax.grad, argnums=(0, 1), has_aux=True)
        def calculate_loss(summaries, summary_derivatives):
            """Returns the jitted gradient of the loss function wrt summaries

            Used to calculate the gradient of the loss function wrt summaries
            and derivatives of the summaries with respect to model parameters
            which will be used to calculate the aggregated gradient of the
            Fisher information with respect to the network parameters via the
            chain rule.

            Parameters
            ----------
            summaries : float(n_s, n_summaries)
                The network outputs
            summary_derivatives : float(n_d, n_summaries, n_params)
                The derivative of the network outputs wrt the model parameters

            Returns
            -------
            tuple:
                Gradient of the loss function with respect to network outputs
                and their derivatives with respect to physical model parameters
            tuple:
                Fitting statistics calculated on a single iteration
                    - **F** *(float(n_params, n_params))* -- Fisher information
                      matrix
                    - **C** *(float(n_summaries, n_summaries))* -- covariance
                      of network outputs
                    - **invC** *(float(n_summaries, n_summaries))* -- inverse
                      covariance of network outputs
                    - **Λ2** *(float)* -- covariance regularisation
                    - **r** *(float)* -- regularisation coupling strength
            """
            return self._calculate_loss(summaries, summary_derivatives, λ, α)

        @jax.jit
        def validation_loss(summaries, derivatives):
            """Jitted loss and auxillary statistics from validation set

            Parameters
            ----------
            summaries : float(n_s, n_summaries)
                The network outputs
            summary_derivatives : float(n_d, n_summaries, n_params)
                The derivative of the network outputs wrt the model parameters

            Returns
            -------
            tuple:
                Fitting statistics calculated on a single validation iteration
                    - **F** *(float(n_params, n_params))* -- Fisher information
                      matrix
                    - **C** *(float(n_summaries, n_summaries))* -- covariance
                      of network outputs
                    - **invC** *(float(n_summaries, n_summaries))* -- inverse
                      covariance of network outputs
                    - **Λ2** *(float)* -- covariance regularisation
                    - **r** *(float)* -- regularisation coupling strength
            """
            F, C, invC, *_ = self._calculate_F_statistics(
                summaries, derivatives)
            _Λ2 = self._get_regularisation(C, invC)
            _r = self._get_regularisation_strength(_Λ2, λ, α)
            return (F, C, invC, _Λ2, _r)

        λ = _check_type(λ, float, "λ")
        ϵ = _check_type(ϵ, float, "ϵ")
        α = self.get_α(λ, ϵ)
        patience = _check_type(patience, int, "patience")
        min_iterations = _check_type(min_iterations, int, "min_iterations")
        max_iterations = _check_type(max_iterations, int, "max_iterations")
        best = _check_boolean(best, "best")
        if self.simulate and (rng is None):
            raise ValueError("`rng` is necessary when simulating.")
        rng = _check_input(rng, (2, ), "rng", allow_None=True)
        max_detF, best_w, detF, detC, detinvC, Λ2, r, counter, \
            patience_counter, state, rng = self._set_inputs(
                rng, max_iterations)
        pbar, print_rate, remainder = self._setup_progress_bar(
            print_rate, max_iterations)
        while self._fit_cond((max_detF, best_w, detF, detC, detinvC, Λ2, r,
                              counter, patience_counter, state, rng),
                             patience=patience,
                             max_iterations=max_iterations):
            rng, training_key, validation_key, w = get_keys_and_params(
                rng, state)
            summaries, summary_derivatives = self.get_summaries(
                w=w, key=training_key)
            dΛ_dx, results = calculate_loss(summaries, summary_derivatives)
            grad = self.get_gradient(dΛ_dx, w, key=training_key)
            state = self._update(counter, grad, state)
            w = self._get_parameters(state)
            detF, detC, detinvC, Λ2, r = self._update_history(
                results, (detF, detC, detinvC, Λ2, r), counter, 0)
            if self.validate:
                summaries, summary_derivatives = self.get_summaries(
                    w=w, key=training_key, validate=True)
                results = validation_loss(summaries, summary_derivatives)
                detF, detC, detinvC, Λ2, r = self._update_history(
                    results, (detF, detC, detinvC, Λ2, r), counter, 1)
            _detF = np.linalg.det(results[0])
            patience_counter, counter, _, max_detF, __, best_w = \
                jax.lax.cond(
                    np.greater(_detF, max_detF),
                    self._update_loop_vars,
                    lambda inputs: self._check_loop_vars(
                        inputs, min_iterations),
                    (patience_counter, counter, _detF, max_detF, w, best_w))
            self._update_progress_bar(pbar, counter, patience_counter,
                                      max_detF, detF[counter], detC[counter],
                                      detinvC[counter], Λ2[counter],
                                      r[counter], print_rate, max_iterations,
                                      remainder)
            counter += 1
        self._update_progress_bar(pbar,
                                  counter,
                                  patience_counter,
                                  max_detF,
                                  detF[counter - 1],
                                  detC[counter - 1],
                                  detinvC[counter - 1],
                                  Λ2[counter - 1],
                                  r[counter - 1],
                                  print_rate,
                                  max_iterations,
                                  remainder,
                                  close=True)
        self.history["max_detF"] = max_detF
        self.best_w = best_w
        self._set_history((detF[:counter], detC[:counter], detinvC[:counter],
                           Λ2[:counter], r[:counter]))
        self.state = state
        self.final_w = self._get_parameters(self.state)
        if best:
            w = self.best_w
        else:
            w = self.final_w
        self.set_F_statistics(w, key=rng)
Beispiel #18
0
def greater(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.greater(x1, x2))
Beispiel #19
0
def gt(a: Numeric, b: Numeric):
    return jnp.greater(a, b)
Beispiel #20
0
  def unstack(self, stacked):
    """Inverts stacking over time.

    Given 'stacked' outputs from this StackingOverTime layer,

      stacked, _ = this_layer.FProp(inputs),

    this method attempts to reconstruct the original 'inputs'.

    If stride > window_size, the original input cannot be recovered, and a
    ValueError is raised.

    Otherwise, if right_context + 1 >= stride, this method returns a Tensor that
    is identical to 'inputs' but potentially longer due to paddings.

    If right_context + 1 < stride, this method returns a Tensor that may be up
    to ```stride - right_context - 1``` frames shorter than the original input,
    but identical in the frames that are returned. e.g.::

      left_context = 2, right_context = 1, stride = 4
      input sequence:     1 2 3 4 5 6 7 8
      after padding:  0 0 1 2 3 4 5 6 7 8 0
      windows:
        [0 0 (1) 2] 3 4 5 6 7 8 0
         0 0 1 2 [3 4 (5) 6] 7 8 0
      stacked:
        [[0 0 1 2], [3 4 5 6]]
      unstacked:
        [1 2 3 4 5 6], which is 4 - 1 - 1 = 2 (stride - right_context - 1)
        frames shorter than the original input.

    `unstack()` can be used to project the outputs of downstream layers back to
    the shape of the original unstacked inputs. For example::

        inputs = ...  # [batch, length, input_dim]
        # [batch, ceil(length / stride), rnn_dim]
        rnn_out = rnn.fprop(stacking.fprop(inputs)[0])
        # [batch, length, rnn_dim]
        back_projected_rnn_out = py_utils.PadOrTrimTo(
            stacking.unstack(jnp.tile(rnn_out, [1, 1, stacking.window_size])),
            inputs.shape)

    Note this method does not take or return a separate padding JTensor. The
    caller is responsible for knowing which of outputs are padding (e.g. based
    on the padding of the original FProp inputs).

    Args:
      stacked: JTensor of shape [batch, time, window_size * feature_dim],
        assumed to be the output of `fprop`.

    Returns:
      The reconstructed input JTensor, with shape
      [batch, (frames - 1) * stride + right_context + 1, feature_dim].

    Raises:
      ValueError: if stride > window_size.
    """
    p = self.params
    if 0 == p.left_context == p.right_context and 1 == p.stride:
      return stacked

    if p.stride > self.window_size:
      raise ValueError(
          "Can't invert StackingOverTime with stride (%d) > window_size (%d)" %
          (p.stride, self.window_size))

    # Reshape to allow indexing individual frames within each stacked window.
    batch_size, stacked_length, _ = stacked.shape
    stacked = jnp.reshape(stacked,
                          [batch_size, stacked_length, self.window_size, -1])

    # Compute the index of the window and frame in 'stacked' where each frame of
    # the original input is located, and extract them with tf.gather_nd.
    # First compute for all except the last window, since these elements have
    # the potential of being looked up from the next window.
    input_indices = jnp.arange(0, (stacked_length - 1) * p.stride)
    mod = input_indices % p.stride
    in_next_window = jnp.greater(mod, p.right_context).astype(jnp.int32)
    window_index = input_indices // p.stride + in_next_window
    frame_index = p.left_context + mod - p.stride * in_next_window
    # Now handle the last window explicitly and concatenate onto the existing
    # window_index/frame_index tensors.
    last_window_length = p.right_context + 1
    window_index = jnp.concatenate([
        window_index,
        jnp.repeat(jnp.array([stacked_length - 1]), last_window_length)
    ],
                                   axis=0)
    frame_index = jnp.concatenate(
        [frame_index, p.left_context + jnp.arange(last_window_length)], axis=0)
    # Stack the indices for gather_nd operation below
    window_and_frame_indices = jnp.stack([window_index, frame_index], axis=1)
    window_and_frame_indices = jnp.tile(
        jnp.expand_dims(window_and_frame_indices, 0), [batch_size, 1, 1])

    # jax equivalent of tf.gather_nd
    def gather_nd_unbatched(params, indices):
      return params[tuple(jnp.moveaxis(indices, -1, 0))]

    return vmap(gather_nd_unbatched, (0, 0), 0)(stacked,
                                                window_and_frame_indices)
Beispiel #21
0
                           lambda x, name=None: scipy_special.erf(x))

erfc = utils.copy_docstring(tf.math.erfc,
                            lambda x, name=None: scipy_special.erfc(x))

exp = utils.copy_docstring(tf.math.exp, lambda x, name=None: np.exp(x))

expm1 = utils.copy_docstring(tf.math.expm1, lambda x, name=None: np.expm1(x))

floor = utils.copy_docstring(tf.math.floor, lambda x, name=None: np.floor(x))

floordiv = utils.copy_docstring(tf.math.floordiv,
                                lambda x, y, name=None: np.floor_divide(x, y))

greater = utils.copy_docstring(tf.math.greater,
                               lambda x, y, name=None: np.greater(x, y))

greater_equal = utils.copy_docstring(
    tf.math.greater_equal, lambda x, y, name=None: np.greater_equal(x, y))

igamma = utils.copy_docstring(
    tf.math.igamma, lambda a, x, name=None: scipy_special.gammainc(a, x))

igammac = utils.copy_docstring(
    tf.math.igammac, lambda a, x, name=None: scipy_special.gammaincc(a, x))

imag = utils.copy_docstring(tf.math.imag,
                            lambda input, name=None: np.imag(input))

# in_top_k = utils.copy_docstring(
#     tf.math.in_top_k,
Beispiel #22
0
 def w_cond(self, args):
     _, loc, counter = args
     return np.logical_and(
         np.logical_or(np.any(np.greater(loc, self.high)),
                       np.any(np.less(loc, self.low))),
         np.less(counter, self.max_counter))
Beispiel #23
0
def delete(m, i):
  n = m.shape[0]
  before_inds = np.arange(n-1)*np.less(np.arange(n-1),i)
  after_inds = (np.arange(n-1)+1)*np.greater(np.arange(n-1)+1,i)
  inds = before_inds + after_inds
  return m[inds]
 def stable_exp(x):
     """If x is greater than thresh, use first order Taylor's expansion."""
     return jnp.where(jnp.greater(x, thresh),
                      jnp.exp(thresh) * (1 + x - thresh), jnp.exp(x))