Ejemplo n.º 1
0
Archivo: MVP.py Proyecto: duyniem/tsai
def create_subsequence_mask(o, r=.15, lm=3, stateful=True, sync=False):
    device = o.device
    if o.ndim == 2: o = o[None]
    n_masks, mask_dims, mask_len = o.shape
    if sync == 'random': sync = random.random() > .5
    dims = 1 if sync else mask_dims
    if stateful:
        numels = n_masks * dims * mask_len
        pm = torch.tensor([1 / lm], device=device)
        pu = torch.clip(pm * (r / max(1e-6, 1 - r)), 1e-3, 1)
        zot, proba_a, proba_b = (torch.as_tensor([False, True], device=device), pu, pm) if random.random() > pm else \
        (torch.as_tensor([True, False], device=device), pm, pu)
        max_len = max(1, 2 * math.ceil(numels // (1/pm + 1/pu)))
        for i in range(10):
            _dist_a = (Geometric(probs=proba_a).sample([max_len])+1).long()
            _dist_b = (Geometric(probs=proba_b).sample([max_len])+1).long()
            dist_a = _dist_a if i == 0 else torch.cat((dist_a, _dist_a), dim=0)
            dist_b = _dist_b if i == 0 else torch.cat((dist_b, _dist_b), dim=0)
            add = torch.add(dist_a, dist_b)
            if torch.gt(torch.sum(add), numels): break
        dist_len = torch.argmax((torch.cumsum(add, 0) >= numels).float()) + 1
        if dist_len%2: dist_len += 1
        repeats = torch.cat((dist_a[:dist_len], dist_b[:dist_len]), -1).flatten()
        zot = zot.repeat(dist_len)
        mask = torch.repeat_interleave(zot, repeats)[:numels].reshape(n_masks, dims, mask_len)
    else:
        probs = torch.tensor(r, device=device)
        mask = Binomial(1, probs).sample((n_masks, dims, mask_len)).bool()
    if sync: mask = mask.repeat(1, mask_dims, 1)
    return mask
Ejemplo n.º 2
0
Archivo: MVP.py Proyecto: duyniem/tsai
def create_future_mask(o, r=.15, sync=False):
    if o.ndim == 2: o = o[None]
    n_masks, mask_dims, mask_len = o.shape
    if sync == 'random': sync = random.random() > .5
    dims = 1 if sync else mask_dims
    probs = torch.tensor(r, device=o.device)
    mask = Binomial(1, probs).sample((n_masks, dims, mask_len))
    if sync: mask = mask.repeat(1, mask_dims, 1)
    mask = torch.sort(mask,dim=-1, descending=True)[0].bool()
    return mask
Ejemplo n.º 3
0
    def get_sample_wlen(self, seq_len, bs=1):

        rep = torch.randint(self.min_repeat,
                            self.max_repeat, (1, ),
                            dtype=torch.long).item()
        prob = 0.5 * torch.ones([seq_len, bs, self.seq_width],
                                dtype=torch.float64)
        seq = Binomial(1, prob).sample()

        # fill in input sequence, two bit longer and wider than target
        input_seq = torch.zeros([seq_len + 2, bs, self.seq_width + 2])
        input_seq[0, :, self.seq_width] = 1.0  # delimiter
        input_seq[1:seq_len + 1, :, :self.seq_width] = seq
        input_seq[seq_len + 1, :, self.seq_width + 1] = self.normalise(rep)

        target_seq = torch.zeros([seq_len * rep + 1, bs, self.seq_width + 1])
        target_seq[:seq_len * rep, :, :self.seq_width] = seq.repeat(rep, 1, 1)
        target_seq[seq_len * rep, :, self.seq_width] = 1.0  # delimiter

        return {'input': input_seq, 'target': target_seq}
Ejemplo n.º 4
0
    def __getitem__(self, idx):
        # idx only acts as a counter while generating batches.
        seq_len = torch.randint(self.min_seq_len,
                                self.max_seq_len, (1, ),
                                dtype=torch.long).item()
        rep = torch.randint(self.min_repeat,
                            self.max_repeat, (1, ),
                            dtype=torch.long).item()
        prob = 0.5 * torch.ones([seq_len, self.seq_width], dtype=torch.float64)
        seq = Binomial(1, prob).sample()

        # fill in input sequence, two bit longer and wider than target
        input_seq = torch.zeros([seq_len + 2, self.seq_width + 2])
        input_seq[0, self.seq_width] = 1.0  # delimiter
        input_seq[1:seq_len + 1, :self.seq_width] = seq
        input_seq[seq_len + 1, self.seq_width + 1] = self.normalise(rep)

        target_seq = torch.zeros([seq_len * rep + 1, self.seq_width + 1])
        target_seq[:seq_len * rep, :self.seq_width] = seq.repeat(rep, 1)
        target_seq[seq_len * rep, self.seq_width] = 1.0  # delimiter

        return {'input': input_seq, 'target': target_seq}