コード例 #1
0
 def forward_anchor(self, decoder_output, encoder_output, encoder_mask):
     if self.anchor_classifier is None:
         return None
     decoder_output = scale_grad(decoder_output,
                                 self.loss_weights["anchor"])
     return self.anchor_classifier(decoder_output, encoder_output,
                                   encoder_mask)
コード例 #2
0
ファイル: edge_classifier.py プロジェクト: brkmnd/DcrParser
    def forward(self, x, loss_weights):
        presence, label, attribute = None, None, None

        if self.presence:
            presence = self.edge_presence(
                scale_grad(x, loss_weights["edge presence"])).squeeze(
                    -1)  # shape: (B, T, T)
        if self.label:
            label = self.edge_label(scale_grad(
                x, loss_weights["edge label"]))  # shape: (B, T, T, O_1)
        if self.attribute:
            attribute = self.edge_attribute(
                scale_grad(
                    x,
                    loss_weights["edge attribute"]))  # shape: (B, T, T, O_2)

        return presence, label, attribute
コード例 #3
0
    def forward_property(self, decoder_output):
        output = {}
        for key in self.property_keys:
            scaled_decoder_output = scale_grad(
                decoder_output, self.loss_weights[f"property {key}"])
            output[f"{key}"] = F.log_softmax(
                self.property_classifier[key](scaled_decoder_output), dim=-1)

        return output
コード例 #4
0
ファイル: ucca_head.py プロジェクト: brkmnd/DcrParser
 def forward_label(self, decoder_output, decoder_lens):
     if self.label_classifier is None:
         return None
     decoder_output = scale_grad(decoder_output, self.loss_weights["label"])
     return torch.softmax(self.label_classifier(decoder_output), dim=-1)
コード例 #5
0
 def forward_property(self, decoder_output):
     if self.property_classifier is None:
         return None
     decoder_output = scale_grad(decoder_output,
                                 self.loss_weights["property"])
     return self.property_classifier(decoder_output).squeeze(-1)
コード例 #6
0
 def forward_top(self, decoder_output):
     if self.top_classifier is None:
         return None
     decoder_output = scale_grad(decoder_output, self.loss_weights["top"])
     return self.top_classifier(decoder_output).squeeze(-1)
コード例 #7
0
 def forward_label(self, decoder_output, decoder_lens):
     if self.label_classifier is None:
         return None
     decoder_output = scale_grad(decoder_output, self.loss_weights["label"])
     return self.label_classifier(decoder_output, decoder_lens,
                                  decoder_output.size(1))