def forward(self, input, target, mask):
        assert input.size(0) == cfg.HKO.BENCHMARK.OUT_LEN
        # F.cross_entropy should be B*C*S*H*W
        input = input.permute((1, 2, 0, 3, 4))
        # B*S*H*W
        target = target.permute((1, 2, 0, 3, 4)).squeeze(1)
        class_index = torch.zeros_like(target).long()
        thresholds = [0.0] + rainfall_to_pixel(self._thresholds).tolist()
        # print(thresholds)
        for i, threshold in enumerate(thresholds):
            class_index[target >= threshold] = i
        error = F.cross_entropy(input,
                                class_index,
                                self._weight,
                                reduction='none')
        if self._lambda is not None:
            B, S, H, W = error.size()

            w = torch.arange(1.0, 1.0 + S * self._lambda, self._lambda)
            if torch.cuda.is_available():
                w = w.to(error.get_device())
                # B, H, W, S
            error = (w * error.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        # S*B*1*H*W
        error = error.permute(1, 0, 2, 3).unsqueeze(2)
        return torch.mean(error * mask.float())
Beispiel #2
0
 def forward(self, input, target, mask):
     balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS
     weights = torch.ones_like(input) * balancing_weights[0]
     thresholds = [rainfall_to_pixel(ele) for ele in cfg.HKO.EVALUATION.THRESHOLDS]
     for i, threshold in enumerate(thresholds):
         weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (target >= threshold).float()
     weights = weights * mask.float()
     # input: S*B*1*H*W
     # error: S*B
     mse = torch.sum(weights * ((input-target)**2), (2, 3, 4))
     mae = torch.sum(weights * (torch.abs((input-target))), (2, 3, 4))
     if self._lambda is not None:
         S, B = mse.size()
         w = torch.arange(1.0, 1.0 + S * self._lambda, self._lambda)
         if torch.cuda.is_available():
             w = w.to(mse.get_device())
         mse = (w * mse.permute(1, 0)).permute(1, 0)
         mae = (w * mae.permute(1, 0)).permute(1, 0)
     return self.NORMAL_LOSS_GLOBAL_SCALE * (self.mse_weight*torch.mean(mse) + self.mae_weight*torch.mean(mae))
    def __call__(self, prediction, ground_truth, mask, lr):
        '''
        prediction: 输入的类别预测值,S*B*C*H*W
        ground_truth: 实际值,像素/255.0, [0, 1]
        lr: 学习率
        :param prediction:
        :return:
        '''
        # 分类结果,0 到 classes - 1
        # prediction: S*B*C*H*W
        result = np.argmax(prediction, axis=2)[:, :, np.newaxis, ...]
        prediction_result = np.zeros(result.shape, dtype=np.float32)
        if not self.requires_grad:
            for i in range(len(self._middle_value)):
                prediction_result[result==i] = self._middle_value[i]
            # 如果需要更新替代值
            # 更新替代值
        # 权重
        else:
            balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS
            weights = torch.ones_like(prediction_result) * balancing_weights[0]
            thresholds = [rainfall_to_pixel(ele) for ele in cfg.HKO.EVALUATION.THRESHOLDS]
            for i, threshold in enumerate(thresholds):
                weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (ground_truth >= threshold).float()
            weights = weights * mask.float()


            loss = torch.zeros(1, requires_grad=True).float()
            for i in range(len(self._middle_value)):
                m = (result == i)
                prediction_result[m] = self._middle_value.data[i]
                tmp = (ground_truth[m]-self._middle_value[i])
                mse = torch.sum(weights[m] * (tmp ** 2), (2, 3, 4))
                mae = torch.sum(weights[m] * (torch.abs(tmp)), (2, 3, 4))
                loss = self.NORMAL_LOSS_GLOBAL_SCALE * (torch.mean(mse) + torch.mean(mae))
            loss.backward()
            self._middle_value -= lr * self._middle_value.grad

        return prediction_result