def Bidirectional(forward_layer, axis=1, merge_layer=Concatenate()): """Bidirectional combinator for RNNs. Args: forward_layer: A layer, such as `trax.layers.LSTM` or `trax.layers.GRU`. axis: a time axis of the inputs. Default value is `1`. merge_layer: A combinator used to combine outputs of the forward and backward RNNs. Default value is 'trax.layers.Concatenate'. Example: Bidirectional(RNN(n_units=8)) Returns: The Bidirectional combinator for RNNs. """ backward_layer = copy.deepcopy(forward_layer) flip = base.Fn('_FlipAlongTimeAxis', lambda x: jnp.flip(x, axis=axis)) backward = Serial( flip, backward_layer, flip, ) return Serial( Branch(forward_layer, backward), merge_layer, )
def random_minibatches(length_list): """Generate a stream of random mini-batches.""" while True: length = random.choice(length_list) assert length % 2 == 0 w_length = (length // 2) - 1 w = np.random.randint(low=1, high=vocab_size-1, size=(batch_size, w_length)) zero = np.zeros([batch_size, 1], np.int32) loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)), np.ones((batch_size, w_length))], axis=1) if reverse: x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1) else: x = np.concatenate([zero, w, zero, w], axis=1) x = _pad_to_multiple_of(x, pad_to_multiple, 1) loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) yield (x, x, loss_weights) # Here inputs and targets are the same.