def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: """ Returns the mean of the tensor but ignoring the values specified by masks. Used for masking out loss functions. :param tensor: Tensor which needs mean computation. :param masks: Boolean tensor of masks with same dimension as tensor. """ return (tensor.T * masks).sum() / torch.clamp( (torch.ones_like(tensor.T) * masks).float().sum(), min=1.0)
def test_lstm_layer(): torch.manual_seed(0) # Test zero for LSTM layer = lstm_layer( 4, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero ) for name, param in layer.named_parameters(): if "weight" in name: assert torch.all(torch.eq(param.data, torch.zeros_like(param.data))) elif "bias" in name: assert torch.all( torch.eq(param.data[4:8], torch.ones_like(param.data[4:8])) )
def test_all_masking(mask_value): # We make sure that a mask of all zeros or all ones will not trigger an error np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 20 for _ in range(5): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = [torch.ones_like(key[:, :, 0]) * mask_value] prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 optimizer.zero_grad() error.backward() optimizer.step()
def test_masked_mean(): test_input = torch.tensor([1, 2, 3, 4, 5]) masks = torch.ones_like(test_input).bool() mean = ModelUtils.masked_mean(test_input, masks=masks) assert mean == 3.0 masks = torch.tensor([False, False, True, True, True]) mean = ModelUtils.masked_mean(test_input, masks=masks) assert mean == 4.0 # Make sure it works if all masks are off masks = torch.tensor([False, False, False, False, False]) mean = ModelUtils.masked_mean(test_input, masks=masks) assert mean == 0.0 # Make sure it works with 2d arrays of shape (mask_length, N) test_input = torch.tensor([1, 2, 3, 4, 5]).repeat(2, 1).T masks = torch.tensor([False, False, True, True, True]) mean = ModelUtils.masked_mean(test_input, masks=masks) assert mean == 4.0