コード例 #1
0
ファイル: tree_module.py プロジェクト: orybkin/video-gcp
    def loss(self, inputs, outputs):
        if outputs.tree.depth == 0:
            return {}

        losses = AttrDict()

        losses.update(self.get_node_loss(inputs, outputs))

        # Explaining loss
        losses.update(self.binding.loss(inputs, outputs))

        # entropy penalty
        losses.entropy = PenaltyLoss(weight=self._hp.entropy_weight)(
            outputs.entropy)

        return losses
コード例 #2
0
ファイル: hedge.py プロジェクト: codeaudit/video-gcp
    def loss(self, inputs, model_output):
        if model_output.tree.depth == 0:
            return {}

        losses = AttrDict()

        if not 'gt_matching_dists' in model_output:     # potentially already computed in forward pass
            self.compute_matching(inputs, model_output)

        losses.update(self.get_node_loss(inputs, model_output))

        losses.update(self.get_extra_losses(inputs, model_output))
        
        # Explaining loss
        losses.update(self.matcher.loss(inputs, model_output))
        
        # entropy penalty
        losses.entropy = PenaltyLoss(weight=self._hp.entropy_weight)(model_output.entropy)

        return losses