示例#1
0
 def __call__(self, sequences):
     """Embeds int or string sequences."""
     if isinstance(sequences[0], str):
         sequences = _encode_string_sequences(sequences,
                                              domain=self._domain,
                                              length=self._length)
     return utils.batch_apply(self._embed_fn, sequences, self._batch_size)
示例#2
0
  def test_batch_apply(self, batch_size, num_inputs):
    def fn(inputs):
      return np.power(inputs + 1, 2)

    def batch_fn(batched_inputs):
      if len(batched_inputs) != batch_size:
        raise ValueError('fn() called with a batch that is '
                         'the wrong size (%d vs. %d).' % (len(batched_inputs),
                                                          batch_size))
      return fn(batched_inputs)
    inputs = np.stack([np.arange(num_inputs), -np.arange(num_inputs)], axis=1)
    unbatched_output = fn(inputs)
    batched_output = utils.batch_apply(batch_fn, inputs, batch_size)
    np.testing.assert_array_equal(unbatched_output, batched_output)