def _transform_utility(self, val):
     # we minimize, but botorch assumes maximization
     # hence multiply by -1
     if isinstance(val, np.ndarray):
         return -1 * torch.logit(torch.from_numpy(val))
     if np.isscalar(val):
         return -1 * torch.logit(torch.tensor([val], dtype=torch.double))
     return -1 * torch.logit(val)
    def test_forward_backward(self):
        target_bbox = torch.tensor(
            [
                [
                    [0, 0, 9, 9],
                    [10, 10, 490, 490],
                    [-1, -1, -1, -1],
                ],
                [
                    [32, 32, 88, 88],
                    [42, 32, 84, 96],
                    [-1, -1, -1, -1],
                ],
                [
                    [10, 20, 50, 60],
                    [10, 20, 500, 600],
                    [20, 20, 84, 84],
                ],
                [
                    [-1, -1, -1, -1],
                    [-1, -1, -1, -1],
                    [-1, -1, -1, -1],
                ],
            ]
        )

        target_cls = torch.tensor(
            [
                [0, 1, -1],
                [0, 0, -1],
                [0, 0, 1],
                [-1, -1, -1],
            ]
        ).unsqueeze_(-1)

        target_bbox.shape[0]
        num_classes = 2
        strides = (8, 16, 32, 64, 128)
        base_size = 512
        sizes: Tuple[Tuple[int, int], ...] = tuple((base_size // stride,) * 2 for stride in strides)  # type: ignore

        criterion = FCOSLoss(strides, num_classes, radius=1.5)
        pred_cls, pred_reg, pred_centerness = criterion.create_targets(target_bbox, target_cls, sizes)
        pred_cls = [torch.logit(x, 1e-4) for x in pred_cls]
        pred_centerness = [torch.logit(x.clamp_(min=0, max=1), 1e-4) for x in pred_centerness]
        pred_reg = [x.clamp_min(0) for x in pred_reg]

        output = FCOSDecoder.postprocess(pred_cls, pred_reg, pred_centerness, list(strides), from_logits=True)
        criterion(pred_cls, pred_reg, pred_centerness, target_bbox, target_cls)
Beispiel #3
0
 def update_adaptive_pseudo_augmentation_p(self):
     loss_sign_real = torch.logit(torch.sigmoid(self.pred_real)).sign().mean()
     self.adjust = torch.sign(loss_sign_real - self.opt.APA_target)
     lambda_adjust = self.adjust* (self.opt.batch_size * self.opt.APA_every) / (self.opt.APA_nimg * 1000)
     self.adaptive_pseudo_augmentation_p = (self.adaptive_pseudo_augmentation_p + lambda_adjust)
     if self.adaptive_pseudo_augmentation_p < 0:
         self.adaptive_pseudo_augmentation_p = self.adaptive_pseudo_augmentation_p * 0
         
     if self.adaptive_pseudo_augmentation_p > 1:
         self.adaptive_pseudo_augmentation_p = 1
Beispiel #4
0
    def logit_mean_sigmoid(x: T, weights: T, bias: T, labels: T) -> T:
        """Computes a logistic regression on the input, averages over the class labels,
        and then converts the results back into log space with logit function. This
        function can also be used on batches of input, e.g for posterior samples

        :param x: input tensor (n_cells x n_latent)
        :param weights: regression weights (n_latent,)
        :param bias: offsets for each condition (n_conditions,)
        :param labels: label for each cell (n_cells,)
        :return: average dose response in log space (n_classes, n_conditions)
        """
        response = torch.sigmoid(
            torch.einsum("...jk,...k->...j", x, weights.squeeze())[..., None] +
            bias)

        labels = labels.view(-1, 1).expand(response.size())
        unique_labels, labels_count = labels.unique(dim=-2, return_counts=True)

        res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
            -2, labels, response)
        return torch.logit(res / labels_count.float().view(-1, 1), eps=1e-5)
Beispiel #5
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
Beispiel #6
0
 def backward(self, y):
     self._low = torch.min(y)
     self._high = torch.max(y)
     return torch.logit(y)
Beispiel #7
0
    def predict(
        self,
        prediction_dataset,
        batch_size=1,
        num_workers=0,
        activation_layer=None,  # softmax','sigmoid','softmax_and_logit', None
        binary_preds=None,  #'single_target','multi_target', None
        threshold=0.5,
        error_log=None,
    ):
        """Generate predictions on a dataset

        Choose to return any combination of scores, labels, and single-target or
        multi-target binary predictions. Also choose activation layer for scores
        (softmax, sigmoid, softmax then logit, or None).

        Note: the order of returned dataframes is (scores, preds, labels)

        Args:
            prediction_dataset:
                a Preprocessor or DataSset object that returns tensors,
                such as AudioToSpectrogramPreprocessor (no augmentation)
                or CnnPreprocessor (w/augmentation) from opensoundscape.datasets
            batch_size:
                Number of files to load simultaneously [default: 1]
            num_workers:
                parallelization (ie cpus or cores), use 0 for current process
                [default: 0]
            activation_layer:
                Optionally apply an activation layer such as sigmoid or
                softmax to the raw outputs of the model.
                options:
                - None: no activation, return raw scores (ie logit, [-inf:inf])
                - 'softmax': scores all classes sum to 1
                - 'sigmoid': all scores in [0,1] but don't sum to 1
                - 'softmax_and_logit': applies softmax first then logit
                [default: None]
            binary_preds:
                Optionally return binary (thresholded 0/1) predictions
                options:
                - 'single_target': max scoring class = 1, others = 0
                - 'multi_target': scores above threshold = 1, others = 0
                - None: do not create or return binary predictions
                [default: None]
            threshold:
                prediction threshold for sigmoid scores. Only relevant when
                binary_preds == 'multi_target'
            error_log:
                if not None, saves a list of files that raised errors to
                the specified file location [default: None]

        Returns: 3 DataFrames (or Nones), w/index matching prediciton_dataset.df
            scores: post-activation_layer scores
            predictions: 0/1 preds for each class
            labels: labels from dataset (if available)

        Note: if loading an audio file raises a PreprocessingError, the scores
            and predictions for that sample will be np.nan

        Note: if no return type selected for labels/scores/preds, returns None
        instead of a DataFrame in the returned tuple
        """
        err_msg = ("Prediction dataset must have same classes"
                   "and class order as model object, or no classes.")
        if len(prediction_dataset.df.columns) > 0:
            assert list(self.classes) == list(
                prediction_dataset.df.columns), err_msg

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.network.to(self.device)

        self.network.eval()

        # SafeDataset will not fail on bad files,
        # but will provide a different sample! Later we go back and replace scores
        # with np.nan for the bad samples (using safe_dataset._unsafe_indices)
        # this approach to error handling feels hacky
        safe_dataset = SafeDataset(prediction_dataset)

        dataloader = torch.utils.data.DataLoader(
            safe_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=False,
            # use pin_memory=True when loading files on CPU and training on GPU
            pin_memory=torch.cuda.is_available(),
        )

        ### Prediction ###
        total_scores = []
        total_preds = []
        total_tgts = []

        failed_files = []  # keep list of any samples that raise errors

        has_labels = False

        # disable gradient updates during inference
        with torch.set_grad_enabled(False):

            for batch in dataloader:
                # get batch of Tensors
                batch_tensors = batch["X"].to(self.device)
                batch_tensors.requires_grad = False
                # get batch's labels if available
                batch_targets = torch.Tensor([]).to(self.device)
                if "y" in batch.keys():
                    batch_targets = batch["y"].to(self.device)
                    batch_targets.requires_grad = False
                    has_labels = True

                # forward pass of network: feature extractor + classifier
                logits = self.network.forward(batch_tensors)

                ### Activation layer ###
                if activation_layer == None:  # scores [-inf,inf]
                    scores = logits
                elif activation_layer == "softmax":
                    # "softmax" activation: preds across all classes sum to 1
                    scores = softmax(logits, 1)
                elif activation_layer == "sigmoid":  # map [-inf,inf] to [0,1]
                    scores = torch.sigmoid(logits)
                elif activation_layer == "softmax_and_logit":  # scores [-inf,inf]
                    scores = torch.logit(softmax(logits, 1))
                else:
                    raise ValueError(
                        f"invalid option for activation_layer: {activation_layer}"
                    )

                ### Binary predictions ###
                # generate binary predictions
                if binary_preds == "single_target":
                    # predict highest scoring class only
                    batch_preds = F.one_hot(logits.argmax(1), len(logits[0]))
                elif binary_preds == "multi_target":
                    # predict 0 or 1 based on a fixed threshold
                    batch_preds = torch.sigmoid(logits) >= threshold
                elif binary_preds is None:
                    batch_preds = torch.Tensor([])
                else:
                    raise ValueError(
                        f"invalid option for binary_preds: {binary_preds}")

                # detach the returned values: currently tethered to gradients
                # and updates via optimizer/backprop. detach() returns
                # just numeric values.
                total_scores.append(scores.detach().cpu().numpy())
                total_preds.append(batch_preds.float().detach().cpu().numpy())
                total_tgts.append(batch_targets.int().detach().cpu().numpy())

        # aggregate across all batches
        total_tgts = np.concatenate(total_tgts, axis=0)
        total_scores = np.concatenate(total_scores, axis=0)
        total_preds = np.concatenate(total_preds, axis=0)

        print(np.shape(total_scores))

        # replace scores/preds with nan for samples that failed in preprocessing
        # this feels hacky (we predicted on substitute-samples rather than
        # skipping the samples that failed preprocessing)
        total_scores[safe_dataset._unsafe_indices, :] = np.nan
        if binary_preds is not None:
            total_preds[safe_dataset._unsafe_indices, :] = np.nan

        # return 3 DataFrames with same index/columns as prediction_dataset's df
        # use None for placeholder if no preds / labels
        samples = prediction_dataset.df.index.values
        score_df = pd.DataFrame(index=samples,
                                data=total_scores,
                                columns=self.classes)
        pred_df = (None if binary_preds is None else pd.DataFrame(
            index=samples, data=total_preds, columns=self.classes))
        label_df = (None if not has_labels else pd.DataFrame(
            index=samples, data=total_tgts, columns=self.classes))

        return score_df, pred_df, label_df
Beispiel #8
0
 print(x.abs())
 print(x.neg())
 print(x.acos())
 print(torch.ceil(x))
 print(torch.floor(x))
 print(torch.round(x))
 # print(torch.diff(x))
 print(torch.clamp(x, min=-0.5, max=0.5))
 print(torch.clamp(x, min=0.5))
 print(torch.clamp(x, max=0.5))
 print(torch.trunc(x))  # truncated integer values
 print(torch.frac(x))  # fractional portion of each element
 print(x.add(1))
 print(torch.exp(x))
 print(torch.expm1(x))
 print(torch.logit(x))
 print(torch.mul(x, 100))
 print(torch.addcdiv(t, t1, t2, value=0.1))  # t + value * t1 / t2
 print(torch.addcmul(t, t1, t2, value=0.1))  # t + value * t1 * t2
 print(torch.addmm(M, mat1, mat2))  # beta * M + alpha * mat1 * mat2
 print(torch.matmul(mat1, mat2))  # mat1 * mat2
 print(torch.mm(mat1, mat2))  # mat1 * mat2
 print(torch.matrix_power(mat1, 2))  # mat1 * mat1
 print(torch.addmv(x, mat1, x))  # β x+α (mat * x)
 print(torch.mv(mat1, x))  # mat * vec
 print(torch.outer(x, x))  # vec1⊗vec2
 print(torch.renorm(mat1, 1, 0, 5))
 input_ = torch.tensor([10000., 1e-07])
 other_ = torch.tensor([10000.1, 1e-08])
 print(torch.floor_divide(input_, other_))  # trunc(input_ / other_)
 print(torch.allclose(input_, other_))  # ∣input−other∣≤atol+rtol×∣other∣
Beispiel #9
0
    def forward(
        self,
        inputs: Union[Dict[str, torch.Tensor], Dict[str, np.ndarray],
                      Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]],
        mask=None,
    ) -> Dict[str, torch.Tensor]:
        if self.compiled_model is None:
            raise ValueError("Model has not been trained yet.")

        if isinstance(inputs, tuple):
            inputs, targets = inputs
            # Convert targets to tensors.
            for target_feature_name, target_value in targets.items():
                if not isinstance(target_value, torch.Tensor):
                    targets[target_feature_name] = torch.from_numpy(
                        target_value)
                else:
                    targets[target_feature_name] = target_value
        else:
            targets = None

        assert list(inputs.keys()) == self.input_features.keys()

        # Convert inputs to tensors.
        for input_feature_name, input_values in inputs.items():
            if not isinstance(input_values, torch.Tensor):
                inputs[input_feature_name] = torch.from_numpy(input_values)
            else:
                inputs[input_feature_name] = input_values.view(-1, 1)

        # TODO(travis): include encoder and decoder steps during inference
        # encoder_outputs = {}
        # for input_feature_name, input_values in inputs.items():
        #     encoder = self.input_features[input_feature_name]
        #     encoder_output = encoder(input_values)
        #     encoder_outputs[input_feature_name] = encoder_output

        # concatenate inputs
        inputs = torch.cat(list(inputs.values()), dim=1)

        # Invoke output features.
        output_logits = {}
        output_feature_name = self.output_features.keys()[0]
        output_feature = self.output_features[output_feature_name]

        preds = self.compiled_model(inputs)

        if output_feature.type() == NUMBER:
            # regression
            if len(preds.shape) == 2:
                preds = preds.squeeze(1)
            logits = preds
        else:
            # classification
            _, probs = preds
            # keep positive class only for binary feature
            probs = probs[:, 1] if output_feature.type() == BINARY else probs
            logits = torch.logit(probs)
        output_feature_utils.set_output_feature_tensor(output_logits,
                                                       output_feature_name,
                                                       LOGITS, logits)

        return output_logits
Beispiel #10
0
                 torch.tensor([True, False, False]))
torch.logical_or(r, s)
torch.logical_or(r.double(), s.double())
torch.logical_or(r.double(), s)
torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool))

# logical_xor
torch.logical_xor(torch.tensor([True, False, True]),
                  torch.tensor([True, False, False]))
torch.logical_xor(r, s)
torch.logical_xor(r.double(), s.double())
torch.logical_xor(r.double(), s)
torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool))

# logit
torch.logit(torch.rand(5), eps=1e-6)

# hypot
torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0]))

# i0
torch.i0(torch.arange(5, dtype=torch.float32))

# igamma/igammac
a1 = torch.tensor([4.0])
a2 = torch.tensor([3.0, 4.0, 5.0])
torch.igamma(a1, a2)
torch.igammac(a1, a2)

# mul/multiply
torch.mul(torch.randn(3), 100)
Beispiel #11
0
 def forward(self, a, b):
     return torch.logit(0.5 - 0.5 * dot_product(a, b, normalize=True))
def rescale_logit(img,lambd=1e-6):
    ## logit space
    return torch.logit(lambd+(1-2*lambd)*img)
Beispiel #13
0
def haarpsi(
        x: Tensor,
        y: Tensor,
        n_kernels: int = 3,
        value_range: float = 1.,
        c: float = 30. / (255.**2),
        alpha: float = 4.2,
) -> Tensor:
    r"""Returns the HaarPSI between :math:`x` and :math:`y`,
    without color space conversion.

    Args:
        x: An input tensor, :math:`(N, 3 \text{ or } 1, H, W)`.
        y: A target tensor, :math:`(N, 3 \text{ or } 1, H, W)`.
        n_kernels: The number of Haar wavelet kernels to use.
        value_range: The value range :math:`L` of the inputs (usually `1.` or `255`).

    Note:
        For the remaining arguments, refer to [Reisenhofer2018]_.

    Returns:
        The HaarPSI vector, :math:`(N,)`.

    Example:
        >>> x = torch.rand(5, 3, 256, 256)
        >>> y = torch.rand(5, 3, 256, 256)
        >>> l = haarpsi(x, y)
        >>> l.size()
        torch.Size([5])
    """

    c *= value_range**2

    # Y
    y_x, y_y = x[:, :1], y[:, :1]

    ## Gradient(s)
    g_xy: List[Tuple[Tensor, Tensor]] = []

    for j in range(1, n_kernels + 1):
        kernel_size = int(2**j)

        ### Haar wavelet kernel
        kernel = gradient_kernel(haar_kernel(kernel_size)).to(x.device)

        ### Haar filter (gradient)
        pad = kernel_size // 2

        g_x = channel_conv(y_x, kernel, padding=pad)[..., 1:, 1:].abs()
        g_y = channel_conv(y_y, kernel, padding=pad)[..., 1:, 1:].abs()

        g_xy.append((g_x, g_y))

    ## Gradient similarity(ies)
    gs = []
    for g_x, g_y in g_xy[:-1]:
        gs.append((2. * g_x * g_y + c) / (g_x**2 + g_y**2 + c))

    ## Local similarity(ies)
    ls = torch.stack(gs, dim=-1).sum(dim=-1) / 2.  # (N, 2, H, W)

    ## Weight(s)
    w = torch.stack(g_xy[-1], dim=-1).max(dim=-1)[0]  # (N, 2, H, W)

    # IQ
    if x.size(1) == 3:
        iq_x, iq_y = x[:, 1:], y[:, 1:]

        ## Mean filter
        m_x = F.avg_pool2d(iq_x, 2, stride=1, padding=1)[..., 1:, 1:].abs()
        m_y = F.avg_pool2d(iq_y, 2, stride=1, padding=1)[..., 1:, 1:].abs()

        ## Chromatic similarity(ies)
        cs = (2. * m_x * m_y + c) / (m_x**2 + m_y**2 + c)

        ## Local similarity(ies)
        ls = torch.cat([ls, cs.mean(1, True)], dim=1)  # (N, 3, H, W)

        ## Weight(s)
        w = torch.cat([w, w.mean(1, True)], dim=1)  # (N, 3, H, W)

    # HaarPSI
    hs = torch.sigmoid(ls * alpha)
    hpsi = (hs * w).sum(dim=(-1, -2, -3)) / w.sum(dim=(-1, -2, -3))
    hpsi = (torch.logit(hpsi) / alpha)**2

    return hpsi
Beispiel #14
0
def log_deriv_logit(x, eps=1.0e-8):
    """ logarithm of logit derivative """
    y = torch.logit(torch.clamp(x, eps, 1.0 - eps))
    return -log_deriv_sigmoid(y)
Beispiel #15
0
 def forward(self, x, log_df_dz):
     x = torch.clamp(x, self.eps, 1.0 - self.eps)
     log_det = log_deriv_logit(x)
     log_det = torch.sum(log_det.view(x.size(0), -1), dim=1)
     return torch.logit(x), log_df_dz + log_det
Beispiel #16
0
 def backward(self, x, log_df_dz):
     x = torch.clamp(x, 1.0e-8, 1.0 - 1.0e-8)
     log_det = log_deriv_logit(x)
     log_det = torch.sum(log_det.view(x.size(0), -1), dim=1)
     return torch.logit(x), log_df_dz + log_det