Esempio n. 1
0
 def copy_values_to_cell(cell_value, value, ids):
     scatter_indices = np.reshape(ids, (tiled_size, 1))
     dnums = lax.ScatterDimensionNumbers(
         update_window_dims=tuple([1]),
         inserted_window_dims=tuple([0]),
         scatter_dims_to_operand_dims=tuple([0]),
     )
     return lax.scatter(cell_value, scatter_indices, value, dnums)
Esempio n. 2
0
 def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
   rng = rng_factory(self.rng())
   rng_idx = rng_idx_factory(self.rng())
   idxs = rng_idx(idxs.shape, idxs.dtype)
   scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
   x = rng(arg_shape, dtype)
   y = rng(update_shape, dtype)
   check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
Esempio n. 3
0
File: smap.py Progetto: xcbat/jax-md
 def copy_values_from_cell(value, cell_value, cell_id):
   scatter_indices = np.reshape(
       cell_id, (_cells_per_iter * cell_capacity, 1))
   cell_value = np.reshape(
       cell_value, (_cells_per_iter * cell_capacity, output_dimension))
   dnums = lax.ScatterDimensionNumbers(
       update_window_dims=tuple([1]),
       inserted_window_dims=tuple([0]),
       scatter_dims_to_operand_dims=tuple([0]),
   )
   return lax.scatter(value, scatter_indices, cell_value, dnums)
Esempio n. 4
0
    def predict(self, params, logits, context, target=None):
        context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context,
                                                                  axis=1),
                                                  axis=1),
                                  axis=1)
        context_bias = params.get('context_bias', 0.0)
        context_index = (params['context_maps'] *
                         context).sum(axis=-1) > context_bias

        context_map_values = jnp.asarray(
            [[[[1 << n for n in range(self.context_map_size)]]]])
        context_index = jnp.where(context_index, context_map_values, 0)
        context_index = context_index.sum(axis=-1, keepdims=True)

        batch_size = logits.shape[0]
        class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)]
                                           for c in range(self.num_classes)]])
        class_neuron_index = jnp.tile(class_neuron_index,
                                      reps=(batch_size, 1, 1, 1))
        context_index = jnp.concatenate([class_neuron_index, context_index],
                                        axis=-1)

        dims = lax.GatherDimensionNumbers(offset_dims=(3, ),
                                          collapsed_slice_dims=(0, 1, 2),
                                          start_index_map=(0, 1, 2))
        weights = lax.gather(operand=params['weights'],
                             start_indices=context_index,
                             dimension_numbers=dims,
                             slice_sizes=(1, 1, 1,
                                          self.input_size + int(self.bias)))

        if self.bias:
            bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1))
            logits = jnp.concatenate([logits, bias], axis=-1)
        logits = jnp.expand_dims(logits, axis=-1)

        output_logits = jnp.matmul(weights, logits)
        output_logits = jnp.clip(output_logits,
                                 a_min=jsp.special.logit(self.pred_clipping),
                                 a_max=jsp.special.logit(1.0 -
                                                         self.pred_clipping))

        if target is None:
            return jnp.squeeze(output_logits, axis=-1)

        else:
            logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2)
            output_preds = jnn.sigmoid(output_logits)
            target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1)
            params['lr_step'], learning_rate = self.learning_rate.value(
                params['lr_step'])
            delta = learning_rate * (target - output_preds) * logits

            dims = lax.ScatterDimensionNumbers(
                update_window_dims=(3, ),
                inserted_window_dims=(0, 1, 2),
                scatter_dims_to_operand_dims=(0, 1, 2))

            if self.weight_clipping is None:
                params['weights'] = lax.scatter_add(
                    operand=params['weights'],
                    scatter_indices=context_index,
                    updates=delta,
                    dimension_numbers=dims)
            else:
                weights = jnp.clip(weights + delta,
                                   a_min=-self.weight_clipping,
                                   a_max=self.weight_clipping)
                params['weights'] = lax.scatter(operand=params['weights'],
                                                scatter_indices=context_index,
                                                updates=weights,
                                                dimension_numbers=dims)

            return params, jnp.squeeze(output_logits, axis=-1)