def test_Transformer(): args = argparse.Namespace() B, T, C, V = 4, 3, 6, 5 args.__dict__.update(d_model=C, d_hidden=C, n_heads=3, drop_ratio=0, n_layers=2, length_ratio=1.5) field = MaskedBatchField() field.vocab = list(range(V)) xs = [ Variable(torch.LongTensor(1, random.randint(1, T)).random_(V)) for i in range(B) ] ys = [ Variable(torch.LongTensor(1, random.randint(2, T)).random_(V)) for i in range(B) ] xb = MaskedBatch.fromlist(xs, (True, )) yb = MaskedBatch.fromlist(ys, (True, )) model = Transformer(field, field, args) mb_assert(model, (xs, ys), (xb, yb), B) def loss(x, y): b = namedtuple('_batch', ('src', 'trg'))(x, y) return model.loss(b, reduce=False) mb_assert(loss, (xs, ys), (xb, yb), B)
def split(batch, split_size_or_sections, dim=0): if not isinstance(batch, MaskedBatch): return torch.split(batch, split_size_or_sections, dim) if dim < 0: dim += batch.dim() if dim > 0 and batch.dims[dim - 1]: return tuple(MaskedBatch(data, mask, batch.dims) for data, mask in zip( torch.split(batch.data, split_size_or_sections, dim), torch.split(batch.mask, split_size_or_sections, dim))) return tuple(MaskedBatch(data, batch.mask, batch.dims) for data in torch.split(batch.data, split_size_or_sections, dim))
def unbind(batch, dim): if not isinstance(batch, MaskedBatch): return torch.unbind(batch, dim) if dim == 0: raise ValueError("cannot unbind over batch dimension") dims = tuple(b for d, b in enumerate(batch.dims) if d != dim - 1) if batch.dims[dim - 1]: return tuple(MaskedBatch(data, mask, dims) for data, mask in zip(torch.unbind(batch.data, dim), torch.unbind(batch.mask, dim))) else: mask = batch.mask.squeeze(dim) return tuple(MaskedBatch(data, mask, dims) for data in torch.unbind(batch.data, dim))
def _synchronize(batch): if not isinstance(batch, MaskedBatch): return batch if any(batch.dims): raise ValueError("cannot synchronize batch with dynamic dimensions") mask = batch.mask + (1 - batch.mask) return MaskedBatch(batch.data, mask, batch.dims)
def inner(batch, dim=None, keepdim=False): if dim is None: if not zero_preserving and __builtins__['any'](batch.dims): raise NotImplementedError( "cannot reduce to scalar with non-zero-preserving kernel " "if dynamic dims present") mask = batch.mask[(slice(None), *(0 for d in batch.dims))] dims = () else: if dim < 0: dim += batch.dim() if not zero_preserving and batch.dims[dim - 1]: raise NotImplementedError("cannot reduce over dynamic dim " "with non-zero-preserving kernel") if keepdim: mask = batch.mask[tuple( slice(0, 1) if i == dim else slice(None) for i in range(batch.mask.dim()))] dims = tuple(False if i == dim - 1 else d for i, d in enumerate(batch.dims)) else: mask = batch.mask[tuple(0 if i == dim else slice(None) for i in range(batch.mask.dim()))] dims = tuple(d for i, d in enumerate(batch.dims) if i != dim - 1) data = fn(batch.data * batch.mask, dim=dim, keepdim=keepdim) return MaskedBatch(data, mask, dims)
def inner(batch, *args, **kwargs): if not isinstance(batch, MaskedBatch): return fn(batch, *args, **kwargs) data = fn(batch.data, *args, **kwargs) mask = batch.mask.type_as(data) dims = batch.dims return MaskedBatch(data, mask, dims)
def embedding(batch, weight, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False): def compat_embedding(batch, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse): if torch.__version__ >= '0.4': return F.embedding(batch, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) if padding_idx is not None: raise ValueError( "F.embedding doesn't support padding_idx for torch < 0.4") return F.embedding(batch, weight, max_norm, norm_type, scale_grad_by_freq, sparse) if not isinstance(batch, MaskedBatch): return compat_embedding(batch, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) #data = batch.data - batch.mask data = batch.data data = compat_embedding(data, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) mask = batch.mask.unsqueeze(-1).float() dims = batch.dims + (False, ) return MaskedBatch(data, mask, dims)
def linear(batch, weight, bias=None): if not isinstance(batch, MaskedBatch): return F.linear(batch, weight, bias) if batch.dims[-1]: raise ValueError("cannot contract static and dynamic dimensions") data = F.linear(batch.data, weight, bias) return MaskedBatch(data, batch.mask, batch.dims)
def join_dims(batch, dim1, dim2): if dim1 < 0: dim1 += batch.dim() if dim2 < 0: dim2 += batch.dim() if dim2 != dim1 + 1: order = [n for n in range(batch.dim()) if n != dim2] order.insert(dim1 + 1, dim2) batch = batch.permute(*order) if dim2 < dim1: dim1 -= 1 if not isinstance(batch, MaskedBatch): sizes = (batch.size(d + 1) * s if d == dim1 else s for d, s in enumerate(batch.size()) if d != dim1 + 1) return batch.contiguous().view(*sizes) sizes = (batch.data.size(d + 1) * s if d == dim1 else s for d, s in enumerate(batch.data.size()) if d != dim1 + 1) data = batch.data.contiguous().view(*sizes) if dim1 == 0: mask = batch.mask.expand(*(s if d == dim1 + 1 else -1 for d, s in enumerate(batch.data.size()))) sizes = (s * mask.size(d + 1) if d == dim1 else s for d, s in enumerate(mask.size()) if d != dim1 + 1) mask = mask.contiguous().view(*sizes) else: mask = batch.mask.squeeze(dim1 + 1) dims = batch.dims[:dim1] + batch.dims[dim1 + 1:] return MaskedBatch(data, mask, dims)
def inner(self, *sizes): source = self.data if isinstance(self, MaskedBatch) else self if not any(isinstance(size, MaskedBatch) for size in sizes): return original(source, *(int(size) for size in sizes)) if isinstance(sizes[0], MaskedBatch): raise ValueError("batch size dimension must be static") dims = tuple(isinstance(size, MaskedBatch) for size in sizes[1:]) maxsizes = [ size.data.max() if isinstance(size, MaskedBatch) else int(size) for size in sizes ] bs = maxsizes[0] masksizes = [s if b else 1 for s, b in zip(maxsizes[1:], dims)] data = original(source, *maxsizes) mask = source.new_zeros(bs, *masksizes) # TODO this should be # mask[range(bs), *(s - 1 for s in masksizes)] = 1 # mask = mask[:, *(slice(None, None, -1) if b # else slice(None, None, None) for b in dims)] # for d, b in enumerate(dims): # if not b: continue # mask = mask.cumsum(d + 1) # mask = mask[:, *(slice(None, None, -1) if b # else slice(None, None, None) for b in dims)] # if faking negative strides is fast enough; # we can also use numpy if it's worth it. for i in range(bs): inds = [ slice(0, int(size.data[i])) if b else slice(None) for size, b in zip(sizes[1:], dims) ] mask[(slice(i, i + 1), *inds)] = 1 return MaskedBatch(data, mask, dims)
def getitem(batch, index): if not isinstance(index, tuple) or index[0] != slice(None): raise ValueError("first index must be :") if None in index: raise NotImplementedError("cannot index with None") data = batch.data[index] index = list(index) for i, (ind, b) in enumerate(zip(index[1:], batch.dims)): if b: if isinstance(ind, int) and ind < 0: raise NotImplementedError("cannot index dynamic dim with " "negative integer") if isinstance(ind, slice) and ind.stop is not None and ind.stop < 0: if ind.step is not None or ind.start is not None: raise NotImplementedError("cannot index dynamic dim with " "complex slice") index[i + 1] = slice(-ind.stop, None) index = tuple(index) mask = batch.mask[tuple( i if b else 0 if isinstance(i, int) else slice(None) for i, b in zip(index, (True, ) + batch.dims))] dims = tuple(b for i, b in zip(index[1:] + (slice(None), ) * len(batch.dims), batch.dims) if not isinstance(i, int)) # could be faster return MaskedBatch(data, mask, dims)
def test_embedding(): xs = [ Variable(torch.LongTensor(1, random.randint(1, 3)).random_(5)) for i in range(4) ] W = Variable(torch.rand(5, 2)) xb = MaskedBatch.fromlist(xs, (True, )) mb_assert(F.embedding, (xs, W), (xb, W), 4)
def _update(batch, new, update_mask=None): if not isinstance(batch, MaskedBatch) and not isinstance(new, MaskedBatch): return new update_mask = (new.mask.byte() if update_mask is None else update_mask.data * update_mask.mask) if isinstance(batch, MaskedBatch): data = torch.where(update_mask, new.data, batch.data) else: data = torch.where(update_mask, new.data, batch) return MaskedBatch(data, update_mask.type_as(data), new.dims)
def view(batch, *sizes): bs = batch.data.size(0) if sizes[0] not in (1, -1, bs): raise ValueError("first dim in view must be 1, -1, or batch size") sizes = (bs,) + sizes[1:] data = batch.data.view(*sizes) # TODO can throw mask_sizes = (bs,) + tuple(batch.data.size(i) if sizes[i] == -1 else 1 for i in range(1, len(sizes))) mask = batch.mask.view(*mask_sizes) # TODO can this throw if data doesn't? dims = tuple(sizes[i] == -1 for i in range(1, len(sizes))) return MaskedBatch(data, mask, dims)
def mb_rand(*dims): dims = [dim for dim in dims if dim != ()] xs = [ Variable( torch.rand( 1, *(random.randint(1, size) if b else size for b, size in dims[1:]))) for i in range(dims[0]) ] xb = MaskedBatch.fromlist(xs, tuple(b for b, d in dims[1:])) return xs, xb
def inner(batch1, batch2, **kwargs): if not isinstance(batch1, MaskedBatch) and not isinstance(batch2, MaskedBatch): return fn(batch1, batch2, **kwargs) if isinstance(batch2, MaskedBatch): data = fn(batch1.data, batch2.data, **kwargs) mask = batch1.mask * batch2.mask dims = tuple(b1 or b2 for b1, b2 in zip(batch1.dims, batch2.dims)) else: data = fn(batch1.data, batch2, **kwargs) mask = batch1.mask.type_as(data) dims = batch1.dims return MaskedBatch(data, mask, dims)
def cat(sequence, dim): sequence = list(sequence) if len(sequence) == 0: raise ValueError("cannot stack empty sequence") first = sequence[0] if not isinstance(first, MaskedBatch): return torch.cat(sequence, dim) data = torch.cat([batch.data for batch in sequence], dim) if first.dims[dim - 1]: mask = torch.cat([batch.mask for batch in sequence], dim) else: mask = first.mask return MaskedBatch(data, mask, first.dims)
def size_as_tensor(batch, dim): if not isinstance(batch, MaskedBatch): return MAYBE_VARIABLE(torch.LongTensor([batch.size(dim)])) if dim is None: return tuple(batch.size(d) for d in range(len(batch.dims) + 1)) if dim < 0: dim += batch.dim() if dim == 0 or not batch.dims[dim - 1]: return MAYBE_VARIABLE(torch.LongTensor([batch.data.size(dim)])) if any(batch.dims[:dim - 1] + batch.dims[dim:]): raise NotImplementedError("cannot get size in any of two or " "more dynamic dimensions") data = batch.mask.long().sum(dim).view(-1) mask = data.new(batch.mask.size(0)).fill_(1) return MaskedBatch(data, mask, ())
def test_Encoder(): args = argparse.Namespace() args.__dict__.update(d_model=6, d_hidden=6, n_heads=3, drop_ratio=0, n_layers=2) field = MaskedBatchField() field.out = nn.Linear(args.d_model, 5) xs = [ Variable(torch.LongTensor(1, random.randint(1, 3)).random_(5)) for i in range(4) ] xb = MaskedBatch.fromlist(xs, (True, )) mb_assert(Encoder(field, args), (xs, ), (xb, ), 4)
def transpose(batch, dim1, dim2): if dim1 > batch.dim() or dim2 > batch.dim(): if dim1 < 0: dim1 += batch.dim() if dim2 < 0: dim2 += batch.dim() permutation = [dim2 if i == dim1 else dim1 if i == dim2 else i for i in range(batch.dim() + 1)][:batch.dim()] return batch.permute(*permutation) if not isinstance(batch, MaskedBatch): return torch.transpose(batch, dim1, dim2) data = batch.data.transpose(dim1, dim2) mask = batch.mask.transpose(dim1, dim2) dims = list(batch.dims) dims[dim1 - 1], dims[dim2 - 1] = dims[dim2 - 1], dims[dim1 - 1] dims = tuple(dims) return MaskedBatch(data, mask, dims)
def softmax(batch, dim=-1): if not isinstance(batch, MaskedBatch): return F.softmax(batch, dim) if dim == 0: raise ValueError("cannot softmax over batch dimension") elif dim < 0: dim += batch.dim() dims = batch.dims if dims[dim - 1]: data = F.softmax(batch.data * batch.mask, dim) * batch.mask data = data / data.sum(dim, keepdim=True) data[data.ne(data).detach()] = 0 # remove NaNs mask = batch.mask.narrow(dim, 0, 1) dims = dims[:dim - 1] + (False, ) + dims[dim:] else: data = F.softmax(batch.data, dim) mask = batch.mask return MaskedBatch(data, mask, dims)
def stack(sequence, dim, dynamic=None): sequence = list(sequence) if len(sequence) == 0: raise ValueError("cannot stack empty sequence") first = sequence[0] if not isinstance(first, MaskedBatch): return torch.stack(sequence, dim) if dim < 0: dim += first.dim() + 1 if dynamic is None: dynamic = not first.mask.eq(sequence[-1].mask).all() data = torch.cat([batch.data.unsqueeze(dim) for batch in sequence], dim) if dynamic: mask = torch.cat( [batch.mask.unsqueeze(dim) for batch in sequence], dim) else: mask = first.mask.unsqueeze(dim) dims = first.dims[:dim - 1] + (dynamic,) + first.dims[dim - 1:] return MaskedBatch(data, mask, dims)
def matmul(batch1, batch2): if not isinstance(batch1, MaskedBatch) and not isinstance( batch2, MaskedBatch): return batch1 @ batch2 if isinstance(batch1, MaskedBatch) and isinstance(batch2, MaskedBatch): dims1 = len(batch1.dims) dims2 = len(batch2.dims) data1 = batch1.data * batch1.mask data2 = batch2.data * batch2.mask if dims1 == 1: data1 = data1.unsqueeze(-2) if dims2 == 1 and dims1 == 1: data2 = data2.unsqueeze(-1) data = data1 @ data2 if dims1 == 1 and dims2 == 1: #if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask).all(): # raise ValueError("cannot contract non-matching dimensions") mask = batch1.mask[:, :1] dims = () if dims1 == 2 and dims2 == 1: #if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask).all(): # raise ValueError("cannot contract non-matching dimensions") mask = batch1.mask[:, :, :1] @ batch2.mask[:, :1] dims = batch1.dims[:1] elif dims1 == 1 and dims2 == 2: #if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask[:, :, 0]).all(): # raise ValueError("cannot contract non-matching dimensions") mask = batch1.mask[:, :1].unsqueeze(-2) @ batch2.mask[:, :1, :] dims = batch2.dims[1:] elif dims1 == 2 and dims2 == 2: #if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask[:, :, 0]).all(): # raise ValueError("cannot contract non-matching dimensions") mask = batch1.mask[:, :, :1] @ batch2.mask[:, :1, :] dims = batch1.dims[:1] + batch2.dims[1:] else: raise NotImplementedError( "matmul not implemented with batches of 3+D tensors") else: raise NotImplementedError( "matmul not implemented between MaskedBatch and tensor") return MaskedBatch(data, mask, dims)
def split_dim(batch, dim, split_by): if dim < 0: dim += batch.dim() if batch.data.size(dim) % split_by != 0: raise ValueError("size of dim not divisible by split_by") sizes = ((s // split_by, split_by) if d == dim else (s,) for d, s in enumerate(batch.data.size())) if not isinstance(batch, MaskedBatch): return batch.contiguous().view(*(n for tup in sizes for n in tup)) if dim == 0: msizes = ((s // split_by, split_by) if d == dim else (s,) for d, s in enumerate(batch.mask.size())) mask = batch.mask.contiguous().view(*(n for tup in msizes for n in tup)) mask = mask.narrow(1, 0, 1) else: if batch.dims[dim - 1]: raise ValueError("cannot split dynamic dimension") mask = batch.mask.unsqueeze(dim) data = batch.data.contiguous().view(*(n for tup in sizes for n in tup)) dims = batch.dims[:dim] + (False,) + batch.dims[dim:] return MaskedBatch(data, mask, dims)
def causal_mask(batch, in_dim, out_dim): '''if in_dim is indexed by i and out_dim by j, masks ret[i,j] where i > j''' if not isinstance(batch, MaskedBatch): # TODO or we could just promote to MaskedBatch /shrug if in_dim == 1 and out_dim == 2: return batch - batch.new( *batch.size()[1:]).fill_(1e10).tril(-1).unsqueeze(0) elif in_dim == 2 and out_dim == 1: return batch - batch.new( *batch.size()[1:]).fill_(1e10).triu(1).unsqueeze(0) else: raise NotImplementedError("unsupported arguments for causal_mask") if in_dim == 1 and out_dim == 2: mask = batch.mask * batch.mask.new( *batch.data.size()[1:]).fill_(1).triu(0).unsqueeze(0) elif in_dim == 2 and out_dim == 1: mask = batch.mask * batch.mask.new( *batch.data.size()[1:]).fill_(1).tril(0).unsqueeze(0) else: raise NotImplementedError("unsupported arguments for causal_mask") dims = tuple(True if d + 1 in (in_dim, out_dim) else b for d, b in enumerate(batch.dims)) return MaskedBatch(batch.data, mask, dims)
def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-1, reduce=True): if not isinstance(input, MaskedBatch) and not isinstance( target, MaskedBatch): ret = F.cross_entropy(input.contiguous().view(-1, input.size(-1)), target.contiguous().view(-1), weight, size_average, ignore_index, reduce) if reduce: return ret return ret.view(input.size(0), input.size(1)) target_data = (target.data + target.mask - 1).view(-1) input_data = input.data.view(target_data.size(0), -1) if ignore_index != -1: raise ValueError("cannot set ignore_index with MaskedBatch") data = F.cross_entropy(input_data, target_data, weight, size_average, ignore_index, reduce) if reduce: return data data = data.view(input.maxsize(0), input.maxsize(1)) mask = input.mask.squeeze(-1) * target.mask.float() return MaskedBatch(data, mask, target.dims)
def permute(batch, *permutation): data = batch.data.permute(*permutation) mask = batch.mask.permute(*permutation) dims = tuple(batch.dims[i - 1] for i in permutation[1:]) return MaskedBatch(data, mask, dims)
def contiguous(batch): return MaskedBatch( batch.data.contiguous(), batch.mask.contiguous(), batch.dims)
def dropout(batch, p=0.5, training=False, inplace=False): if not isinstance(batch, MaskedBatch): return F.dropout(batch, p, training, inplace) data = F.dropout(batch.data, p, training, inplace) return MaskedBatch(data, batch.mask, batch.dims)