def test_is_binary(): t1 = torch.FloatTensor([0, 1, 0, 0]) t2 = torch.FloatTensor([0, -1, 0, 0]) t3 = torch.FloatTensor([0, 0.1, 0.2, 0]) assert is_binary(t1) assert not is_binary(t2) assert not is_binary(t3)
def __new__(cls, values, mask, left_justify=True): if not isinstance(values, Variable) or not isinstance(mask, Variable): raise ValueError('values and mask must both be of type Variable.') m = mask.data if len(m.size()) == 0: raise ValueError('Mask must not be 0-dimensional') # check that mask is binary if not is_binary(m): raise ValueError('Mask must be binary:\n{}'.format(mask)) # check that mask is left-justified # since mask is binary, we just need to check that it is monotonically non-increasing from left to right batch_size, seq_len = m.size() if seq_len > 1 and left_justify: diffs = m[:, 1:] - m[:, :-1] # (batch_size, max_seq_length - 1) non_increasing = diffs <= 0 all_non_increasing = (torch.prod(non_increasing) == 1) if not all_non_increasing: raise ValueError( 'Mask must be left-justified:\n{}'.format(mask)) self = super(SequenceBatch, cls).__new__(cls, values, mask) return self