Beispiel #1
0
 def forward(self, input_t: BK.Expr, edges: BK.Expr, mask_t: BK.Expr):
     _isize = self.conf._isize
     _ntype = self.conf.type_num
     _slen = BK.get_shape(edges, -1)
     # --
     edges3 = edges.clamp(min=-1, max=1) + 1
     edgesF = edges + _ntype  # offset to positive!
     # get hid
     hid0 = BK.matmul(input_t, self.W_hid).view(
         BK.get_shape(input_t)[:-1] + [3, _isize])  # [*, L, 3, D]
     hid1 = hid0.unsqueeze(-4).expand(-1, _slen, -1, -1,
                                      -1)  # [*, L, L, 3, D]
     hid2 = BK.gather_first_dims(hid1.contiguous(), edges3.unsqueeze(-1),
                                 -2).squeeze(-2)  # [*, L, L, D]
     hidB = self.b_hid[edgesF]  # [*, L, L, D]
     _hid = hid2 + hidB
     # get gate
     gate0 = BK.matmul(input_t, self.W_gate)  # [*, L, 3]
     gate1 = gate0.unsqueeze(-3).expand(-1, _slen, -1, -1)  # [*, L, L, 3]
     gate2 = gate1.gather(-1, edges3.unsqueeze(-1))  # [*, L, L, 1]
     gateB = self.b_gate[edgesF].unsqueeze(-1)  # [*, L, L, 1]
     _gate0 = BK.sigmoid(gate2 + gateB)
     _gmask0 = (
         (edges != 0) |
         (BK.eye(_slen) > 0)).float() * mask_t.unsqueeze(-2)  # [*,L,L]
     _gate = _gate0 * _gmask0.unsqueeze(-1)  # [*,L,L,1]
     # combine
     h0 = BK.relu((_hid * _gate).sum(-2))  # [*, L, D]
     h1 = self.drop_node(h0)
     # add & norm?
     if self.ln is not None:
         h1 = self.ln(h1 + input_t)
     return h1
Beispiel #2
0
 def pred(self, all_logprobs: List[BK.Expr], all_cfs: List[BK.Expr],
          **kwargs):
     conf: IdecHelperCWConf = self.conf
     # --
     stack_t = BK.stack(all_logprobs, -2)  # [*, NL, L]
     cf_t = BK.stack(all_cfs, -1).sigmoid()  # [*, NL]
     _, _lidxes = cf_t.max(-1, keepdim=True)  # [*, 1]
     ret_t = BK.gather_first_dims(stack_t, _lidxes,
                                  -2).squeeze(-2)  # [*, L]
     return ret_t
Beispiel #3
0
 def pred(self, all_logprobs: List[BK.Expr], all_cfs: List[BK.Expr],
          **kwargs):
     conf: IdecHelperCWConf = self.conf
     # --
     stack_t = BK.stack(all_logprobs, -2)  # [*, NL, L]
     cf_t = self._merge_cf_f(all_cfs)  # [*, NL]
     if conf.pred_argmax:
         _, _lidxes = cf_t.max(-1, keepdim=True)  # [*, 1]
         ret_t = BK.gather_first_dims(stack_t, _lidxes,
                                      -2).squeeze(-2)  # [*, L]
     else:
         if conf.pred_mix_probs:
             prob_t = stack_t.exp()  # [*, NL, L]
             prob_t2 = (prob_t * cf_t.unsqueeze(-1)).sum(-1)  # [*, L]
             ret_t = (prob_t2 + 1e-6).log()  # [*, L]
         else:
             ret_t = (stack_t * cf_t.unsqueeze(-1)).sum(-1)  # [*, L]
     return ret_t