def matching_log_probas(embeddings, targets, test_embeddings, num_classes, eps=1e-8): num_samples = test_embeddings.shape[0] similarities = pairwise_cosine_similarity(embeddings, test_embeddings, eps=eps) logsumexp = nn.logsumexp(similarities, axis=0, keepdims=True) max_similarities = jnp.max(similarities, axis=0, keepdims=True) exp_similarities = jnp.exp(similarities - max_similarities) sum_exp = jnp.zeros((num_classes, num_samples), dtype=exp_similarities.dtype) indices = jnp.expand_dims(targets, axis=-1) dimension_numbers = ScatterDimensionNumbers( update_window_dims=(1, ), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=(0, )) sum_exp = scatter_add(sum_exp, indices, exp_similarities, dimension_numbers) return jnp.log(sum_exp) + max_similarities - logsumexp
def _scale_and_translate(x, output_shape, scale, translate, kernel, antialias, precision): input_shape = x.shape assert len(input_shape) == len(output_shape) assert len(input_shape) == len(scale) assert len(input_shape) == len(translate) spatial_dims, = np.nonzero( np.not_equal(input_shape, output_shape) | np.not_equal(scale, 1) | np.not_equal(translate, 0)) if len(spatial_dims) == 0: return x contractions = [] in_indices = list(range(len(output_shape))) out_indices = list(range(len(output_shape))) for i, d in enumerate(spatial_dims): m = input_shape[d] n = output_shape[d] starts, span_weights = _compute_spans(m, n, scale[d], translate[d], kernel, antialias=antialias) dnums = lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(1, ), scatter_dims_to_operand_dims=(0, 1)) w = lax.scatter_add(jnp.zeros((m, n), x.dtype), np.stack([starts, np.arange(n)], axis=-1), span_weights.astype(x.dtype), dnums) contractions.append(w) contractions.append([d, len(output_shape) + i]) out_indices[d] = len(output_shape) + i contractions.append(out_indices) return jnp.einsum(x, in_indices, *contractions, precision=precision)
def f_jax(x, upd): return lax.scatter_add( x, scatter_indices=idx, updates=upd, dimension_numbers=lax.ScatterDimensionNumbers(*dimension_numbers), indices_are_sorted=False, unique_indices=True)
def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
def vec_to_tril_matrix(t, diagonal=0): # NB: the following formula only works for diagonal <= 0 n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal n2 = n * n idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)] x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t, lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1), inserted_window_dims=(t.ndim - 1,), scatter_dims_to_operand_dims=(t.ndim - 1,))) return jnp.reshape(x, x.shape[:-1] + (n, n))
def get_num_samples(targets, num_classes, dtype=None): ones = jnp.ones_like(targets, dtype=dtype) indices = jnp.expand_dims(targets, axis=-1) num_samples = jnp.zeros(targets.shape[:-1] + (num_classes, ), dtype=ones.dtype) dimension_numbers = ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=(0, )) return scatter_add(num_samples, indices, ones, dimension_numbers)
def _scatter_add_one(operand, indices, updates): return lax.scatter_add( operand, indices, updates, lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), ), )
def get_prototypes(embeddings, targets, num_classes): embedding_size, dtype = embeddings.shape[-1], embeddings.dtype num_samples = get_num_samples(targets, num_classes, dtype=dtype) num_samples = jnp.expand_dims(jnp.maximum(num_samples, 1), axis=-1) prototypes = jnp.zeros((num_classes, embedding_size), dtype=dtype) indices = jnp.expand_dims(targets, axis=-1) dimension_numbers = ScatterDimensionNumbers( update_window_dims=(1, ), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=(0, )) prototypes = scatter_add(prototypes, indices, embeddings, dimension_numbers) return prototypes / num_samples
def predict(self, params, logits, context, target=None): context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context, axis=1), axis=1), axis=1) context_bias = params.get('context_bias', 0.0) context_index = (params['context_maps'] * context).sum(axis=-1) > context_bias context_map_values = jnp.asarray( [[[[1 << n for n in range(self.context_map_size)]]]]) context_index = jnp.where(context_index, context_map_values, 0) context_index = context_index.sum(axis=-1, keepdims=True) batch_size = logits.shape[0] class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)] for c in range(self.num_classes)]]) class_neuron_index = jnp.tile(class_neuron_index, reps=(batch_size, 1, 1, 1)) context_index = jnp.concatenate([class_neuron_index, context_index], axis=-1) dims = lax.GatherDimensionNumbers(offset_dims=(3, ), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2)) weights = lax.gather(operand=params['weights'], start_indices=context_index, dimension_numbers=dims, slice_sizes=(1, 1, 1, self.input_size + int(self.bias))) if self.bias: bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1)) logits = jnp.concatenate([logits, bias], axis=-1) logits = jnp.expand_dims(logits, axis=-1) output_logits = jnp.matmul(weights, logits) output_logits = jnp.clip(output_logits, a_min=jsp.special.logit(self.pred_clipping), a_max=jsp.special.logit(1.0 - self.pred_clipping)) if target is None: return jnp.squeeze(output_logits, axis=-1) else: logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2) output_preds = jnn.sigmoid(output_logits) target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1) params['lr_step'], learning_rate = self.learning_rate.value( params['lr_step']) delta = learning_rate * (target - output_preds) * logits dims = lax.ScatterDimensionNumbers( update_window_dims=(3, ), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) if self.weight_clipping is None: params['weights'] = lax.scatter_add( operand=params['weights'], scatter_indices=context_index, updates=delta, dimension_numbers=dims) else: weights = jnp.clip(weights + delta, a_min=-self.weight_clipping, a_max=self.weight_clipping) params['weights'] = lax.scatter(operand=params['weights'], scatter_indices=context_index, updates=weights, dimension_numbers=dims) return params, jnp.squeeze(output_logits, axis=-1)