Пример #1
0
def test_is_binary():
    t1 = torch.FloatTensor([0, 1, 0, 0])
    t2 = torch.FloatTensor([0, -1, 0, 0])
    t3 = torch.FloatTensor([0, 0.1, 0.2, 0])
    assert is_binary(t1)
    assert not is_binary(t2)
    assert not is_binary(t3)
Пример #2
0
    def __new__(cls, values, mask, left_justify=True):
        if not isinstance(values, Variable) or not isinstance(mask, Variable):
            raise ValueError('values and mask must both be of type Variable.')

        m = mask.data

        if len(m.size()) == 0:
            raise ValueError('Mask must not be 0-dimensional')

        # check that mask is binary
        if not is_binary(m):
            raise ValueError('Mask must be binary:\n{}'.format(mask))

        # check that mask is left-justified
        # since mask is binary, we just need to check that it is monotonically non-increasing from left to right
        batch_size, seq_len = m.size()
        if seq_len > 1 and left_justify:
            diffs = m[:, 1:] - m[:, :-1]  # (batch_size, max_seq_length - 1)
            non_increasing = diffs <= 0
            all_non_increasing = (torch.prod(non_increasing) == 1)
            if not all_non_increasing:
                raise ValueError(
                    'Mask must be left-justified:\n{}'.format(mask))

        self = super(SequenceBatch, cls).__new__(cls, values, mask)
        return self