Esempio n. 1
0
def test_expand_dims_for_broadcast():
    low_tensor = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])  # (2, 3)
    high_tensor = torch.zeros(2, 3, 8, 1)

    new_tensor = expand_dims_for_broadcast(low_tensor, high_tensor)

    assert new_tensor.size() == (2, 3, 1, 1)

    assert_tensor_equal(new_tensor.squeeze(), low_tensor)

    with pytest.raises(AssertionError):
        bad_tensor = torch.zeros(2, 4, 8, 1)  # prefix doesn't match
        expand_dims_for_broadcast(low_tensor, bad_tensor)
Esempio n. 2
0
    def reduce_prod(cls, seq_batch):
        """Compute the product of each sequence in a SequenceBatch.
        
        If a sequence is empty, we return a product of 1.
        
        Args:
            seq_batch (SequenceBatch): of shape (batch_size, seq_length, X1, X2, ...)

        Returns:
            Tensor: of shape (batch_size, X1, X2, ...)
        """
        mask = seq_batch.mask
        values = seq_batch.values

        # We set all pad values = 1, so that taking the log will not produce -inf
        mask_bcast = expand_dims_for_broadcast(mask, values).expand(
            values.size())  # (batch_size, seq_length, X1, X2, ...)
        values = conditional(mask_bcast, values, 1 - mask_bcast)

        logged = SequenceBatch(
            torch.log(values),
            seq_batch.mask)  # (batch_size, seq_length, X1, X2, ...)

        log_sum = SequenceBatch.reduce_sum(logged)  # (batch_size, X1, X2, ...)
        prod = torch.exp(log_sum)
        return prod
Esempio n. 3
0
    def reduce_max(cls, seq_batch):
        if cls._empty_seqs(seq_batch):
            raise ValueError("Taking max over zero elements.")
        values, mask = seq_batch.values, seq_batch.mask

        inf_mask = mask.clone()  # (batch_size, seq_length)
        inf_mask[mask == 0] = float('inf')
        inf_mask[mask == 1] = 0
        # masked elements will never win the max, because we subtract infinity from them

        inf_mask_bcast = expand_dims_for_broadcast(inf_mask, values).expand_as(values)  # (batch_size, seq_length, X1, X2, ...)

        max_values, _ = torch.max(values - inf_mask_bcast, 1)  # (batch_size, 1, X1, X2, ...)
        max_values = torch.squeeze(max_values, 1)  # (batch_size, X1, X2, ...)

        return max_values
Esempio n. 4
0
    def weighted_sum(cls, seq_batch, weights):
        """Compute weighted sum of elements in a SequenceBatch.

        Args:
            seq_batch (SequenceBatch): with values of shape (batch_size, seq_length, X1, X2, ...)
            weights (Variable): of shape (batch_size, seq_length)

        Returns:
            Variable: of shape (batch_size, X1, X2, ...)
        """
        values = seq_batch.values
        mask = seq_batch.mask
        weights = weights * mask  # ignore weights outside mask
        weights = expand_dims_for_broadcast(weights, values).expand(values.size())
        weighted = values * weights
        return torch.sum(weighted, dim=1).squeeze(dim=1)