def forward(self, input_t: BK.Expr, edges: BK.Expr, mask_t: BK.Expr): _isize = self.conf._isize _ntype = self.conf.type_num _slen = BK.get_shape(edges, -1) # -- edges3 = edges.clamp(min=-1, max=1) + 1 edgesF = edges + _ntype # offset to positive! # get hid hid0 = BK.matmul(input_t, self.W_hid).view( BK.get_shape(input_t)[:-1] + [3, _isize]) # [*, L, 3, D] hid1 = hid0.unsqueeze(-4).expand(-1, _slen, -1, -1, -1) # [*, L, L, 3, D] hid2 = BK.gather_first_dims(hid1.contiguous(), edges3.unsqueeze(-1), -2).squeeze(-2) # [*, L, L, D] hidB = self.b_hid[edgesF] # [*, L, L, D] _hid = hid2 + hidB # get gate gate0 = BK.matmul(input_t, self.W_gate) # [*, L, 3] gate1 = gate0.unsqueeze(-3).expand(-1, _slen, -1, -1) # [*, L, L, 3] gate2 = gate1.gather(-1, edges3.unsqueeze(-1)) # [*, L, L, 1] gateB = self.b_gate[edgesF].unsqueeze(-1) # [*, L, L, 1] _gate0 = BK.sigmoid(gate2 + gateB) _gmask0 = ( (edges != 0) | (BK.eye(_slen) > 0)).float() * mask_t.unsqueeze(-2) # [*,L,L] _gate = _gate0 * _gmask0.unsqueeze(-1) # [*,L,L,1] # combine h0 = BK.relu((_hid * _gate).sum(-2)) # [*, L, D] h1 = self.drop_node(h0) # add & norm? if self.ln is not None: h1 = self.ln(h1 + input_t) return h1
def pred(self, all_logprobs: List[BK.Expr], all_cfs: List[BK.Expr], **kwargs): conf: IdecHelperCWConf = self.conf # -- stack_t = BK.stack(all_logprobs, -2) # [*, NL, L] cf_t = BK.stack(all_cfs, -1).sigmoid() # [*, NL] _, _lidxes = cf_t.max(-1, keepdim=True) # [*, 1] ret_t = BK.gather_first_dims(stack_t, _lidxes, -2).squeeze(-2) # [*, L] return ret_t
def pred(self, all_logprobs: List[BK.Expr], all_cfs: List[BK.Expr], **kwargs): conf: IdecHelperCWConf = self.conf # -- stack_t = BK.stack(all_logprobs, -2) # [*, NL, L] cf_t = self._merge_cf_f(all_cfs) # [*, NL] if conf.pred_argmax: _, _lidxes = cf_t.max(-1, keepdim=True) # [*, 1] ret_t = BK.gather_first_dims(stack_t, _lidxes, -2).squeeze(-2) # [*, L] else: if conf.pred_mix_probs: prob_t = stack_t.exp() # [*, NL, L] prob_t2 = (prob_t * cf_t.unsqueeze(-1)).sum(-1) # [*, L] ret_t = (prob_t2 + 1e-6).log() # [*, L] else: ret_t = (stack_t * cf_t.unsqueeze(-1)).sum(-1) # [*, L] return ret_t