Beispiel #1
0
 def apply(self, x, mask=None, weights=None):
     ret = self.enc(x, mask=mask,
                    weights=weights)  # (batsize, seqlen, lastdim)
     outs = []
     # apply mask    (SeqEncoder should attach mask to outvar if all_outputs()
     mask = mask if mask is not None else ret.mask if hasattr(
         ret, "mask") else None
     if self.bidir:
         mid = ret.shape[2] / 2
         ret1 = ret[:, :, :mid]
         ret2 = ret[:, :, mid:]
         ret = ret1
     for i in range(self.numouts):
         selfweights = ret[:, :, i]  # (batsize, seqlen)
         if self.bidir:
             selfweights += ret2[:, :, i]
         selfweights = Softmax()(selfweights)
         if mask is not None:
             selfweights *= mask  # apply mask
         selfweights = selfweights / T.sum(selfweights, axis=1).dimshuffle(
             0, "x")  # renormalize
         weightedstates = ret[:, :, self.numouts:] * selfweights.dimshuffle(
             0, 1, "x")
         if self.bidir:
             weightedstates2 = ret2[:, :,
                                    self.numouts:] * selfweights.dimshuffle(
                                        0, 1, "x")
             weightedstates = T.concatenate(
                 [weightedstates, weightedstates2], axis=2)
         out = T.sum(weightedstates, axis=1)  # (batsize, lastdim)
         outs.append(out)
     if self.mode == "concat":
         ret = T.concatenate(outs, axis=1)
     elif self.mode == "seq":
         outs = [out.dimshuffle(0, "x", 1) for out in outs]
         ret = T.concatenate(outs, axis=1)
     return ret
Beispiel #2
0
 def apply(self, x, mask=None, weights=None):
     ret = self.enc(x, mask=mask,
                    weights=weights)  # (batsize, seqlen, lastdim)
     outs = []
     # apply mask    (SeqEncoder should attach mask to outvar if all_outputs()
     mask = ret.mask
     for i in range(self.numouts):
         selfweights = Softmax()(ret[:, :, i])  # (batsize, seqlen)
         selfweights *= mask  # apply mask
         selfweights = selfweights / T.sum(selfweights, axis=1).dimshuffle(
             0, "x")  # renormalize
         weightedstates = ret[:, :, self.numouts:] * selfweights.dimshuffle(
             0, 1, "x")
         out = T.sum(weightedstates, axis=1)  # (batsize, lastdim)
         outs.append(out)
     if self.mode == "concat":
         ret = T.concatenate(outs, axis=1)
     elif self.mode == "seq":
         outs = [out.dimshuffle(0, "x", 1) for out in outs]
         ret = T.concatenate(outs, axis=1)
     return ret