Пример #1
0
def _test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)):
    """Make trainer_lib.inputs.Inputs."""
    batch_size = 2 * xla_bridge.device_count()

    def input_stream(n_devices):
        del n_devices
        key = fastmath.random.get_prng(0)
        while True:
            keys = fastmath.random.split(key, 4)
            key = keys[0]
            inputs = fastmath.random.uniform(keys[1],
                                             [batch_size] + list(input_shape))
            targets = fastmath.random.randint(keys[2], [batch_size],
                                              dtype=jnp.int32,
                                              minval=0,
                                              maxval=n_classes)
            weights = fastmath.random.uniform(keys[3], [batch_size])
            if with_weights:
                yield inputs, targets, weights
            else:
                yield inputs, targets

    def input_stream_masked(n_devices):
        return inputs_lib.add_loss_weights(input_stream(n_devices))

    return inputs_lib.Inputs(input_stream_masked)
Пример #2
0
def random_inputs(seq_len, batch_size):
    def stream_fn(num_devices):
        del num_devices
        while True:
            x = np.random.uniform(size=(batch_size, seq_len))
            y = np.random.uniform(size=(batch_size, seq_len))
            mask = np.ones_like(x).astype(np.float32)
            yield (x, y, x, mask)

    return inputs.Inputs(
        train_stream=stream_fn,
        eval_stream=stream_fn,
    )
Пример #3
0
def signal_inputs(seq_len, batch_size, depth=1):
    def stream_fn(num_devices):
        del num_devices
        for (x, y) in batch_stream(
                generate_signals(seq_len=seq_len, depth=depth),
                batch_size=batch_size,
        ):
            mask = np.ones_like(x).astype(np.float32)
            # (input_x, input_y, target_x, target_y, mask)
            yield (x, y, x, y, mask)

    return trax_input.Inputs(
        train_stream=stream_fn,
        eval_stream=stream_fn,
    )
Пример #4
0
def _test_inputs_lm(vocab_size, seq_len, per_device_batch_size=2):
  """Make trainer_lib.inputs.Inputs for language model."""
  batch_size = per_device_batch_size * xla_bridge.device_count()

  def input_stream(_):
    def make_batch(key):
      return fastmath.random.randint(
          key, [batch_size, seq_len], dtype=jnp.int32, minval=0,
          maxval=vocab_size)
    key = fastmath.random.get_prng(0)
    while True:
      keys = fastmath.random.split(key, 3)
      key = keys[0]
      inputs = make_batch(keys[1])
      targets = make_batch(keys[2])
      yield inputs, targets

  def input_stream_masked(n_devices):
    return inputs_lib.add_loss_weights(input_stream(n_devices))

  return inputs_lib.Inputs(input_stream_masked)