def create_subsequence_mask(o, r=.15, lm=3, stateful=True, sync=False): device = o.device if o.ndim == 2: o = o[None] n_masks, mask_dims, mask_len = o.shape if sync == 'random': sync = random.random() > .5 dims = 1 if sync else mask_dims if stateful: numels = n_masks * dims * mask_len pm = torch.tensor([1 / lm], device=device) pu = torch.clip(pm * (r / max(1e-6, 1 - r)), 1e-3, 1) zot, proba_a, proba_b = (torch.as_tensor([False, True], device=device), pu, pm) if random.random() > pm else \ (torch.as_tensor([True, False], device=device), pm, pu) max_len = max(1, 2 * math.ceil(numels // (1/pm + 1/pu))) for i in range(10): _dist_a = (Geometric(probs=proba_a).sample([max_len])+1).long() _dist_b = (Geometric(probs=proba_b).sample([max_len])+1).long() dist_a = _dist_a if i == 0 else torch.cat((dist_a, _dist_a), dim=0) dist_b = _dist_b if i == 0 else torch.cat((dist_b, _dist_b), dim=0) add = torch.add(dist_a, dist_b) if torch.gt(torch.sum(add), numels): break dist_len = torch.argmax((torch.cumsum(add, 0) >= numels).float()) + 1 if dist_len%2: dist_len += 1 repeats = torch.cat((dist_a[:dist_len], dist_b[:dist_len]), -1).flatten() zot = zot.repeat(dist_len) mask = torch.repeat_interleave(zot, repeats)[:numels].reshape(n_masks, dims, mask_len) else: probs = torch.tensor(r, device=device) mask = Binomial(1, probs).sample((n_masks, dims, mask_len)).bool() if sync: mask = mask.repeat(1, mask_dims, 1) return mask
def create_future_mask(o, r=.15, sync=False): if o.ndim == 2: o = o[None] n_masks, mask_dims, mask_len = o.shape if sync == 'random': sync = random.random() > .5 dims = 1 if sync else mask_dims probs = torch.tensor(r, device=o.device) mask = Binomial(1, probs).sample((n_masks, dims, mask_len)) if sync: mask = mask.repeat(1, mask_dims, 1) mask = torch.sort(mask,dim=-1, descending=True)[0].bool() return mask
def get_sample_wlen(self, seq_len, bs=1): rep = torch.randint(self.min_repeat, self.max_repeat, (1, ), dtype=torch.long).item() prob = 0.5 * torch.ones([seq_len, bs, self.seq_width], dtype=torch.float64) seq = Binomial(1, prob).sample() # fill in input sequence, two bit longer and wider than target input_seq = torch.zeros([seq_len + 2, bs, self.seq_width + 2]) input_seq[0, :, self.seq_width] = 1.0 # delimiter input_seq[1:seq_len + 1, :, :self.seq_width] = seq input_seq[seq_len + 1, :, self.seq_width + 1] = self.normalise(rep) target_seq = torch.zeros([seq_len * rep + 1, bs, self.seq_width + 1]) target_seq[:seq_len * rep, :, :self.seq_width] = seq.repeat(rep, 1, 1) target_seq[seq_len * rep, :, self.seq_width] = 1.0 # delimiter return {'input': input_seq, 'target': target_seq}
def __getitem__(self, idx): # idx only acts as a counter while generating batches. seq_len = torch.randint(self.min_seq_len, self.max_seq_len, (1, ), dtype=torch.long).item() rep = torch.randint(self.min_repeat, self.max_repeat, (1, ), dtype=torch.long).item() prob = 0.5 * torch.ones([seq_len, self.seq_width], dtype=torch.float64) seq = Binomial(1, prob).sample() # fill in input sequence, two bit longer and wider than target input_seq = torch.zeros([seq_len + 2, self.seq_width + 2]) input_seq[0, self.seq_width] = 1.0 # delimiter input_seq[1:seq_len + 1, :self.seq_width] = seq input_seq[seq_len + 1, self.seq_width + 1] = self.normalise(rep) target_seq = torch.zeros([seq_len * rep + 1, self.seq_width + 1]) target_seq[:seq_len * rep, :self.seq_width] = seq.repeat(rep, 1) target_seq[seq_len * rep, self.seq_width] = 1.0 # delimiter return {'input': input_seq, 'target': target_seq}