def test_predict_minimum_training(): # of 5 numbers, predict index of min np.random.seed(1336) torch.manual_seed(1336) n_k = 5 size = n_k + 1 embedding_size = 64 entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k], concat_self=False) transformer = ResidualSelfAttention(embedding_size) l_layer = LinearEncoder(embedding_size, 2, n_k) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 onehots = ModelUtils.actions_to_onehot( torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0] onehots = onehots.expand((batch_size, -1, -1)) losses = [] for _ in range(400): num = np.random.randint(0, n_k) inp = torch.rand((batch_size, num + 1, 1)) with torch.no_grad(): # create the target : The minimum argmin = torch.argmin(inp, dim=1) argmin = argmin.squeeze() argmin = argmin.detach() sliced_oh = onehots[:, :num + 1] inp = torch.cat([inp, sliced_oh], dim=2) embeddings = entity_embeddings(inp, [inp]) masks = EntityEmbeddings.get_masks([inp]) prediction = transformer(embeddings, masks) prediction = l_layer(prediction) ce = loss(prediction, argmin) losses.append(ce.item()) print(ce.item()) optimizer.zero_grad() ce.backward() optimizer.step() assert np.array(losses[-20:]).mean() < 0.1
def test_simple_transformer_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size) transformer = ResidualSelfAttention(embedding_size, [n_k]) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam(list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001) batch_size = 200 point_range = 3 init_error = -1.0 for _ in range(250): center = torch.rand((batch_size, size)) * point_range * 2 - point_range key = torch.rand( (batch_size, n_k, size)) * point_range * 2 - point_range with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, [key]) masks = EntityEmbeddings.get_masks([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 if init_error == -1.0: init_error = error.item() else: assert error.item() < init_error print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.3
def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 for _ in range(200): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = get_zero_entities_mask([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.02
def test_all_masking(mask_value): # We make sure that a mask of all zeros or all ones will not trigger an error np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 20 for _ in range(5): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = [torch.ones_like(key[:, :, 0]) * mask_value] prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 optimizer.zero_grad() error.backward() optimizer.step()
def test_multi_head_attention_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_h, n_k, n_q = 3, 10, 5, 1 embedding_size = 64 mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size) optimizer = torch.optim.Adam(mha.parameters(), lr=0.001) batch_size = 200 point_range = 3 init_error = -1.0 for _ in range(50): query = torch.rand( (batch_size, n_q, size)) * point_range * 2 - point_range key = torch.rand( (batch_size, n_k, size)) * point_range * 2 - point_range value = key with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((query - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() prediction, _ = mha.forward(query, key, value) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 if init_error == -1.0: init_error = error.item() else: assert error.item() < init_error print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.5