def predict(self, params, logits, context, target=None): context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context, axis=1), axis=1), axis=1) context_bias = params.get('context_bias', 0.0) context_index = (params['context_maps'] * context).sum(axis=-1) > context_bias context_map_values = jnp.asarray( [[[[1 << n for n in range(self.context_map_size)]]]]) context_index = jnp.where(context_index, context_map_values, 0) context_index = context_index.sum(axis=-1, keepdims=True) batch_size = logits.shape[0] class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)] for c in range(self.num_classes)]]) class_neuron_index = jnp.tile(class_neuron_index, reps=(batch_size, 1, 1, 1)) context_index = jnp.concatenate([class_neuron_index, context_index], axis=-1) dims = lax.GatherDimensionNumbers(offset_dims=(3, ), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2)) weights = lax.gather(operand=params['weights'], start_indices=context_index, dimension_numbers=dims, slice_sizes=(1, 1, 1, self.input_size + int(self.bias))) if self.bias: bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1)) logits = jnp.concatenate([logits, bias], axis=-1) logits = jnp.expand_dims(logits, axis=-1) output_logits = jnp.matmul(weights, logits) output_logits = jnp.clip(output_logits, a_min=jsp.special.logit(self.pred_clipping), a_max=jsp.special.logit(1.0 - self.pred_clipping)) if target is None: return jnp.squeeze(output_logits, axis=-1) else: logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2) output_preds = jnn.sigmoid(output_logits) target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1) params['lr_step'], learning_rate = self.learning_rate.value( params['lr_step']) delta = learning_rate * (target - output_preds) * logits dims = lax.ScatterDimensionNumbers( update_window_dims=(3, ), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) if self.weight_clipping is None: params['weights'] = lax.scatter_add( operand=params['weights'], scatter_indices=context_index, updates=delta, dimension_numbers=dims) else: weights = jnp.clip(weights + delta, a_min=-self.weight_clipping, a_max=self.weight_clipping) params['weights'] = lax.scatter(operand=params['weights'], scatter_indices=context_index, updates=weights, dimension_numbers=dims) return params, jnp.squeeze(output_logits, axis=-1)
StaticArg(axis)]) for indices in [ # Ensure each set of indices has a distinct shape np.array(2, dtype=np.int32), np.array([2], dtype=np.int32), np.array([2, 4], dtype=np.int32), np.array([[2, 4], [5, 6]], dtype=np.int32), np.array([0, 1, 10], dtype=np.int32), # Index out of bounds np.array([0, 1, 2, -1], dtype=np.int32), # Index out of bounds ] for axis in [0, 1, 2]] + # Directly from lax.gather in lax_test.py. [Harness( f"_shape={shape}_idxs_shape={idxs.shape}_dnums={dnums}_slice_sizes={slice_sizes}", lambda op, idxs, dnums, slice_sizes: lax.gather(op, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes), [RandArg(shape, np.float32), idxs, StaticArg(dnums), StaticArg(slice_sizes)]) for shape, idxs, dnums, slice_sizes in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)),
shape = (5, 4, 3) x = randn(*shape) z = choice(shape[-1], size=shape[:-1]) y = x[np.arange(shape[0])[:, None], np.arange(shape[1]), z, ] print("Correct answer") print(y) print("Wrong shapes") print( gather(x, z[:, :, None], GatherDimensionNumbers(( 0, 2, ), (1, ), (2, )), (1, 1, 1)).shape) print( gather(x, z[:, :, None], GatherDimensionNumbers(( 0, 1, ), (2, ), (2, )), (1, 1, 1)).shape) print("Right shape and answer with gather:") print( gather(x, z[:, :, None], GatherDimensionNumbers((2, ), ( 0, 1, ), (2, )), (1, 1, 1)).squeeze(-1), ) print("Totally wrong answer with incorrect indexing")