def forward(self, img, volatile=False): # BDWH -> HBWD -> HBsD b, d, w, h = img.size() seq = img.permute(3, 0, 2, 1).contiguous().view(h, b * w, d) bs = b * w h0 = Variable(helpers.typeas(torch.zeros(1, bs, self.noutput), img), volatile=volatile) c0 = Variable(helpers.typeas(torch.zeros(1, bs, self.noutput), img), volatile=volatile) # HBsD -> HBsD assert seq.size() == (h, b * w, d), (seq.size(), (h, b * w, d)) post_lstm, _ = self.lstm(seq, (h0, c0)) assert post_lstm.size() == (h, b * w, self.noutput), (post_lstm.size(), (h, b * w, self.noutput)) # HBsD -> BsD -> BWD final = post_lstm.select(0, h - 1).view(b, w, self.noutput) assert final.size() == (b, w, self.noutput), (final.size(), (b, w, self.noutput)) # BWD -> BDW final = final.permute(0, 2, 1).contiguous() assert final.size() == (b, self.noutput, w), (final.size(), (b, self.noutput, self.noutput)) return final
def forward(self, seq, volatile=False): seq = bdl2lbd(seq) l, bs, d = seq.size() assert d == self.ninput, seq.size() h0 = Variable(helpers.typeas(torch.zeros(self.ndir, bs, self.noutput), seq), volatile=volatile) c0 = Variable(helpers.typeas(torch.zeros(self.ndir, bs, self.noutput), seq), volatile=volatile) post_lstm, _ = self.lstm(seq, (h0, c0)) return lbd2bdl(post_lstm)
def forward(self, seq): volatile = not isinstance(seq, Variable) or seq.volatile seq = bdl2lbd(seq) l, b, d = seq.size() assert d == self.ninput, (d, self.ninput) h0 = Variable(helpers.typeas(torch.zeros(1, b, self.noutput), seq), volatile=volatile) c0 = Variable(helpers.typeas(torch.zeros(1, b, self.noutput), seq), volatile=volatile) assert seq.size() == (l, b, d) post_lstm, _ = self.lstm(seq, (h0, c0)) assert post_lstm.size() == (l, b, self.noutput) final = post_lstm.select(0, l - 1).view(b, self.noutput) return final
def forward(self, img): volatile = not isinstance(img, Variable) or img.volatile b, d, h, w = img.size() # BDHW -> WHBD -> WB'D seq = img.permute(3, 2, 0, 1).contiguous().view(w, h * b, d) # WB'D h0 = helpers.typeas(torch.zeros(self.ndir, h * b, self.noutput), img) c0 = helpers.typeas(torch.zeros(self.ndir, h * b, self.noutput), img) h0 = Variable(h0, volatile=volatile) c0 = Variable(c0, volatile=volatile) seqresult, _ = self.lstm(seq, (h0, c0)) # WB'D' -> BD'HW result = seqresult.view(w, h, b, self.noutput * self.ndir).permute(2, 3, 1, 0) return result
def ctc_align(prob, target): """Perform CTC alignment on torch sequence batches (using ocrolstm). Inputs are in BDL format. """ import cctc assert dlh.sequence_is_normalized(prob), prob assert dlh.sequence_is_normalized(target), target # inputs are BDL prob_ = dlh.novar(prob).permute(0, 2, 1).cpu().contiguous() target_ = dlh.novar(target).permute(0, 2, 1).cpu().contiguous() # prob_ and target_ are both BLD now assert prob_.size(0) == target_.size(0), (prob_.size(), target_.size()) assert prob_.size(2) == target_.size(2), (prob_.size(), target_.size()) assert prob_.size(1) >= target_.size(1), (prob_.size(), target_.size()) result = torch.rand(1) cctc.ctc_align_targets_batch(result, prob_, target_) return dlh.typeas(result.permute(0, 2, 1).contiguous(), prob)