def _test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)): """Make trainer_lib.inputs.Inputs.""" batch_size = 2 * xla_bridge.device_count() def input_stream(n_devices): del n_devices key = fastmath.random.get_prng(0) while True: keys = fastmath.random.split(key, 4) key = keys[0] inputs = fastmath.random.uniform(keys[1], [batch_size] + list(input_shape)) targets = fastmath.random.randint(keys[2], [batch_size], dtype=jnp.int32, minval=0, maxval=n_classes) weights = fastmath.random.uniform(keys[3], [batch_size]) if with_weights: yield inputs, targets, weights else: yield inputs, targets def input_stream_masked(n_devices): return inputs_lib.add_loss_weights(input_stream(n_devices)) return inputs_lib.Inputs(input_stream_masked)
def random_inputs(seq_len, batch_size): def stream_fn(num_devices): del num_devices while True: x = np.random.uniform(size=(batch_size, seq_len)) y = np.random.uniform(size=(batch_size, seq_len)) mask = np.ones_like(x).astype(np.float32) yield (x, y, x, mask) return inputs.Inputs( train_stream=stream_fn, eval_stream=stream_fn, )
def signal_inputs(seq_len, batch_size, depth=1): def stream_fn(num_devices): del num_devices for (x, y) in batch_stream( generate_signals(seq_len=seq_len, depth=depth), batch_size=batch_size, ): mask = np.ones_like(x).astype(np.float32) # (input_x, input_y, target_x, target_y, mask) yield (x, y, x, y, mask) return trax_input.Inputs( train_stream=stream_fn, eval_stream=stream_fn, )
def _test_inputs_lm(vocab_size, seq_len, per_device_batch_size=2): """Make trainer_lib.inputs.Inputs for language model.""" batch_size = per_device_batch_size * xla_bridge.device_count() def input_stream(_): def make_batch(key): return fastmath.random.randint( key, [batch_size, seq_len], dtype=jnp.int32, minval=0, maxval=vocab_size) key = fastmath.random.get_prng(0) while True: keys = fastmath.random.split(key, 3) key = keys[0] inputs = make_batch(keys[1]) targets = make_batch(keys[2]) yield inputs, targets def input_stream_masked(n_devices): return inputs_lib.add_loss_weights(input_stream(n_devices)) return inputs_lib.Inputs(input_stream_masked)