def test_expand_dims_for_broadcast(): low_tensor = torch.FloatTensor([[1, 2, 3], [4, 5, 6]]) # (2, 3) high_tensor = torch.zeros(2, 3, 8, 1) new_tensor = expand_dims_for_broadcast(low_tensor, high_tensor) assert new_tensor.size() == (2, 3, 1, 1) assert_tensor_equal(new_tensor.squeeze(), low_tensor) with pytest.raises(AssertionError): bad_tensor = torch.zeros(2, 4, 8, 1) # prefix doesn't match expand_dims_for_broadcast(low_tensor, bad_tensor)
def reduce_prod(cls, seq_batch): """Compute the product of each sequence in a SequenceBatch. If a sequence is empty, we return a product of 1. Args: seq_batch (SequenceBatch): of shape (batch_size, seq_length, X1, X2, ...) Returns: Tensor: of shape (batch_size, X1, X2, ...) """ mask = seq_batch.mask values = seq_batch.values # We set all pad values = 1, so that taking the log will not produce -inf mask_bcast = expand_dims_for_broadcast(mask, values).expand( values.size()) # (batch_size, seq_length, X1, X2, ...) values = conditional(mask_bcast, values, 1 - mask_bcast) logged = SequenceBatch( torch.log(values), seq_batch.mask) # (batch_size, seq_length, X1, X2, ...) log_sum = SequenceBatch.reduce_sum(logged) # (batch_size, X1, X2, ...) prod = torch.exp(log_sum) return prod
def reduce_max(cls, seq_batch): if cls._empty_seqs(seq_batch): raise ValueError("Taking max over zero elements.") values, mask = seq_batch.values, seq_batch.mask inf_mask = mask.clone() # (batch_size, seq_length) inf_mask[mask == 0] = float('inf') inf_mask[mask == 1] = 0 # masked elements will never win the max, because we subtract infinity from them inf_mask_bcast = expand_dims_for_broadcast(inf_mask, values).expand_as(values) # (batch_size, seq_length, X1, X2, ...) max_values, _ = torch.max(values - inf_mask_bcast, 1) # (batch_size, 1, X1, X2, ...) max_values = torch.squeeze(max_values, 1) # (batch_size, X1, X2, ...) return max_values
def weighted_sum(cls, seq_batch, weights): """Compute weighted sum of elements in a SequenceBatch. Args: seq_batch (SequenceBatch): with values of shape (batch_size, seq_length, X1, X2, ...) weights (Variable): of shape (batch_size, seq_length) Returns: Variable: of shape (batch_size, X1, X2, ...) """ values = seq_batch.values mask = seq_batch.mask weights = weights * mask # ignore weights outside mask weights = expand_dims_for_broadcast(weights, values).expand(values.size()) weighted = values * weights return torch.sum(weighted, dim=1).squeeze(dim=1)