Beispiel #1
0
 def permute_bwd(ip, permuted_grad):
     # JAX autodiff would synthesize a scatter operation because it doesn't
     # know that the indices are a permutation. However on TPU, gathers are
     # faster than scatters (at least in the regime the LSH attention uses).
     return (None, None, jnp.take(permuted_grad, ip, axis=axis))
Beispiel #2
0
 def vjpfun(permuted_grad):
   # JAX autodiff would synthesize a scatter operation because it doesn't
   # know that the indices are a permutatation. However on TPU, gathers are
   # faster than scatters (at least in the regime the LSH attention uses).
   return (jnp.take(permuted_grad, inverse_permutation, axis=axis),)
Beispiel #3
0
def cross_entropy(params, buffers, p, y):
    return -jnp.take(jnp.log(p.squeeze()), y.squeeze())
Beispiel #4
0
def take(data, indices, dim):
    return jnp.take(
        data,
        indices=indices,
        axis=dim,
    )
from timeit import default_timer as timer

NUSTAR_IMAGE_LENGTH = 64
PSF_IMAGE_LENGTH = 1300

# In radians/pixel
NUSTAR_PIXEL_SIZE = 5.5450564776903175e-05
PSF_PIXEL_SIZE = 2.9793119397393605e-06

PSF_HALF_LENGTH = NUSTAR_IMAGE_LENGTH / 2

PSF = np.load("psf_test.npy")
PSF = np.array(PSF)
print(np.max(PSF))
print(np.min(PSF))
print(np.take(PSF, 10000))
print(np.take(PSF, -10235325))
print(np.take(PSF, 0))
print


def pixel_psf_powerlaw(i, j, source_x, source_y, psf=PSF):
    # d = (
    # 	((PSF_HALF_LENGTH - i) - source_y/NUSTAR_PIXEL_SIZE)**2 +
    # 	((j - PSF_HALF_LENGTH) - source_x/NUSTAR_PIXEL_SIZE)**2
    # )
    def clip(x):
        return np.max([0, np.min([x, NUSTAR_IMAGE_LENGTH])])

    t_x, t_y = source_x / NUSTAR_PIXEL_SIZE, source_y / NUSTAR_PIXEL_SIZE
    i_x, i_y = clip((j - t_x).astype(int)), clip((i + t_y).astype(int))
 def __call__(self, input_ids):
     return jnp.take(self.embeddings, input_ids, axis=0)
Beispiel #7
0
def gather_row(data, row_index):
    return jnp.take(data, row_index, 0)
Beispiel #8
0
    def __call__(self,
                 rng,
                 n_points,
                 percentile=None,
                 acceptance_ratio=0.1,
                 max_iteration=10,
                 max_acceptance=1,
                 max_samples=int(1e5),
                 n_initial_points=None,
                 n_parallel_simulations=None,
                 proposed=None,
                 summaries=None,
                 distances=None,
                 smoothing=None,
                 replace=False):
        """Run the PMC

        Parameters
        ----------
        rng : int(2,)
            A random number generator for generating simulations and randomly
            drawing parameter values from the prior
        n_points : int
            Number of points desired in the final (approximate) posterior
            (note that if `acceptance_ratio` is large then this approximation
            will be bad)
        percentile : int or None, default=None
            The percentage of points to define as the population, i.e. 75 for
            75% of the points. If None only the furthest point is moved one at
            a time
        acceptance_ratio : float
            The ratio of the number of accepted points to total proposals. When
            this gets small it suggests that points aren't being added to the
            population any more because the population is stationary
        max_iteration : int, default=10
            The cutoff number of iterations to break out of the while loop even
            if `acceptance_ratio` is not reached
        max_acceptance : int, default=1
            The cutoff number of attempts in a single iteration to attempt to
            get a sample accepted
        max_samples : int, default=100000
            The number of attempts to get parameter values from the truncated
            normal distribution
        n_initial_points : int or None, default=None
            The number of points to run in the initial ABC to define the
            population before the PMC starts. The PMC will always run from
            scratch if n_initial_points is passed
        n_parallel_simulations : int or None, default=None
            number of simulations to do at once, the innermost (closest)
            summaries are the only ones accepted, but this can massively reduce
            sampling time as long as the simulations can be easily parallelised
        proposed : float(any, n_params) or None, default=None
            A set of proposed parameters which have been used to premake
            simulations. Summaries of these simulations must also be passed as
            `summaries`. These can be used instead of running an initial ABC
            step
        summaries : float(any, n_summaries) or None, default=None
            A set of summaries of simulations which have been premade at
            parameter values corresponding to `proposed`. These can be used
            instead of running an initial ABC step
        distances : float(n_targets, any) or None, default=None
            An optional distance calculation from `summaries`, if this is not
            passed then distances is calculated in the call
        smoothing : float or None, default=None
            A Gaussian smoothing for the marginal distributions
        replace : bool, default=False
            Whether to replace the summaries, parameters and distances already
            obtained when running again

        Returns
        -------
        parameters container:
            All parameters with accepted and rejected attributes
        summaries container:
            All summaries with accepted and rejected attributes
        distances container:
            All distances with accepted and rejected attributes

        Raises
        ------
        ValueError
            if `n_initial_points` is less than `n_points`

        Todo
        ----
        type checking and pytests need writing
        """
        if n_initial_points is not None:
            if n_initial_points < n_points:
                raise ValueError(
                    "`n_initial_points` must be greater than or equal to " +
                    "the final number of points (`n_points`)")
            if n_parallel_simulations is not None:
                rng, *keys = jax.random.split(rng,
                                              num=n_parallel_simulations + 1)
                proposed, summaries = jax.vmap(lambda key: self.get_samples(
                    key, n_initial_points // n_parallel_simulations))(
                        np.array(keys))
            else:
                rng, *keys = jax.random.split(rng, num=n_initial_points + 1)
                proposed, summaries = jax.vmap(
                    lambda key: self.get_samples(key, 1))(np.array(keys))
            proposed = proposed.reshape((n_initial_points, -1))
            summaries = summaries.reshape((n_initial_points, -1))
            distances = jax.vmap(
                lambda target, F: self.distance_measure(summaries, target, F))(
                    self.target_summaries, self.F)
        elif (proposed is not None) and (summaries is not None):
            if distances is None:
                distances = jax.vmap(lambda target, F: self.distance_measure(
                    summaries, target, F))(self.target_summaries, self.F)
        elif (self.parameters.all is not None) and (not replace):
            proposed = self.parameters.all.reshape((-1, self.n_params))
            summaries = self.summaries.all.reshape((-1, self.n_summaries))
            distances = jax.vmap(
                lambda target, F: self.distance_measure(summaries, target, F))(
                    self.target_summaries, self.F)
        else:
            raise ValueError(
                "`proposed` and `summaries` (and optionally `distances`) or " +
                "`n_initial_points` must be provided if PMC has not been " +
                "previously called")

        sample_indices = np.argsort(distances, axis=1)[:, :n_points]
        samples = jax.vmap(lambda x: proposed[x])(sample_indices)
        summaries = jax.vmap(lambda x: summaries[x])(sample_indices)
        distances = np.take(distances, sample_indices)

        weighting = self.prior.prob(samples)

        if percentile is None:
            ϵ_ind = -1
        else:
            ϵ_ind = int(percentile / 100 * n_points)

        key = np.array(jax.random.split(rng, num=self.n_targets))
        (rng, samples, summaries, distances, weighting, acceptance_reached,
         iteration_counter, total_draws) = jax.vmap(
             partial(self.move_samples,
                     ϵ_ind=ϵ_ind,
                     acceptance_ratio=acceptance_ratio,
                     max_iteration=max_iteration,
                     max_acceptance=max_acceptance,
                     max_samples=max_samples,
                     n_parallel_simulations=n_parallel_simulations))(
                         key, samples, summaries, distances, weighting,
                         self.target_summaries, self.F)
        self.set_samples(samples,
                         summaries,
                         distances=distances,
                         replace=replace)
        self.set_accepted(smoothing=smoothing)
        self.acceptance_reached = self.acceptance_reached + acceptance_reached
        self.iterations = self.iterations + iteration_counter
        self.total_draws = self.total_draws + total_draws
        print(f"Acceptance reached {self.acceptance_reached} in " +
              f"{self.iterations} iterations with a total of " +
              f"{self.total_draws} draws")
        return self.parameters, self.distances, self.summaries
    def apply(self,
              inputs,
              inputs_spatial_positions,
              inputs_scale_positions,
              inputs_masks,
              spatial_pos_grid_size,
              num_scales,
              num_layers,
              mlp_dim,
              use_sinusoid_pos_emb=False,
              use_scale_emb=True,
              dropout_rate=0.1,
              train=False,
              dtype=jnp.float32,
              stochastic_layer_drop_rate=0.0,
              **attention_kwargs):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_spatial_positions: input spatial positions for each embedding.
      inputs_scale_positions: input scale positions for each embedding.
      inputs_masks: bool, input mask.
      spatial_pos_grid_size: spatial positional encoding hash grid size.
      num_scales: number of scales input.
      num_layers: number of layers
      mlp_dim: dimension of the mlp on top of attention block.
      use_sinusoid_pos_emb: whether to use Sinusoidal Positional Embedding.
      use_scale_emb: use scale embedding.
      dropout_rate: dropout rate
      train: if it is training,
      dtype: dtype of activations.
      stochastic_layer_drop_rate: probability of dropping a layer linearly grows
        from 0 to the provided value. Our implementation of stochastic depth
        follows timm library, which does per-example layer dropping and uses
        independent dropping patterns for each skip-connection.
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 3  # (batch, len, emb)
        dtype = jax.dtypes.canonicalize_dtype(dtype)

        if not use_sinusoid_pos_emb:
            x = AddHashSpatialPositionEmbs(
                inputs,
                spatial_pos_grid_size,
                inputs_positions=inputs_spatial_positions,
                posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
                name="posembed_input")
        else:
            pos_emb_shape = (1, spatial_pos_grid_size * spatial_pos_grid_size,
                             inputs.shape[2])
            pe = get_sinusoid_encoding(pos_emb_shape[1], pos_emb_shape[2])
            pe = jnp.expand_dims(pe, axis=0)
            x = inputs + jnp.take(pe[0], inputs_spatial_positions, axis=0)

        if use_scale_emb:
            x = AddScaleEmbs(
                x,
                num_scales=num_scales,
                inputs_positions=inputs_scale_positions,
                scale_emb_init=nn.initializers.normal(stddev=0.02),
                name="scaleembed_input")

        n, _, c = x.shape
        cls = self.param("cls", (1, 1, c), nn.initializers.zeros)
        cls = jnp.tile(cls, [n, 1, 1])
        x = jnp.concatenate([cls, x], axis=1)

        cls_mask = jnp.ones((n, 1), dtype=inputs_masks.dtype)
        inputs_masks = jnp.concatenate([cls_mask, inputs_masks], axis=1)

        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        # Input Encoder
        for lyr in range(num_layers):
            layer_drop_p = (
                lyr / max(num_layers - 1, 1)) * stochastic_layer_drop_rate
            x = Encoder1DBlock(x,
                               mlp_dim=mlp_dim,
                               inputs_masks=inputs_masks,
                               dropout_rate=dropout_rate,
                               deterministic=not train,
                               name=f"encoderblock_{lyr}",
                               dtype=dtype,
                               layer_drop_p=layer_drop_p,
                               **attention_kwargs)
        encoded = nn.LayerNorm(x, name="encoder_norm")

        return encoded
Beispiel #10
0
 def _build_messages(self, node_x, edge_x, sources, targets):
     del edge_x  # Unused.
     source_x = jnp.take(node_x, sources, axis=0)
     return source_x
def lsh_attention_single_head(query,
                              value,
                              n_buckets,
                              n_hashes,
                              causal_mask=True,
                              length_norm=False):
    """Applies LSH attention on a single head and a single batch.

  Args:
    query: query tensor of shape [qlength, dims].
    value: value tensor of shape [vlength, dims].
    n_buckets: integer, number of buckets.
    n_hashes: integer, number of hashes.
    causal_mask: boolean, to use causal mask or not.
    length_norm: boolean, to normalize k or not.
  Returns:
    output tensor of shape [qlength, dims]
  """

    qdim, vdim = query.shape[-1], value.shape[-1]
    chunk_size = n_hashes * n_buckets

    seqlen = query.shape[0]

    with nn.stochastic(jax.random.PRNGKey(0)):
        rng = nn.make_rng()

    buckets = hash_vectors(query,
                           rng,
                           num_buckets=n_buckets,
                           num_hashes=n_hashes)
    # buckets should be (seq_len)
    assert buckets.shape[-1] == n_hashes * seqlen

    total_hashes = n_hashes

    # create sort and unsort
    ticker = jax.lax.tie_in(query, jnp.arange(n_hashes * seqlen))
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    buckets_and_t = jax.lax.stop_gradient(buckets_and_t)
    # ticker = jnp.tile(jnp.reshape(ticker, [1, -1]), [batch_size, 1])
    sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                   ticker,
                                                   dimension=-1)
    _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
    sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
    sticker = jax.lax.stop_gradient(sticker)
    undo_sort = jax.lax.stop_gradient(undo_sort)

    st = (sticker % seqlen)

    sqk = jnp.take(query, st, axis=0)
    sv = jnp.take(value, st, axis=0)

    bkv_t = jnp.reshape(st, (chunk_size, -1))
    bqk = jnp.reshape(sqk, (chunk_size, -1, qdim))
    bv = jnp.reshape(sv, (chunk_size, -1, vdim))
    bq = bqk
    bk = bqk

    if length_norm:
        bk = length_normalized(bk)

    # get previous chunks
    bk = look_one_back(bk)
    bv = look_one_back(bv)
    bkv_t = look_one_back(bkv_t)

    # compute dot product attention
    dots = jnp.einsum('hie,hje->hij', bq, bk) * (qdim**0.5)

    if causal_mask:
        # apply causal mask
        # TODO(yitay): This is not working yet
        # We don't need causal reformer for any task YET.
        pass

    dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True)
    slogits = jnp.reshape(dots_logsumexp, [-1])
    dots = jnp.exp(dots - dots_logsumexp)

    x = jnp.matmul(dots, bv)
    x = jnp.reshape(x, [-1, qdim])

    # Unsort
    o = permute_via_gather(x, undo_sort, sticker, axis=0)
    logits = permute_via_sort(slogits, sticker, undo_sort, axis=0)
    logits = jnp.reshape(logits, [total_hashes, seqlen, 1])
    probs = jnp.exp(logits - logsumexp(logits, axis=0, keepdims=True))
    o = jnp.reshape(o, [n_hashes, seqlen, qdim])
    out = jnp.sum(o * probs, axis=0)
    out = jnp.reshape(out, [seqlen, qdim])

    return out
Beispiel #12
0
 def test_gather(self):
     values = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
     indices = np.array([0, 1], dtype=np.int32)
     for axis in (0, 1):
         f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
         self.ConvertAndCompare(f_jax, values, indices, with_function=True)
Beispiel #13
0
def sample_from_array(rng_key, x, n, axis):
    """ Samples n elements from a given array without replacement.

    Uses the Feistel shuffle to uniformly draw
    n unique elements from x along the given axis.

    :param rng_key: jax prng key used for sampling.
    :param x: the array from which elements are sampled
    :param n: how many elements to return
    :param axis: axis along which samples are drawn
    """
    capacity = np.uint32(np.shape(x)[axis])
    data = np.arange(n, dtype=np.uint32)

    seed = jax.random.randint(rng_key,
                              shape=(1, ),
                              minval=0,
                              maxval=capacity,
                              dtype=np.uint32).squeeze()

    def permute32(vals):
        def hash_func_in(x):
            x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16)))
            x *= jnp.uint32(0x85ebca6b)
            x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(13)))
            x *= jnp.uint32(0xc2b2ae35)
            x = jnp.bitwise_xor(x, jnp.right_shift(x, jnp.uint32(16)))

            return x

        num_iters = np.uint32(8)

        bits = jnp.uint32(len(bin(capacity)) - 2)
        bits_lower = jnp.right_shift(bits, 1)
        bits_upper = bits - bits_lower
        mask_lower = (jnp.left_shift(jnp.uint32(1),
                                     bits_lower)) - jnp.uint32(1)

        seed_offst = hash_func_in(seed)
        position = vals

        def iter_func(position):
            for j in range(num_iters):
                j = jnp.uint32(j)
                upper = jnp.right_shift(position, bits_lower)
                lower = jnp.bitwise_and(position, mask_lower)
                mixer = hash_func_in(upper + seed_offst + j)

                tmp = jnp.bitwise_xor(lower, mixer)
                position = upper + (jnp.left_shift(
                    jnp.bitwise_and(tmp, mask_lower), bits_upper))
            return position

        position = iter_func(position)
        position = jax.lax.while_loop(lambda position: position >= capacity,
                                      iter_func, position)

        return position

    func = jax.vmap(permute32)
    a = func(data)

    return jnp.take(x, a, axis=axis)
Beispiel #14
0
def nms(boxes: Tensor,
        scores: Tensor,
        overlap_threshold: float = .5,
        score_threshold: float = .5,
        boxes_fmt: BoxesFormat = BoxesFormat.xyxy) -> Tensor:
    """
    Performs Non maxima supperssion with the given boxes
    Selects the boxes with higher score and discards the ones pointing to the
    same object.

    Parameters
    ----------
    boxes: Tensor of shape [N, 4]
        Boxes formated according to the input parameter `boxes_fmt`
    scores: Tensor of shape [N]
        Boxes scores ranging from 0 to 1 being 1 a higher value
    boxes_fmt: BoxesFormat, default xyxy
        Format of the boxes, by default it is set to 
        [x_min, y_min, x_max, y_max]
    overlap_threshold: float, default .5
        Overlapping boxes pointing to the same object with an iou larger than
        this threshold are going to be discarded. NMS will only keep the one
        with the highest score
    score_threshold: float, default 0.5
        Boxes with a lower score than score_threshold will be discarded

    Returns
    -------
    Tensor of shape [N]
        A mask containing True if the box has to be kept and False otherwise
    
    Examples
    --------
    >>> N = 10
    >>> boxes = jax.random.uniform(key, shape=(N, 4))
    >>> scores = np.ones(N)
    >>> keep_mask = aj.ops.nms(boxes, scores)
    >>> desired_boxes = boxes[keep_mask]
    """

    if boxes_fmt != BoxesFormat.xyxy:
        convert_fn = getattr(boxes_utils, f'{boxes_fmt.value}_to_xyxy')
        boxes = convert_fn(boxes)

    n = boxes.shape[0]

    # Sort boxes by score
    sort_idx = np.argsort(-scores)
    scores = np.take(scores, sort_idx)
    boxes = boxes[sort_idx]

    # Compute iou of sorted boxes, hence the boxes with lower scores are going
    # to be at the end of the rows
    ious = iou(boxes, boxes)

    # Set itsself iou to 0 (iou matrix diag to 0)
    ious_mask = 1 - np.eye(n)
    ious = ious * ious_mask

    score_keep_mask = scores > score_threshold

    overlapping_boxes = ious > overlap_threshold

    # Since boxes with lower scores are at the end of the iou matrix rows
    # we create a mask containing ones for every element which is over the
    # diagonal. For box at index 3 all elements after the third index of its
    # row are going to have a lower score
    smaller_score_than_mask = sum([np.eye(n, k=i) for i in range(n)])
    smaller_score_than_mask = smaller_score_than_mask.astype('bool')

    # If a box overlaps and has a lower score we discard it
    # To compute the keep mask we invert the discard mask
    iou_keep_mask = ~(overlapping_boxes & smaller_score_than_mask)
    iou_keep_mask = np.all(iou_keep_mask, axis=0)

    keep_mask = iou_keep_mask & score_keep_mask

    # Undo the sort and return the mask according to the input boxes
    init_correspondence_idx = jax.ops.index_update(
        np.zeros_like(sort_idx), sort_idx, np.arange(sort_idx.shape[0]))

    return keep_mask[init_correspondence_idx]
Beispiel #15
0
 def apply_fun(params, inputs, **kwargs):
   del kwargs
   dense_embedding = params
   return np.take(dense_embedding, inputs, axis=0)
Beispiel #16
0
def bilinear_sampler(imgs, coords, mask_value):
    """Construct a new image by bilinear sampling from the input image.
    Points falling outside the source image boundary have value of mask_value.
    Args:
        imgs: source image to be sampled from [b, h, w, c]
        coords: coordinates of source pixels to sample from [b, h, w, 2].
            height_t/width_t correspond to the dimensions of the output
            image (don't need to be the same as height_s/width_s).
            The two channels correspond to x and y coordinates respectively.
        mask_value: value of points outside of image. -1 for edge sampling.
        Returns:
            A new sampled image [height_t, width_t, channels]
    """
    coords_x, coords_y = jnp.split(coords, 2, axis=2)
    inp_size = imgs.shape
    out_size = list(coords.shape)
    out_size[2] = imgs.shape[2]

    coords_x = jnp.array(coords_x, dtype='float32')
    coords_y = jnp.array(coords_y, dtype='float32')

    y_max = jnp.array(jnp.shape(imgs)[0] - 1, dtype='float32')
    x_max = jnp.array(jnp.shape(imgs)[1] - 1, dtype='float32')
    zero = jnp.zeros([1], dtype='float32')
    eps = jnp.array([0.5], dtype='float32')

    coords_x_clipped = jnp.clip(coords_x, zero, x_max - eps)
    coords_y_clipped = jnp.clip(coords_y, zero, y_max - eps)

    x0 = jnp.floor(coords_x_clipped)
    x1 = x0 + 1
    y0 = jnp.floor(coords_y_clipped)
    y1 = y0 + 1

    x0_safe = jnp.clip(x0, zero, x_max)
    y0_safe = jnp.clip(y0, zero, y_max)
    x1_safe = jnp.clip(x1, zero, x_max)
    y1_safe = jnp.clip(y1, zero, y_max)

    # bilinear interp weights, with points outside the grid having weight 0
    # wt_x0 = (x1 - coords_x) * jnp.equal(x0, x0_safe).astype('float32')
    # wt_x1 = (coords_x - x0) * jnp.equal(x1, x1_safe).astype('float32')
    # wt_y0 = (y1 - coords_y) * jnp.equal(y0, y0_safe).astype('float32')
    # wt_y1 = (coords_y - y0) * jnp.equal(y1, y1_safe).astype('float32')

    wt_x0 = x1_safe - coords_x  # 1
    wt_x1 = coords_x - x0_safe  # 0
    wt_y0 = y1_safe - coords_y  # 1
    wt_y1 = coords_y - y0_safe  # 0

    # indices in the flat image to sample from
    dim2 = jnp.array(inp_size[1], dtype='float32')

    base_y0 = y0_safe * dim2
    base_y1 = y1_safe * dim2
    idx00 = jnp.reshape(x0_safe + base_y0, [-1])
    idx01 = x0_safe + base_y1
    idx10 = x1_safe + base_y0
    idx11 = x1_safe + base_y1

    # sample from imgs
    imgs_flat = jnp.reshape(imgs, [-1, inp_size[2]])
    imgs_flat = imgs_flat.astype('float32')
    im00 = jnp.reshape(
        jnp.take(imgs_flat, idx00.astype('int32'), axis=0), out_size)
    im01 = jnp.reshape(
        jnp.take(imgs_flat, idx01.astype('int32'), axis=0), out_size)
    im10 = jnp.reshape(
        jnp.take(imgs_flat, idx10.astype('int32'), axis=0), out_size)
    im11 = jnp.reshape(
        jnp.take(imgs_flat, idx11.astype('int32'), axis=0), out_size)

    w00 = wt_x0 * wt_y0
    w01 = wt_x0 * wt_y1
    w10 = wt_x1 * wt_y0
    w11 = wt_x1 * wt_y1

    output = jnp.clip(jnp.round(w00 * im00 + w01 * im01 + w10 * im10 +
                                w11 * im11), 0, 255)

    return jnp.where(jnp.all(mask_value >= 0),
                     jnp.where(
                         compute_mask(coords_x, coords_y, x_max, y_max),
                         output,
                         jnp.ones_like(output) *
                         jnp.reshape(jnp.array(mask_value), [1, 1, -1])
                     ),
                     output)
Beispiel #17
0
class JaxBox(qml.math.TensorBox):
    """Implements the :class:`~.TensorBox` API for ``numpy.ndarray``.

    For more details, please refer to the :class:`~.TensorBox` documentation.
    """

    abs = wrap_output(lambda self: jnp.abs(self.data))
    angle = wrap_output(lambda self: jnp.angle(self.data))
    arcsin = wrap_output(lambda self: jnp.arcsin(self.data))
    cast = wrap_output(lambda self, dtype: jnp.array(self.data, dtype=dtype))
    expand_dims = wrap_output(
        lambda self, axis: jnp.expand_dims(self.data, axis=axis))
    ones_like = wrap_output(lambda self: jnp.ones_like(self.data))
    sqrt = wrap_output(lambda self: jnp.sqrt(self.data))
    sum = wrap_output(lambda self, axis=None, keepdims=False: jnp.sum(
        self.data, axis=axis, keepdims=keepdims))
    T = wrap_output(lambda self: self.data.T)
    take = wrap_output(lambda self, indices, axis=None: jnp.take(
        self.data, indices, axis=axis, mode="wrap"))

    def __init__(self, tensor):
        tensor = jnp.asarray(tensor)

        super().__init__(tensor)

    @staticmethod
    def astensor(tensor):
        return jnp.asarray(tensor)

    @staticmethod
    @wrap_output
    def concatenate(values, axis=0):
        return jnp.concatenate(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def dot(x, y):
        x, y = JaxBox.unbox_list([x, y])
        x = jnp.asarray(x)
        y = jnp.asarray(y)

        if x.ndim == 0 and y.ndim == 0:
            return x * y

        if x.ndim == 2 and y.ndim == 2:
            return x @ y

        return jnp.dot(x, y)

    @property
    def interface(self):
        return "jax"

    def numpy(self):
        return self.data

    @property
    def requires_grad(self):
        return True

    @property
    def shape(self):
        return self.data.shape

    @staticmethod
    @wrap_output
    def stack(values, axis=0):
        return jnp.stack(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def where(condition, x, y):
        return jnp.where(condition, *JaxBox.unbox_list([x, y]))
Beispiel #18
0
 def __call__(self, inputs):
     embedding = self.param("weight", self.emb_init,
                            (self.vocab_size, self.hidden_size))
     return jnp.take(embedding, inputs, axis=0)
    def __call__(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        sincos = jnp.take(self.embed_positions, position_ids, axis=0)
        sincos = jnp.split(sincos, 2, axis=-1)
        if self.rotary_dim is not None:
            k_rot = key[:, :, :, :self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim:]

            q_rot = query[:, :, :, :self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim:]

            k_rot = apply_rotary_pos_emb(k_rot, sincos)
            q_rot = apply_rotary_pos_emb(q_rot, sincos)

            key = jnp.concatenate([k_rot, k_pass], axis=-1)
            query = jnp.concatenate([q_rot, q_pass], axis=-1)
        else:
            key = apply_rotary_pos_emb(key, sincos)
            query = apply_rotary_pos_emb(query, sincos)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
        )

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Beispiel #20
0
        def _full_update_step(
            state,
            transitions,
            in_initial_bc_iters,
        ):

            key, key_alpha, key_critic, key_actor = jax.random.split(
                state.key, 4)
            if adaptive_entropy_coefficient:
                alpha = jnp.exp(state.alpha_params)
            else:
                alpha = entropy_coefficient

            if snr_applied_to == 'critic':
                (critic_loss_value,
                 (q_loss, snr_term, masked_s, c_matrix, snr_state,
                  snr_loss_weight)), critic_grads = critic_val_and_grad(
                      state.q_params, state.policy_params,
                      state.target_q_params, alpha, transitions, key_critic,
                      state.snr_state, in_initial_bc_iters)
            else:
                (critic_loss_value,
                 (q_loss, _, _, _, _, _)), critic_grads = critic_val_and_grad(
                     state.q_params, state.policy_params,
                     state.target_q_params, alpha, transitions, key_critic,
                     state.snr_state, in_initial_bc_iters)

            if snr_applied_to == 'policy':
                (actor_loss_value,
                 (min_q, log_prob, snr_term, masked_s, c_matrix, snr_state,
                  snr_loss_weight)), actor_grads = actor_val_and_grad(
                      state.policy_params,
                      state.q_params,
                      state.target_q_params,
                      alpha,
                      transitions,
                      key_actor,
                      state.snr_state,
                      in_initial_bc_iters,
                  )
            else:
                (actor_loss_value, (min_q, log_prob, _, _, _, _,
                                    _)), actor_grads = actor_val_and_grad(
                                        state.policy_params,
                                        state.q_params,
                                        state.target_q_params,
                                        alpha,
                                        transitions,
                                        key_actor,
                                        state.snr_state,
                                        in_initial_bc_iters,
                                    )

            # Apply policy gradients
            actor_update, policy_optimizer_state = policy_optimizer.update(
                actor_grads, state.policy_optimizer_state)
            policy_params = optax.apply_updates(state.policy_params,
                                                actor_update)

            # Apply critic gradients
            critic_update, q_optimizer_state = q_optimizer.update(
                critic_grads, state.q_optimizer_state)
            q_params = optax.apply_updates(state.q_params, critic_update)

            # new_target_q_params = jax.tree_map(
            #     lambda x, y: x * (1 - tau) + y * tau, state.target_q_params, q_params)
            new_target_q_params = q_params

            metrics = OrderedDict()
            metrics['critic_loss'] = q_loss
            metrics['actor_loss'] = actor_loss_value
            metrics['actor_log_probs'] = jnp.mean(log_prob)
            metrics['q/avg'] = jnp.mean(min_q)
            metrics['q/std'] = jnp.std(min_q)
            metrics['q/max'] = jnp.max(min_q)
            metrics['q/min'] = jnp.min(min_q)
            metrics['snr/loss'] = snr_term
            metrics['snr/loss_weight'] = snr_loss_weight
            num_gt_zero = jnp.sum(masked_s > 0.)
            metrics['snr/num_gt_zero'] = num_gt_zero
            min_s = jnp.take(masked_s, [num_gt_zero - 1], axis=0)[0]
            num_gt_zero = num_gt_zero + 1e-6
            mean_s = jnp.sum(masked_s) / num_gt_zero
            std_s = jnp.sqrt((jnp.sum(masked_s**2) / num_gt_zero) - mean_s**2)
            metrics['snr/avg'] = mean_s
            metrics['snr/std'] = std_s
            metrics['snr/max'] = jnp.max(masked_s)
            metrics['snr/min'] = min_s

            new_state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=new_target_q_params,
                key=key,
                snr_state=snr_state,
            )

            # alpha update step
            if (not in_initial_bc_iters) and adaptive_entropy_coefficient:
                alpha_loss, alpha_grads = alpha_val_and_grad(
                    state.alpha_params, jnp.mean(log_prob))
                alpha_update, alpha_optimizer_state = alpha_optimizer.update(
                    alpha_grads, state.alpha_optimizer_state)
                alpha_params = optax.apply_updates(state.alpha_params,
                                                   alpha_update)
                # metrics.update({
                #     'alpha_loss': alpha_loss,
                #     'alpha': jnp.exp(alpha_params),
                # })
                new_state = new_state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=alpha_params)
                metrics['alpha'] = jnp.exp(alpha_params)
                metrics['alpha_loss'] = alpha_loss
            else:
                new_state = new_state._replace(
                    alpha_optimizer_state=state.alpha_optimizer_state,
                    alpha_params=state.alpha_params)
                metrics['alpha'] = alpha
                metrics['alpha_loss'] = 0.

            return new_state, metrics
Beispiel #21
0
def slice_axis(data, axis, begin, end):
    return jnp.take(
        data,
        indices=jnp.arange(start=begin, stop=end),
        axis=axis,
    )
Beispiel #22
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 is_training: bool) -> jnp.ndarray:
        """Update node features.

        Parameters
        ----------
        node_feats : ndarray of shape (N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch
        adj : ndarray of shape (2, E)
            Batch adjacency list.
            E is the total number of edges in the batch
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        new_node_feats : ndarray of shape (N, out_feats)
            Batch new node features.
        """
        dropout = self.dropout if is_training is True else 0.0
        num_nodes = node_feats.shape[0]

        # affine transformation
        new_node_feats = jnp.dot(node_feats, self.w)
        if self.bias:
            new_node_feats += self.b

        # update nodes
        if self.normalize:
            # add self connection
            self_loop = jnp.tile(jnp.arange(num_nodes), (2, 1))
            adj = jnp.concatenate((adj, self_loop), axis=1)
            src_idx, dest_idx = adj[0], adj[1]

            # calculate the norm
            degree = segment_sum(jnp.ones(len(dest_idx)),
                                 dest_idx,
                                 num_segments=num_nodes)
            deg_inv_sqrt = jax.lax.pow(degree, -0.5)
            norm = deg_inv_sqrt[src_idx] * deg_inv_sqrt[dest_idx]

            # update nodes
            source_feats = jnp.take(new_node_feats, src_idx, axis=0)
            source_feats = norm.reshape(-1, 1) * source_feats
            new_node_feats = segment_sum(source_feats,
                                         dest_idx,
                                         num_segments=num_nodes)
        else:
            src_idx, dest_idx = adj[0], adj[1]
            source_feats = jnp.take(new_node_feats, src_idx, axis=0)
            aggregated_messages = segment_sum(source_feats,
                                              dest_idx,
                                              num_segments=num_nodes)
            new_node_feats = jnp.add(aggregated_messages, new_node_feats)

        new_node_feats = self.activation(new_node_feats)

        if dropout != 0.0:
            new_node_feats = hk.dropout(hk.next_rng_key(), dropout,
                                        new_node_feats)
        if self.batch_norm:
            new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats,
                                                           is_training)

        return new_node_feats
Beispiel #23
0
def onnx_gather(x, indices, axis=0):
    return jnp.take(x, indices, axis=axis)
Beispiel #24
0
        "",
        lambda lhs, rhs: lax.dot_general(
            lhs, rhs, dimension_numbers=(((2, ), (1, )), ((0, ), (0, )))),
        [RandArg(
            (3, 4, 4), _f32), RandArg((3, 4), _f32)],
        poly_axes=[0, 0]),
    _make_harness(
        "dynamic_slice",
        "",
        # x:shape: (b, 4)
        lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
        [RandArg((3, 4), _f32)],
        poly_axes=[0]),
    _make_harness("jnp_take",
                  "",
                  lambda a, i: jnp.take(a, i, axis=1),
                  [RandArg((3, 4, 5), _f32),
                   np.array([1, 2], np.int32)],
                  poly_axes=[0, None]),
    _make_harness(
        "jnp_getitem",
        "",
        lambda a, i: a[i], [RandArg(
            (3, 4), _f32), np.array([2, 2], np.int32)],
        poly_axes=[None, 0]),

    # TODO(necula): not supported yet
    # _make_harness("jnp_getitem", "",
    #               lambda a, i: a[i],
    #               [RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
    #               poly_axes=[0, 0]),
Beispiel #25
0
 def permute_impl(val):
   return jnp.take(val, permutation, axis=axis)
Beispiel #26
0
def _piecewise_constant(boundaries, values, t):
    index = jnp.sum(boundaries < t)
    return jnp.take(values, index)
Beispiel #27
0
           dtype=dtype)
   for dtype in jtu.dtypes.all_floating
   for arg1, arg2, arg3 in [
     (np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=dtype),
      np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=dtype),
      np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6],
               dtype=np.float32))
  ]
)


_gather_input = np.arange(1000, dtype=np.float32).reshape((10, 10, 10))
lax_gather = tuple(
  # Construct gather harnesses using take
  [Harness(f"from_take_indices_shape={indices.shape}_axis={axis}",
           lambda a, i, axis: jnp.take(a, i, axis=axis),
           [_gather_input,
            indices,
            StaticArg(axis)])
  for indices in [
    # Ensure each set of indices has a distinct shape
    np.array(2, dtype=np.int32),
    np.array([2], dtype=np.int32),
    np.array([2, 4], dtype=np.int32),
    np.array([[2, 4], [5, 6]], dtype=np.int32),
    np.array([0, 1, 10], dtype=np.int32),  # Index out of bounds
    np.array([0, 1, 2, -1], dtype=np.int32),  # Index out of bounds
  ]
  for axis in [0, 1, 2]] +

  # Directly from lax.gather in lax_test.py.
Beispiel #28
0
    def __call__(self, x: Array) -> Array:
        """Applies the symmetrized linear transformation to the inputs along the last dimension.

        Args:
          x: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)
        # infer in_features and ensure input dimensions (batch, in_features,n_sites)

        # TODO: Deprecated: Eventually remove and error if less than 3 dimensions
        if x.ndim < 3:
            old_shape = x.shape
            if x.ndim == 1:
                x = jnp.expand_dims(x, (0, 1))
            elif x.ndim == 2:
                x = jnp.expand_dims(x, 1)
            symm_input_warning(old_shape, x.shape, "DenseSymm")

        in_features = x.shape[1]

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_sites),
            self.dtype,
        )

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Converts the convolutional kernel of shape (self.features, in_features, n_sites)
        # to a full dense kernel of shape (self.features, in_features, n_symm, n_sites).
        # result[out, in, g, r] == kernel[out, in, g^{-1}r]
        kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2)
        kernel = jnp.asarray(kernel, dtype)

        # x is      (batches,       in_featuers,         n_sites)
        # kernel is (self.features, in_features, n_symm, n_sites)
        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 2, x.ndim - 1), (1, 3)), ((), ())),
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)

            # Convert symmetry-reduced bias of shape (features,) to the full bias of
            # shape (..., features, 1).
            bias = jnp.expand_dims(bias, 1)
            bias = jnp.asarray(bias, dtype)

            x += bias

        return x
Beispiel #29
0
 def f(a, i):
   return jnp.take(a, i, axis=1)
Beispiel #30
0
 def permute_fwd(p, ip, val):
     return jnp.take(val, p, axis=axis), ip