Example #1
0
    def work(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos):
        seq_len, bsz = ys_inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(
            ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _, _ = layer(
                x,
                self_padding_mask=padding_mask,
                self_attn_mask=self_attn_mask,
                external_memories=enc,
                external_padding_mask=src_padding_mask,
            )

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        probs = torch.softmax(self.out_proj(x), -1)

        _, pred_y = probs.max(-1)

        return probs, pred_y
Example #2
0
    def forward(self, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk):
        enc, src_padding_mask = self.encode(ys_tpl, ys_seg, ys_pos)
        seq_len, bsz = ys_inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) \
          + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _ ,_ = layer(x, self_padding_mask=padding_mask,\
                               self_attn_mask = self_attn_mask, \
                               external_memories = enc, \
                               external_padding_mask = src_padding_mask)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        pred = torch.softmax(self.out_proj(x), -1)

        loss = self.label_smotthing_loss(pred, ys_truth, msk)
        
        _, pred_y = pred.max(-1)
        tot_tokens = msk.float().sum().item()
        acc = (torch.eq(pred_y, ys_truth) * msk).float().sum().item()
        
        return (pred_y, ys_truth), loss, acc, tot_tokens, bsz
Example #3
0
    def work_incremental(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos, incremental_state=None):
        seq_len, bsz = ys_inp.size()
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None

        if incremental_state is None:
            self_attn_mask = self.attn_mask(seq_len)
            incremental_state = {}
        else:
            x = x[-1, :, :].unsqueeze(0)
            self_attn_mask = None

        for layer in self.layers:
            x, _ ,_ = layer.work_incremental(x, self_padding_mask=padding_mask, \
                                             self_attn_mask=self_attn_mask, \
                                             external_memories = enc, \
                                             external_padding_mask = src_padding_mask, \
                                             incremental_state = incremental_state)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        probs = torch.softmax(self.out_proj(x), -1)

        _, pred_y = probs.max(-1)
        return probs, pred_y, incremental_state
Example #4
0
    def forward(self, truth, inp, seg, msk, nxt_snt_flag):
        seq_len, bsz = inp.size()
        x = self.tok_embed(inp) + self.seg_embed(seg) + self.pos_embed(inp)
        x = self.emb_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        for layer in self.layers:
            x, _ ,_ = layer(x)

        y = self.one_more_norm(gelu(self.one_more(x)))
        out_proj_weight = self.tok_embed.weight
        log_probs = torch.log_softmax(F.linear(y, out_proj_weight, self.out_proj_bias), -1)
        
        loss = F.nll_loss(log_probs.view(seq_len*bsz, -1), truth.view(-1), reduction='none').view(seq_len, bsz)
        loss = loss.masked_select(msk)

        z = x[0]
        nxt_snt_pred = torch.sigmoid(self.nxt_snt_pred(z).squeeze(1))
        nxt_snt_acc = torch.eq(torch.gt(nxt_snt_pred, 0.5), nxt_snt_flag).float().sum().item()
        nxt_snt_loss = F.binary_cross_entropy(nxt_snt_pred, nxt_snt_flag.float(), reduction='none')

        
        tot_loss = loss.mean() + nxt_snt_loss.mean()
        
        _, pred = log_probs.max(-1)
        tot_tokens = msk.float().sum().item()
        pred = pred.masked_select(msk)
        gold = truth.masked_select(msk)
        acc = torch.eq(pred, gold).float().sum().item()
        return (pred, gold), tot_loss, acc, tot_tokens, nxt_snt_acc, bsz
Example #5
0
    def forward(self, x, kv = None,
                self_padding_mask = None, self_attn_mask = None,
                external_memories = None, external_padding_mask=None,
                need_weights = False):
        # x: seq_len x bsz x embed_dim
        residual = x
        if kv is None:
            x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)
        else:
            x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.attn_layer_norm(residual + x)

        if self.with_external:
            residual = x
            x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.external_layer_norm(residual + x)
        else:
            external_attn = None

        residual = x
        x = gelu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.ff_layer_norm(residual + x)

        return x, self_attn, external_attn
Example #6
0
    def ppl(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg,
            ys_pos, msk):
        enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)
        seq_len, bsz = ys_inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(
            ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _, _ = layer(
                x,
                self_padding_mask=padding_mask,
                self_attn_mask=self_attn_mask,
                external_memories=enc,
                external_padding_mask=src_padding_mask,
            )

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        pred = torch.softmax(self.out_proj(x), -1)
        nll, ppl = self.nll_loss(pred, ys_truth, msk)
        return nll, ppl, bsz
Example #7
0
 def work(self, inp):
     sentence_num, inp_dim = inp.size()
     self_attn_mask = self.attn_mask(sentence_num)
     x = inp.unsqueeze(1)
     for layer in self.layers:
         x, _, _ = layer(x, self_attn_mask=self_attn_mask)
     x = self.one_more_layer_norm(gelu(self.one_more(x)))
     x = x.squeeze(1)
     return x[-1:]
Example #8
0
class TransformerLayer(nn.Module):
    
    def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_external=False, weights_dropout = True):
        super(TransformerLayer, self).__init__()
        self.self_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
        self.fc1 = nn.Linear(embed_dim, ff_embed_dim)
        self.fc2 = nn.Linear(ff_embed_dim, embed_dim)
        self.attn_layer_norm = LayerNorm(embed_dim)
        self.ff_layer_norm = LayerNorm(embed_dim)
        self.with_external = with_external
        self.dropout = dropout
        if self.with_external:
            self.external_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
            self.external_layer_norm = LayerNorm(embed_dim)
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.normal_(self.fc1.weight, std=0.02)
        nn.init.normal_(self.fc2.weight, std=0.02)
        nn.init.constant_(self.fc1.bias, 0.)
        nn.init.constant_(self.fc2.bias, 0.)

    def forward(self, x, kv = None,
                self_padding_mask = None, self_attn_mask = None,
                external_memories = None, external_padding_mask=None,
                need_weights = False):
        # x: seq_len x bsz x embed_dim
        residual = x
        x = self.attn_layer_norm(x)
        if kv is None:
            x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)
        else:
            x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x

        if self.with_external:
            residual = x
            x = self.external_layer_norm(x)
            x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
        else:
            external_attn = None

        residual = x
         = self.ff_layer_norm(x)
        x = gelu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x

        return x, self_attn, external_attn
Example #9
0
    def work_incremental(self, x, self_padding_mask, self_attn_mask, incremental_state):
        # x: seq_len x bsz x embed_dim
        residual = x
        x = self.attn_layer_norm(x)
        x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, incremental_state=incremental_state)
        x = residual + x

        residual = x
        x = self.ff_layer_norm(x)
        x = gelu(self.fc1(x))
        x = self.fc2(x)
        x = residual + x

        return x, self_attn, None
    def forward(self, x):
        """
        Forward pass.
        :param x:  Input tensor, with shape of [batch_size, ]
        :return:
        """
        output = x.transpose(1, 2)
        output = self.w2(gelu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)

        return output
Example #11
0
    def forward(self, batch):
        input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next = batch

        out = self.allenc(input_ids, input_mask, segment_ids)

        out1 = self.fc1(out[:, 0])
        out1 = self.tanh(out1)
        out1 = self.fc2(out1)

        masked_pos1 = masked_pos[:, :, None].expand(-1, -1, out.size(-1))
        h_masked = torch.gather(out, 1, masked_pos1)
        h_masked = self.norm(gelu(self.linear(h_masked)))
        out2 = self.decoder(h_masked)

        return out1, out2
Example #12
0
    def call(self, x, seg, training, mask=False):

        res = self.encode(x, seg, training, mask)

        nsp = self.NSPdense1(res[:, 0])
        nsp = self.NSPdense2(nsp)  # For next sentence prediction

        mlm = self.MLMdense1(res)
        mlm = self.layer_norm(gelu(mlm))  # for masked token prediction

        self.reverseEmbeddings.set_weights(
            tf.reshape(self.encode.embeddings.embedding.get_weights(),
                       (1, self.hidden_size, self.input_vocab_size)))
        pred = self.reverseEmbeddings(mlm)

        return nsp, pred
Example #13
0
    def forward(self, inp, correct, future_hard, previous_hard, future_easy,
                previous_easy):
        seq_len, input_dim = inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        # x = [sentence_num, bsz, input_dim]
        x = inp.unsqueeze(1)
        for layer in self.layers:
            x, _, _ = layer(x, self_attn_mask=self_attn_mask)
        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        x = x.squeeze(1)
        # loss = [sentence_num, ]
        loss, acc1, acc2 = self.loss_fun(x, correct, future_hard,
                                         previous_hard, future_easy,
                                         previous_easy)

        return x, loss.mean(), acc1, acc2
Example #14
0
File: biglm.py Project: lipiji/Guyu
    def ppl(self, truth, inp, msk):
        seq_len, bsz = inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(inp) + self.pos_embed(inp)
        x = self.emb_layer_norm(x)
        padding_mask = torch.eq(truth, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _ ,_ = layer(x, self_padding_mask=padding_mask, self_attn_mask = self_attn_mask)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        pred = torch.softmax(self.out_proj(x), -1)
        _, pred_y = pred.max(-1)
        tot_tokens = msk.float().sum().item()
        acc = (torch.eq(pred_y, truth).float()*msk).sum().item() 
        nll, ppl = self.nll_loss(pred, truth, msk) 
        return acc, nll, ppl, tot_tokens, bsz
Example #15
0
    def work(self, inp):
        seq_len, bsz = inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(inp) + self.pos_embed(inp)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(inp, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _ ,_ = layer(x, self_padding_mask=padding_mask, self_attn_mask = self_attn_mask)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        probs = torch.softmax(self.out_proj(x), -1)

        _, pred_y = probs.max(-1)
        
        return probs, pred_y
Example #16
0
    def work_incremental(self, x, self_padding_mask = None, self_attn_mask = None,
                         external_memories = None, external_padding_mask = None, incremental_state = None):
        # x: seq_len x bsz x embed_dim
        residual = x
        x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, incremental_state=incremental_state)
        x = self.attn_layer_norm(residual + x)

        if self.with_external:
            residual = x
            x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask)
            x = self.external_layer_norm(residual + x)
        else:
            external_attn = None
        residual = x
        x = gelu(self.fc1(x))
        x = self.fc2(x)
        x = self.ff_layer_norm(residual + x)

        return x, self_attn, external_attn
Example #17
0
    def forward(self, truth, inp, seg, msk, nxt_snt_flag):
        seq_len, bsz = inp.size()
        x = self.tok_embed(inp) + self.seg_embed(seg) + self.pos_embed(inp)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(truth, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _, _ = layer(x, self_padding_mask=padding_mask)

        masked_x = x.masked_select(msk.unsqueeze(-1))
        masked_x = masked_x.view(-1, self.embed_dim)
        gold = truth.masked_select(msk)

        y = self.one_more_layer_norm(gelu(self.one_more(masked_x)))
        out_proj_weight = self.tok_embed.weight

        if self.approx is None:
            log_probs = torch.log_softmax(
                F.linear(y, out_proj_weight, self.out_proj_bias), -1)
        else:
            log_probs = self.approx.log_prob(y)

        loss = F.nll_loss(log_probs, gold, reduction='mean')

        z = torch.tanh(self.one_more_nxt_snt(x[0]))
        nxt_snt_pred = torch.sigmoid(self.nxt_snt_pred(z).squeeze(1))
        nxt_snt_acc = torch.eq(torch.gt(nxt_snt_pred, 0.5),
                               nxt_snt_flag).float().sum().item()
        nxt_snt_loss = F.binary_cross_entropy(nxt_snt_pred,
                                              nxt_snt_flag.float(),
                                              reduction='mean')

        tot_loss = loss + nxt_snt_loss

        _, pred = log_probs.max(-1)
        tot_tokens = msk.float().sum().item()
        acc = torch.eq(pred, gold).float().sum().item()

        return (pred, gold), tot_loss, acc, tot_tokens, nxt_snt_acc, bsz
Example #18
0
    def forward(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg,
                ys_pos, msk):
        enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)
        seq_len, bsz = ys_inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(
            ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(
            ys_truth, self.vocab.padding_idx
        )  # 这是把真实值换成index么? 不是, 在batchify 已经替换了, 这里可能是bool值, 当作mask??
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _, _ = layer(
                x,
                self_padding_mask=padding_mask,
                self_attn_mask=self_attn_mask,
                external_memories=enc,
                external_padding_mask=src_padding_mask,
            )

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        pred = torch.softmax(self.out_proj(x), -1)

        loss = self.label_smotthing_loss(pred, ys_truth, msk)

        _, pred_y = pred.max(-1)
        tot_tokens = msk.float().sum().item()  # tokens总数量
        acc = (torch.eq(pred_y, ys_truth).float() *
               msk).sum().item()  # 正确率, 其实本身一致的文字就有很多啊

        # nll_loss 交叉熵的一部分, 取对应label 去掉符号再均值, 相当于一个独立的封装, 方便去得到一个loss值和困惑度
        nll, ppl = self.nll_loss(pred, ys_truth, msk)
        return (pred_y, ys_truth), loss, acc, nll, ppl, tot_tokens, bsz
Example #19
0
    def forward(self, truth, inp, msk):
        seq_len, bsz = inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(inp) + self.pos_embed(inp)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(truth, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _, _ = layer(x,
                            self_padding_mask=padding_mask,
                            self_attn_mask=self_attn_mask)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        pred = torch.softmax(self.out_proj(x), -1)

        loss = self.nll_loss(pred, truth, msk)

        _, pred_y = pred.max(-1)
        tot_tokens = msk.float().sum().item()
        acc = torch.eq(pred_y, truth).float().sum().item()

        return (pred_y, truth), loss, acc, tot_tokens, bsz
Example #20
0
    def call(self, x):

        x = self.dense1(x)
        x = self.dense2(gelu(x))

        return x
Example #21
0
 def forward(self, x):
     out = self.fc2(gelu(self.fc1(x)))
     return out
Example #22
0
    def forward(self,
                embs,
                toks,
                msks,
                s_msks,
                c_msks,
                dep_root_msk,
                imgs,
                i_msks,
                poses,
                pose_msks,
                i3d_rgb=None,
                face=None,
                face_msks=None,
                bbox_meta=None):
        '''
        params:
          embs: [B x 1 x L]
          toks: [B x 1 x L]
          msks: [B x 1 x L]
          s_msks: [B x sN x L]
          c_msks: [B x sN x cN]
          dep_root_msk: [B x sN x L]
          imgs: [B x cN x hN x 3 x 224 x 224]
          i_msks: [B x cN x hN]
          poses: [B x cN x hN x 17 x 2]
          pose_msks: [B x cN x hN x 17]
          i3d_rgb: [B x cN x hN x 1024]
          face: [B x cN x hN x 512]
          face_msks: [B x cN x hN]
          bbox_meta_5: [B x cN x hN x 3]
        '''
        B, sN, L = s_msks.shape
        cN, hN = imgs.size(1), imgs.size(2)
        # Text embedding
        embs, toks, msks = embs.view(-1, L), toks.view(-1, L), msks.view(
            -1, L)  # [B*sN x L]
        bert_x, _ = self.bert(embs,
                              toks,
                              msks,
                              output_all_encoded_layers=False)  # [B x L x 768]
        bert_x = gelu(bert_x).view(B, 1, L, 768)
        if 'someone' in self.t_feats:
            ts_x = bert_x * s_msks.view(B, sN, L, 1)  # [B x sN x L x 768]
            ts_x = torch.sum(ts_x, 2)  # [B x sN x 768]
        if 'action' in self.t_feats:
            ta_x = bert_x * dep_root_msk.view(B, sN, L,
                                              1)  # [B x sN x L x 768]
            ta_x = torch.sum(ta_x, 2)  # [B x sN x 768]
        tce_x = torch.cat([ta_x, ts_x], -1)  # [B x sN x sdim]

        # Text Projection
        s_a_x = self.s_a_proj(tce_x)  # [B x sN x H]
        s_s_x = self.s_s_proj(tce_x)  # [B x sN x H]

        # Auxiliary Gender Classifier
        if self.use_gender:
            g_x = self.gender_fc(ts_x)  # [B x sN x 1]
            self.gender_result = g_x.squeeze(-1)  # [B x sN]

        # Visual embedding
        i_x = torch.zeros((B * cN * hN, 0), device=bbox_meta.device)
        if 'img' in self.v_feats:
            imgs = imgs.view(-1, 3, 224, 224)
            img_x = self.act_conv(imgs)  # [B*cN*hN x 2048]
            i_x = torch.cat((i_x, img_x), -1)  #[B*cN*hN x (2048 + 4)]
        if "i3d_rgb" in self.v_feats:
            i3d_x = i3d_rgb.view(-1, 1024)  # [B*cN*hN x 1024]
            i_x = torch.cat((i_x, i3d_x), -1)  # [B*cN*hN x num_ftrs]
        if 'face' in self.v_feats:
            face_x = face.view(-1, 512)  # [B*cN*hN x 512]

        # Image projection
        i_a_x = self.i_a_proj(i_x)  # [B*cN*hN x H]
        i_a_x = i_a_x.view((B, cN * hN, -1))  # [B x cN*hN x H]
        i_s_x = self.i_s_proj(face_x)  # [B*cN*hN x H]
        i_s_x = i_s_x.view((B, cN * hN, -1))  # [B x cN*hN x H]
        if 'meta' in self.v_feats:
            meta_x = self.meta_proj(bbox_meta.view(-1, 4))  # [B*cN*hN x 50]

        # Character Grounding
        s_a_x = s_a_x.unsqueeze(2).repeat(1, 1, cN * hN,
                                          1).view(-1, self.hidden_dim)
        i_a_x = i_a_x.unsqueeze(1).repeat(1, sN, 1,
                                          1).view(-1, self.hidden_dim)
        s_s_x = s_s_x.unsqueeze(2).repeat(1, 1, cN * hN,
                                          1).view(-1, self.hidden_dim)
        i_s_x = i_s_x.unsqueeze(1).repeat(1, sN, 1,
                                          1).view(-1, self.hidden_dim)
        if 'meta' in self.v_feats:
            meta_x = meta_x.view(B, 1, cN * hN,
                                 50).repeat(1, sN, 1, 1)  # B x sN x cN*hN x 50

        f_a_x = s_a_x * i_a_x
        f_a_x = f_a_x.view(B, sN, cN * hN,
                           self.hidden_dim)  # [B x sN x cN*hN x H]
        f_s_x = s_s_x * i_s_x
        f_s_x = f_s_x.view(B, sN, cN * hN,
                           self.hidden_dim)  # [B x sN x cN*hN x H]
        f_t_x = torch.cat((meta_x, f_a_x, f_s_x), -1)  # [B x sN x cN*hN x fH]
        f_x = self.fuse_fc(f_t_x)  # [B x sN x cN*hN x 1]
        f_x = f_x.squeeze(-1)  # [B x sN x cN*hN]

        # Masking
        pred_ground = f_x * i_msks.view(B, 1, cN * hN) * s_msks.sum(
            -1, keepdim=True)

        # Character Re-Identification
        pred_chreid = 0
        to_divide = 1
        self.text_mat, self.vid_mat, self.vgv_mat = None, None, None
        # Text Re-Id
        if 'text' in self.char_reid:
            text_reid_x = ts_x.unsqueeze(1) * ts_x.unsqueeze(
                2)  # [B x sN x sN x 768]
            text_reid_x = self.text_reid_fc(text_reid_x)
            text_mat = text_reid_x.squeeze(-1)  # [B x sN x sN]
            self.text_mat = text_mat
            pred_chreid += self.text_mat
            to_divide += 1
        # Person identity representation
        if 'visual' in self.char_reid:
            vid_mat = self.vid_model(imgs.view(B, -1, 3, 224, 224),
                                     i_msks.view(B, -1),
                                     poses.view(B, -1, 17, 2),
                                     pose_msks.view(B, -1, 17),
                                     face.view(B, -1, 512),
                                     face_msks.view(B,
                                                    -1))  # B x cN*hN x cN*hN

            self.vid_mat = vid_mat
            p_msks = c_msks.view(B, sN, cN, 1).repeat(1, 1, 1,
                                                      hN).view(B, sN, cN * hN)
            i_mm = i_msks.view(B, 1, cN * hN)  # [B x 1 x cN*hN]
            p_msks = p_msks * i_mm  # [B x sN x cN*hN] (1 if exist, 0 if not)
            p_att = masked_softmax(
                f_x * 5, p_msks)  #[B x sN x cN*hN], sharpen f_x little.

            vgv_mat = torch.matmul(p_att, vid_mat)
            vgv_mat = vgv_mat * (
                1 - p_msks) - 0.1 * p_msks  # Mask for someone in same clip
            vgv_mat = torch.matmul(vgv_mat,
                                   p_att.transpose(1, 2))  # [B x sN x sN]
            vgv_mat = torch.sigmoid(vgv_mat)  # [B x sN x sN]  0 ~ 1
            self.vgv_mat = vgv_mat

            pred_chreid = (pred_chreid + vgv_mat) / to_divide

        return pred_ground, pred_chreid