def copy_values_to_cell(cell_value, value, ids): scatter_indices = np.reshape(ids, (tiled_size, 1)) dnums = lax.ScatterDimensionNumbers( update_window_dims=tuple([1]), inserted_window_dims=tuple([0]), scatter_dims_to_operand_dims=tuple([0]), ) return lax.scatter(cell_value, scatter_indices, value, dnums)
def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
def copy_values_from_cell(value, cell_value, cell_id): scatter_indices = np.reshape( cell_id, (_cells_per_iter * cell_capacity, 1)) cell_value = np.reshape( cell_value, (_cells_per_iter * cell_capacity, output_dimension)) dnums = lax.ScatterDimensionNumbers( update_window_dims=tuple([1]), inserted_window_dims=tuple([0]), scatter_dims_to_operand_dims=tuple([0]), ) return lax.scatter(value, scatter_indices, cell_value, dnums)
def predict(self, params, logits, context, target=None): context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context, axis=1), axis=1), axis=1) context_bias = params.get('context_bias', 0.0) context_index = (params['context_maps'] * context).sum(axis=-1) > context_bias context_map_values = jnp.asarray( [[[[1 << n for n in range(self.context_map_size)]]]]) context_index = jnp.where(context_index, context_map_values, 0) context_index = context_index.sum(axis=-1, keepdims=True) batch_size = logits.shape[0] class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)] for c in range(self.num_classes)]]) class_neuron_index = jnp.tile(class_neuron_index, reps=(batch_size, 1, 1, 1)) context_index = jnp.concatenate([class_neuron_index, context_index], axis=-1) dims = lax.GatherDimensionNumbers(offset_dims=(3, ), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2)) weights = lax.gather(operand=params['weights'], start_indices=context_index, dimension_numbers=dims, slice_sizes=(1, 1, 1, self.input_size + int(self.bias))) if self.bias: bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1)) logits = jnp.concatenate([logits, bias], axis=-1) logits = jnp.expand_dims(logits, axis=-1) output_logits = jnp.matmul(weights, logits) output_logits = jnp.clip(output_logits, a_min=jsp.special.logit(self.pred_clipping), a_max=jsp.special.logit(1.0 - self.pred_clipping)) if target is None: return jnp.squeeze(output_logits, axis=-1) else: logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2) output_preds = jnn.sigmoid(output_logits) target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1) params['lr_step'], learning_rate = self.learning_rate.value( params['lr_step']) delta = learning_rate * (target - output_preds) * logits dims = lax.ScatterDimensionNumbers( update_window_dims=(3, ), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) if self.weight_clipping is None: params['weights'] = lax.scatter_add( operand=params['weights'], scatter_indices=context_index, updates=delta, dimension_numbers=dims) else: weights = jnp.clip(weights + delta, a_min=-self.weight_clipping, a_max=self.weight_clipping) params['weights'] = lax.scatter(operand=params['weights'], scatter_indices=context_index, updates=weights, dimension_numbers=dims) return params, jnp.squeeze(output_logits, axis=-1)