def forward(self, h, x_seq_lens, y=None, y_seq_lens=None): batch_size = h.size(0) sos = int2onehot(h.new_full((batch_size, 1), self.sos), num_classes=self.label_vec_size).float() eos = int2onehot(h.new_full((batch_size, 1), self.eos), num_classes=self.label_vec_size).float() hidden = None y_hats = list() attentions = list() in_mask = self.get_mask(h, x_seq_lens) if self.masked_attend else None x = torch.cat([sos, h.narrow(1, 0, 1)], dim=-1) y_hats_seq_lens = torch.ones( (batch_size, ), dtype=torch.int) * self.max_seq_lens bi = torch.zeros(( self.num_eos, batch_size, )).byte() if x.is_cuda: bi = bi.cuda() for t in range(self.max_seq_lens): s, hidden = self.rnns(x, hidden) s = self.norm(s) c, a = self.attention(s, h, in_mask) y_hat = self.chardist(torch.cat([s, c], dim=-1)) y_hat = self.softmax(y_hat) y_hats.append(y_hat) attentions.append(a) # check 3 conjecutive eos occurrences bi[t % self.num_eos] = onehot2int(y_hat.squeeze()).eq(self.eos) ri = y_hats_seq_lens.gt(t) if bi.is_cuda: ri = ri.cuda() y_hats_seq_lens[bi.prod(dim=0, dtype=torch.uint8) * ri] = t + 1 # early termination if y_hats_seq_lens.le(t + 1).all(): break if y is None or not self._is_sample_step(): # non sampling step x = torch.cat([y_hat, c], dim=-1) elif t < y.size(1): # scheduled sampling step x = torch.cat([y.narrow(1, t, 1), c], dim=-1) else: x = torch.cat([eos, c], dim=-1) y_hats = torch.cat(y_hats, dim=1) attentions = torch.cat(attentions, dim=2) return y_hats, y_hats_seq_lens, attentions
def _train_forward(self, x, x_seq_lens, y, y_seq_lens): # to remove the case of x_seq_lens < y_seq_lens and y_seq_lens > max_seq_lens bi = x_seq_lens.gt(y_seq_lens) * y_seq_lens.lt(self.spell.max_seq_lens) if ~bi.any(): logger.warn( "there are samples of x_seq_lens < y_seq_lens or y_seq_lens > max_seq_lens" ) x, x_seq_lens = x[bi], x_seq_lens[bi] # listen h = self.listen(x, x_seq_lens) # make ys from y including trailing eos eos_t = y.new_full((self.spell.num_eos, ), self.eos) ys = [ torch.cat((yb, eos_t)) for yb in torch.split(y, y_seq_lens.tolist()) ] ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=self.blank) ys, ys_seq_lens = ys[bi], y_seq_lens[bi] + self.spell.num_eos floor = np.random.random_sample() * 0.1 yss = int2onehot(ys, num_classes=self.label_vec_size, floor=floor).float() noise = torch.rand_like(yss) * 0.1 yss = F.softmax(yss * noise, dim=-1) y_hats, y_hats_seq_lens, self.attentions = self.spell( h, x_seq_lens, yss, ys_seq_lens) # add regions to attentions self.regions = torch.IntTensor([ (frames - 1, labels - 1) for frames, labels in zip(x_seq_lens, ys_seq_lens) ]) # match seq lens between y_hats and ys s1, s2 = y_hats.size(1), ys.size(1) if s1 < s2: dummy = y_hats.new_full(( y_hats.size(0), s2 - s1, ), fill_value=self.blank, dtype=torch.int) dummy = int2onehot(dummy, num_classes=self.label_vec_size).float() y_hats = torch.cat([y_hats, dummy], dim=1) #y_hats = F.pad(y_hats, (0, 0, 0, s2 - s1)) elif s1 > s2: ys = F.pad(ys, (0, s1 - s2), value=self.blank) y_hats = self.log(y_hats) return y_hats, y_hats_seq_lens, ys, ys_seq_lens
def forward(self, x, x_seq_lens, y=None, y_seq_lens=None): # listener h, x_seq_lens = self.listen(x, x_seq_lens) # speller if self.training: if y_seq_lens.ge(self.max_seq_len).nonzero().numel(): return None, None, None # change y to one-hot tensors eos = self.spell.eos eos_tensor = torch.cuda.IntTensor([ eos, ]) if y.is_cuda else torch.IntTensor([ eos, ]) ys = [ torch.cat([yb, eos_tensor]) for yb in torch.split(y, y_seq_lens.tolist()) ] ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=eos) # speller with teach force rate if self._is_teacher_force(): yss = int2onehot( ys, num_classes=self.label_vec_size).float() * 2. - 1. y_hats, y_hats_seq_lens, self.attentions = self.spell(h, yss) else: y_hats, y_hats_seq_lens, self.attentions = self.spell(h) # match seq lens between y_hats and ys s1, s2 = y_hats.size(1), ys.size(1) if s1 < s2: # append one-hot tensors of eos to y_hats dummy = y_hats.new_full(( y_hats.size(0), s2 - s1, ), fill_value=eos) dummy = int2onehot(dummy, num_classes=self.label_vec_size).float() y_hats = torch.cat([y_hats, dummy], dim=1) elif s1 > s2: ys = F.pad(ys, (0, s1 - s2), value=eos) # return with seq lens return y_hats, y_hats_seq_lens, ys else: # do spell y_hats, y_hats_seq_lens, _ = self.spell(h) y_hats = self.softmax(y_hats[:, :, :-2]) # return with seq lens return y_hats, y_hats_seq_lens
def target_to_loglikes(self, ys, label_lens): max_len = max(label_lens.tolist()) num_classes = self.labeler.get_num_labels() ys_hat = [torch.cat((torch.zeros(1).int(), ys[s:s+l], torch.zeros(max_len-l).int())) for s, l in zip([0]+label_lens[:-1].cumsum(0).tolist(), label_lens.tolist())] ys_hat = [int2onehot(torch.IntTensor(z), num_classes, floor=1e-3) for z in ys_hat] ys_hat = torch.stack(ys_hat) ys_hat = torch.log(ys_hat) return ys_hat
def forward(self, h, y=None): batch_size = h.size(0) sos = int2onehot(h.new_full((batch_size, 1), self.sos), num_classes=self.label_vec_size).float() x = torch.cat([sos, h.narrow(1, 0, 1)], dim=-1) hidden = [None] * self.rnn_num_layers y_hats = list() attentions = list() max_seq_len = self.max_seq_len if y is None else y.size(1) unit_len = torch.ones((batch_size, )) for i in range(max_seq_len): for l, rnn in enumerate(self.rnns): x, hidden[l] = rnn(x, unit_len, hidden[l]) c, a = self.attention(x, h) y_hat = self.chardist(torch.cat([x, c], dim=-1)) y_hats.append(y_hat) attentions.append(a) # if eos occurs in all batch, stop iteration if not onehot2int(y_hat.squeeze()).ne(self.eos).nonzero().numel(): break if y is None: x = torch.cat([y_hat, c], dim=-1) else: # teach force x = torch.cat([y.narrow(1, i, 1), c], dim=-1) y_hats = torch.cat(y_hats, dim=1) attentions = torch.cat(attentions, dim=2) seq_lens = torch.full((batch_size, ), max_seq_len, dtype=torch.int) for b, y_hat in enumerate(y_hats): idx = onehot2int(y_hat).eq(self.eos).nonzero() if idx.numel(): seq_lens[b] = idx[0][0] return y_hats, seq_lens, attentions