def test_linear_attention_forward(self): d_model = 128 n_heads = 4 transformer = RecurrentTransformerEncoder([ RecurrentTransformerEncoderLayer( RecurrentAttentionLayer( RecurrentLinearAttention(), d_model, n_heads ), d_model, n_heads ) for i in range(6) ]) xs = [] memory = None for i in range(7): x, memory = transformer(torch.rand(10, d_model), state=memory) xs.append(x) for i in range(7): self.assertEqual(xs[i].shape, (10, d_model)) self.assertEqual(len(memory), 6) for i in range(6): self.assertEqual(len(memory[i]), 2) self.assertEqual(memory[i][0].shape, (10, n_heads, 32, 32)) self.assertEqual(memory[i][1].shape, (10, n_heads, 32))
def test_forward(self): att = RecurrentAttentionLayer( self._assert_sizes_attention((10, 4, 25), (10, 4, 25), (10, 4, 25)), 100, 4) v, m = att(torch.rand(10, 100), torch.rand(10, 100), torch.rand(10, 100), "test memory") self.assertEqual(v.shape, (10, 100)) self.assertEqual(m, "test memory") att = RecurrentAttentionLayer(self._assert_sizes_attention( (10, 4, 32), (10, 4, 32), (10, 4, 64)), 100, 4, d_keys=32, d_values=64) v, m = att(torch.rand(10, 100), torch.rand(10, 100), torch.rand(10, 100), "test memory") self.assertEqual(v.shape, (10, 100)) self.assertEqual(m, "test memory")
def test_compare_with_batch(self): N = 10 L = 42 S = 100 D = 1024 E = D // 4 x = torch.rand(N, L, D) m = torch.rand(N, S, D) tests = [("full", FullAttention, FullAttention, RecurrentFullAttention, RecurrentCrossFullAttention), ("linear", partial(CausalLinearAttention, E), partial(LinearAttention, E), partial(RecurrentLinearAttention, E), partial(RecurrentCrossLinearAttention, E))] for name, a1, a2, a3, a4 in tests: dec = TransformerDecoder([ TransformerDecoderLayer(AttentionLayer(a1(), D, 4), AttentionLayer(a2(), D, 4), D) for i in range(4) ]) rdec = RecurrentTransformerDecoder([ RecurrentTransformerDecoderLayer( RecurrentAttentionLayer(a3(), D, 4), RecurrentCrossAttentionLayer(a4(), D, 4), D) for i in range(4) ]) dec.eval() rdec.eval() rdec.load_state_dict(dec.state_dict()) x_mask = TriangularCausalMask(L) x_length = LengthMask(torch.full((N, ), L, dtype=torch.int64)) m_mask = FullMask(L, S) m_length = LengthMask(torch.full((N, ), S, dtype=torch.int64)) y1 = dec(x, m, x_mask=x_mask, x_length_mask=x_length, memory_mask=m_mask, memory_length_mask=m_length) state = None y2 = [] for i in range(L): y2i, state = rdec(x[:, i], m, memory_length_mask=m_length, state=state) y2.append(y2i) y2 = torch.stack(y2, dim=1) self.assertLess(torch.abs(y1 - y2).max(), 1e-5)
def test_mask_creation(self): N = 10 L = 42 S = 100 D = 1024 x = torch.rand(N, D) m = torch.rand(N, S, D) rdec = RecurrentTransformerDecoder([ RecurrentTransformerDecoderLayer( RecurrentAttentionLayer(RecurrentFullAttention(), D, 4), RecurrentCrossAttentionLayer(RecurrentCrossFullAttention(), D, 4), D) for i in range(4) ]) rdec(x, m)
def __init__(self, alphabet_size, hidden_size, num_window_components, num_mixture_components): super(HandwritingGenerator, self).__init__() self.alphabet_size = alphabet_size self.hidden_size = hidden_size self.num_window_components = num_window_components self.num_mixture_components = num_mixture_components # print(num_window_components) # print(num_mixture_components) # print(hidden_size) self.input_size = input_size = 3 n_heads_1 = 2 n_heads_2 = 10 query_dimensions = 1 self.n_pre_layers = 2 self.n_layers = 4 # n_heads_2 = 4 # First LSTM layer, takes as input a tuple (x, y, eol) # self.lstm1_layer = LSTM(input_size=3, hidden_size=hidden_size, batch_first=True) # [ # TransformerEncoderLayer( # AttentionLayer(FullAttention(), 768, 12), # 768, # 12, # activation="gelu" # ) for l in range(12) # ], # norm_layer=torch.nn.LayerNorm(768) self.lstm1_layer = LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True) # self.transformers1_layers = [ # RecurrentTransformerEncoderLayer( # RecurrentAttentionLayer(RecurrentLinearAttention(query_dimensions), input_size, n_heads_1), # input_size, # hidden_size, # activation="gelu" # ) for l in range(self.n_pre_layers) # ] # self.norm1_layer = torch.nn.Linear(input_size, hidden_size) # Gaussian Window layer self.window_layer = GaussianWindow( input_size=hidden_size, num_components=num_window_components) # Second LSTM layer, takes as input the concatenation of the input, # the output of the first LSTM layer # and the output of the Window layer # self.lstm2_layer = LSTM( # input_size=3 + hidden_size + alphabet_size + 1, # hidden_size=hidden_size, # batch_first=True, # ) self.transformers2_layers = [ RecurrentTransformerEncoderLayer( RecurrentAttentionLayer( RecurrentLinearAttention(query_dimensions), 3 + hidden_size + alphabet_size + 1, n_heads_2), 3 + hidden_size + alphabet_size + 1, # RecurrentAttentionLayer(RecurrentLinearAttention(query_dimensions), hidden_size, n_heads_2), # hidden_size, hidden_size, activation="gelu") for l in range(self.n_layers) ] # Third LSTM layer, takes as input the concatenation of the output of the first LSTM layer, # the output of the second LSTM layer # and the output of the Window layer # self.lstm3_layer = LSTM( # input_size=hidden_size, hidden_size=hidden_size, batch_first=True # ) # print( 3 + hidden_size + alphabet_size + 1) # print(hidden_size) self.norm2_layer = torch.nn.LayerNorm(3 + hidden_size + alphabet_size + 1) # self.norm2_layer = torch.nn.LayerNorm(hidden_size) # self.norm2_layer = torch.nn.Linear(hidden_size) # Mixture Density Network Layer self.output_layer = MDN( input_size=3 + hidden_size + alphabet_size + 1, num_mixtures=num_mixture_components # input_size=hidden_size, num_mixtures=num_mixture_components ) # Hidden State Variables self.prev_kappa = None # self.hidden1 = None self.hidden1 = None # self.hidden1 = [None] * self.n_pre_layers self.hidden2 = [None] * self.n_layers # self.hidden3 = None # Initiliaze parameters self.reset_parameters()