示例#1
0
def matching_log_probas(embeddings,
                        targets,
                        test_embeddings,
                        num_classes,
                        eps=1e-8):
    num_samples = test_embeddings.shape[0]
    similarities = pairwise_cosine_similarity(embeddings,
                                              test_embeddings,
                                              eps=eps)
    logsumexp = nn.logsumexp(similarities, axis=0, keepdims=True)

    max_similarities = jnp.max(similarities, axis=0, keepdims=True)
    exp_similarities = jnp.exp(similarities - max_similarities)

    sum_exp = jnp.zeros((num_classes, num_samples),
                        dtype=exp_similarities.dtype)
    indices = jnp.expand_dims(targets, axis=-1)
    dimension_numbers = ScatterDimensionNumbers(
        update_window_dims=(1, ),
        inserted_window_dims=(0, ),
        scatter_dims_to_operand_dims=(0, ))
    sum_exp = scatter_add(sum_exp, indices, exp_similarities,
                          dimension_numbers)

    return jnp.log(sum_exp) + max_similarities - logsumexp
示例#2
0
def _scale_and_translate(x, output_shape, scale, translate, kernel, antialias,
                         precision):
    input_shape = x.shape
    assert len(input_shape) == len(output_shape)
    assert len(input_shape) == len(scale)
    assert len(input_shape) == len(translate)
    spatial_dims, = np.nonzero(
        np.not_equal(input_shape, output_shape) | np.not_equal(scale, 1)
        | np.not_equal(translate, 0))
    if len(spatial_dims) == 0:
        return x
    contractions = []
    in_indices = list(range(len(output_shape)))
    out_indices = list(range(len(output_shape)))
    for i, d in enumerate(spatial_dims):
        m = input_shape[d]
        n = output_shape[d]
        starts, span_weights = _compute_spans(m,
                                              n,
                                              scale[d],
                                              translate[d],
                                              kernel,
                                              antialias=antialias)
        dnums = lax.ScatterDimensionNumbers(update_window_dims=(1, ),
                                            inserted_window_dims=(1, ),
                                            scatter_dims_to_operand_dims=(0,
                                                                          1))
        w = lax.scatter_add(jnp.zeros((m, n), x.dtype),
                            np.stack([starts, np.arange(n)], axis=-1),
                            span_weights.astype(x.dtype), dnums)
        contractions.append(w)
        contractions.append([d, len(output_shape) + i])
        out_indices[d] = len(output_shape) + i
    contractions.append(out_indices)
    return jnp.einsum(x, in_indices, *contractions, precision=precision)
示例#3
0
 def f_jax(x, upd):
   return lax.scatter_add(
       x,
       scatter_indices=idx,
       updates=upd,
       dimension_numbers=lax.ScatterDimensionNumbers(*dimension_numbers),
       indices_are_sorted=False,
       unique_indices=True)
示例#4
0
 def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                        rng_factory, rng_idx_factory):
   rng = rng_factory(self.rng())
   rng_idx = rng_idx_factory(self.rng())
   idxs = rng_idx(idxs.shape, idxs.dtype)
   scatter_add = lambda x, y: lax.scatter_add(x, idxs, y,
                                              dimension_numbers=dnums)
   x = rng(arg_shape, dtype)
   y = rng(update_shape, dtype)
   check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
示例#5
0
文件: util.py 项目: while519/numpyro
def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return jnp.reshape(x, x.shape[:-1] + (n, n))
示例#6
0
def get_num_samples(targets, num_classes, dtype=None):
    ones = jnp.ones_like(targets, dtype=dtype)
    indices = jnp.expand_dims(targets, axis=-1)
    num_samples = jnp.zeros(targets.shape[:-1] + (num_classes, ),
                            dtype=ones.dtype)
    dimension_numbers = ScatterDimensionNumbers(
        update_window_dims=(),
        inserted_window_dims=(0, ),
        scatter_dims_to_operand_dims=(0, ))
    return scatter_add(num_samples, indices, ones, dimension_numbers)
示例#7
0
def _scatter_add_one(operand, indices, updates):
    return lax.scatter_add(
        operand,
        indices,
        updates,
        lax.ScatterDimensionNumbers(
            update_window_dims=(),
            inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,),
        ),
    )
示例#8
0
def get_prototypes(embeddings, targets, num_classes):
    embedding_size, dtype = embeddings.shape[-1], embeddings.dtype
    num_samples = get_num_samples(targets, num_classes, dtype=dtype)
    num_samples = jnp.expand_dims(jnp.maximum(num_samples, 1), axis=-1)

    prototypes = jnp.zeros((num_classes, embedding_size), dtype=dtype)
    indices = jnp.expand_dims(targets, axis=-1)
    dimension_numbers = ScatterDimensionNumbers(
        update_window_dims=(1, ),
        inserted_window_dims=(0, ),
        scatter_dims_to_operand_dims=(0, ))
    prototypes = scatter_add(prototypes, indices, embeddings,
                             dimension_numbers)

    return prototypes / num_samples
示例#9
0
    def predict(self, params, logits, context, target=None):
        context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context,
                                                                  axis=1),
                                                  axis=1),
                                  axis=1)
        context_bias = params.get('context_bias', 0.0)
        context_index = (params['context_maps'] *
                         context).sum(axis=-1) > context_bias

        context_map_values = jnp.asarray(
            [[[[1 << n for n in range(self.context_map_size)]]]])
        context_index = jnp.where(context_index, context_map_values, 0)
        context_index = context_index.sum(axis=-1, keepdims=True)

        batch_size = logits.shape[0]
        class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)]
                                           for c in range(self.num_classes)]])
        class_neuron_index = jnp.tile(class_neuron_index,
                                      reps=(batch_size, 1, 1, 1))
        context_index = jnp.concatenate([class_neuron_index, context_index],
                                        axis=-1)

        dims = lax.GatherDimensionNumbers(offset_dims=(3, ),
                                          collapsed_slice_dims=(0, 1, 2),
                                          start_index_map=(0, 1, 2))
        weights = lax.gather(operand=params['weights'],
                             start_indices=context_index,
                             dimension_numbers=dims,
                             slice_sizes=(1, 1, 1,
                                          self.input_size + int(self.bias)))

        if self.bias:
            bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1))
            logits = jnp.concatenate([logits, bias], axis=-1)
        logits = jnp.expand_dims(logits, axis=-1)

        output_logits = jnp.matmul(weights, logits)
        output_logits = jnp.clip(output_logits,
                                 a_min=jsp.special.logit(self.pred_clipping),
                                 a_max=jsp.special.logit(1.0 -
                                                         self.pred_clipping))

        if target is None:
            return jnp.squeeze(output_logits, axis=-1)

        else:
            logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2)
            output_preds = jnn.sigmoid(output_logits)
            target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1)
            params['lr_step'], learning_rate = self.learning_rate.value(
                params['lr_step'])
            delta = learning_rate * (target - output_preds) * logits

            dims = lax.ScatterDimensionNumbers(
                update_window_dims=(3, ),
                inserted_window_dims=(0, 1, 2),
                scatter_dims_to_operand_dims=(0, 1, 2))

            if self.weight_clipping is None:
                params['weights'] = lax.scatter_add(
                    operand=params['weights'],
                    scatter_indices=context_index,
                    updates=delta,
                    dimension_numbers=dims)
            else:
                weights = jnp.clip(weights + delta,
                                   a_min=-self.weight_clipping,
                                   a_max=self.weight_clipping)
                params['weights'] = lax.scatter(operand=params['weights'],
                                                scatter_indices=context_index,
                                                updates=weights,
                                                dimension_numbers=dims)

            return params, jnp.squeeze(output_logits, axis=-1)