def forward(self, token_ids: torch.Tensor, entity_mask: torch.IntTensor, verb_mask: torch.IntTensor, loc_mask: torch.IntTensor, gold_loc_seq: torch.IntTensor, gold_state_seq: torch.IntTensor, num_cands: torch.IntTensor, sentence_mask: torch.IntTensor, cpnet_triples: List, state_rel_labels: torch.IntTensor, loc_rel_labels: torch.IntTensor): """ Args: token_ids: size (batch * max_wiki, max_ctx_tokens) *_mask: size (batch, max_sents, max_tokens) loc_mask: size (batch, max_cands, max_sents + 1, max_tokens), +1 for location 0 gold_loc_seq: size (batch, max_sents) gold_state_seq: size (batch, max_sents) state_rel_labels: size (batch, max_sents, max_cpnet) loc_rel_labels: size (batch, max_sents, max_cpnet) num_cands: size (batch,) """ assert entity_mask.size(-2) == verb_mask.size(-2) == loc_mask.size(-2) - 1\ == gold_state_seq.size(-1) == gold_loc_seq.size(-1) - 1 assert entity_mask.size(-1) == verb_mask.size(-1) == loc_mask.size(-1) batch_size = entity_mask.size(0) max_tokens = entity_mask.size(-1) max_sents = gold_state_seq.size(-1) max_cands = loc_mask.size(-3) attention_mask = (token_ids != self.plm_tokenizer.pad_token_id).to( torch.int) plm_outputs = self.embed_encoder(token_ids, attention_mask=attention_mask) embeddings = plm_outputs[ 0] # hidden states at the last layer, (batch, max_tokens, plm_hidden_size) token_rep, _ = self.TokenEncoder( embeddings) # (batch, max_tokens, 2*hidden_size) token_rep = self.Dropout(token_rep) assert token_rep.size() == (batch_size, max_tokens, 2 * self.hidden_size) cpnet_rep = self.CpnetEncoder(cpnet_triples, tokenizer=self.plm_tokenizer, encoder=self.cpnet_encoder) # state change prediction # size (batch, max_sents, NUM_STATES) tag_logits, state_attn_probs = self.StateTracker( encoder_out=token_rep, entity_mask=entity_mask, verb_mask=verb_mask, sentence_mask=sentence_mask, cpnet_triples=cpnet_triples, cpnet_rep=cpnet_rep) tag_mask = (gold_state_seq != PAD_STATE ) # mask the padded part so they won't count in loss log_likelihood = self.CRFLayer(emissions=tag_logits, tags=gold_state_seq.long(), mask=tag_mask, reduction='token_mean') state_loss = -log_likelihood # State classification loss is negative log likelihood pred_state_seq = self.CRFLayer.decode(emissions=tag_logits, mask=tag_mask) assert len(pred_state_seq) == batch_size correct_state_pred, total_state_pred = compute_state_accuracy( pred=pred_state_seq, gold=gold_state_seq.tolist(), pad_value=PAD_STATE) # location prediction # size (batch, max_cands, max_sents + 1) empty_mask = torch.zeros((batch_size, 1, max_tokens), dtype=torch.int) if self.use_cuda: empty_mask = empty_mask.cuda() entity_mask = torch.cat([empty_mask, entity_mask], dim=1) loc_logits, loc_attn_probs = self.LocationPredictor( encoder_out=token_rep, entity_mask=entity_mask, loc_mask=loc_mask, sentence_mask=sentence_mask, cpnet_triples=cpnet_triples, cpnet_rep=cpnet_rep) loc_logits = loc_logits.transpose( -1, -2) # size (batch, max_sents + 1, max_cands) masked_loc_logits = self.mask_loc_logits( loc_logits=loc_logits, num_cands=num_cands) # (batch, max_sents + 1, max_cands) masked_gold_loc_seq = self.mask_undefined_loc( gold_loc_seq=gold_loc_seq, mask_value=PAD_LOC) # (batch, max_sents + 1) loc_loss = self.CrossEntropy(input=masked_loc_logits.view( batch_size * (max_sents + 1), max_cands + 1), target=masked_gold_loc_seq.view( batch_size * (max_sents + 1)).long()) correct_loc_pred, total_loc_pred = compute_loc_accuracy( logits=masked_loc_logits, gold=masked_gold_loc_seq, pad_value=PAD_LOC) if loc_attn_probs is not None: loc_attn_probs = self.get_gold_attn_probs(loc_attn_probs, gold_loc_seq) attn_loss, total_attn_pred = self.get_attn_loss( state_attn_probs, loc_attn_probs, state_rel_labels, loc_rel_labels) if self.is_test: # inference pred_loc_seq = get_pred_loc(loc_logits=masked_loc_logits, gold_loc_seq=gold_loc_seq) return pred_state_seq, pred_loc_seq, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred return state_loss, loc_loss, attn_loss, correct_state_pred, total_state_pred, \ correct_loc_pred, total_loc_pred, total_attn_pred
def forward(self, char_paragraph: torch.Tensor, entity_mask: torch.IntTensor, verb_mask: torch.IntTensor, loc_mask: torch.IntTensor, gold_loc_seq: torch.IntTensor, gold_state_seq: torch.IntTensor, num_cands: torch.IntTensor): """ Args: gold_loc_seq: size (batch, max_sents) gold_state_seq: size (batch, max_sents) num_cands: size(batch,) """ assert entity_mask.size(-2) == verb_mask.size(-2) == loc_mask.size( -2) == gold_state_seq.size(-1) == gold_loc_seq.size(-1) assert entity_mask.size(-1) == verb_mask.size(-1) == loc_mask.size( -1) == char_paragraph.size(-2) batch_size = char_paragraph.size(0) max_tokens = char_paragraph.size(1) max_sents = gold_state_seq.size(-1) max_cands = loc_mask.size(-3) embeddings = self.EmbeddingLayer( char_paragraph, verb_mask) # (batch, max_tokens, embed_size) token_rep, _ = self.TokenEncoder( embeddings) # (batch, max_tokens, 2*hidden_size) token_rep = self.Dropout(token_rep) assert token_rep.size() == (batch_size, max_tokens, 2 * self.hidden_size) # state cheng prediction # size (batch, max_sents, NUM_STATES) tag_logits = self.StateTracker(encoder_out=token_rep, entity_mask=entity_mask, verb_mask=verb_mask) tag_mask = (gold_state_seq != PAD_STATE ) # mask the padded part so they won't count in loss log_likelihood = self.CRFLayer(emissions=tag_logits, tags=gold_state_seq.long(), mask=tag_mask, reduction='token_mean') state_loss = -log_likelihood # State classification loss is negative log likelihood pred_state_seq = self.CRFLayer.decode(emissions=tag_logits, mask=tag_mask) assert len(pred_state_seq) == batch_size correct_state_pred, total_state_pred = compute_state_accuracy( pred=pred_state_seq, gold=gold_state_seq.tolist(), pad_value=PAD_STATE) # location prediction # size (batch, max_cands, max_sents) loc_logits = self.LocationPredictor(encoder_out=token_rep, entity_mask=entity_mask, loc_mask=loc_mask) loc_logits = loc_logits.transpose( -1, -2) # size (batch, max_sents, max_cands) masked_loc_logits = self.mask_loc_logits( loc_logits=loc_logits, num_cands=num_cands) # (batch, max_sents, max_cands) masked_gold_loc_seq = self.mask_undefined_loc( gold_loc_seq=gold_loc_seq, mask_value=PAD_LOC) # (batch, max_sents) loc_loss = self.CrossEntropy( input=masked_loc_logits.view(batch_size * max_sents, max_cands), target=masked_gold_loc_seq.view(batch_size * max_sents).long()) correct_loc_pred, total_loc_pred = compute_loc_accuracy( logits=masked_loc_logits, gold=masked_gold_loc_seq, pad_value=PAD_LOC) # assert total_loc_pred > 0 if self.is_test: # inference pred_loc_seq = get_pred_loc(loc_logits=masked_loc_logits, gold_loc_seq=gold_loc_seq) return pred_state_seq, pred_loc_seq, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred return state_loss, loc_loss, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred