Esempio n. 1
0
    def margin(self, margin: float) -> None:
        if not isinstance(margin, float):
            raise e.TypeError("`margin` should be a float")
        if margin <= 0:
            raise e.ValueError("`margin` should be greater than 0")

        self._margin = margin
Esempio n. 2
0
    def batch_size(self, batch_size: int) -> None:
        if not isinstance(batch_size, int):
            raise e.TypeError("`batch_size` should be a integer")
        if batch_size <= 0:
            raise e.ValueError("`batch_size` should be greater than 0")

        self._batch_size = batch_size
Esempio n. 3
0
def test_value_error():
    new_exception = exception.ValueError("error")

    try:
        raise new_exception
    except exception.ValueError:
        pass
Esempio n. 4
0
    def __init__(
        self,
        data: np.array,
        labels: np.array,
        n_pairs: Optional[int] = 2,
        batch_size: Optional[int] = 1,
        input_shape: Optional[Tuple[int, ...]] = None,
        normalize: Optional[Tuple[int, int]] = (0, 1),
        shuffle: Optional[bool] = True,
        seed: Optional[int] = 0,
    ):
        """Initialization method.

        Args:
            data: Array of samples.
            labels: Array of labels.
            n_pairs: Number of pairs.
            batch_size: Batch size.
            input_shape: Shape of the reshaped array.
            normalize: Normalization bounds.
            shuffle: Whether data should be shuffled or not.
            seed: Provides deterministic traits when using `random` module.

        """

        logger.info("Overriding class: Dataset -> BalancedPairDataset.")

        super(BalancedPairDataset, self).__init__(batch_size, input_shape,
                                                  normalize, shuffle, seed)

        try:
            # Checks if supplied labels are not equal
            assert not np.all(labels == labels[0])

        except:
            raise e.ValueError("`labels` should have distinct values")

        # Amount of pairs
        self.n_pairs = n_pairs

        data = self.preprocess(data)
        pairs = self.create_pairs(data, labels)

        self._build(pairs)

        logger.info("Class overrided.")
Esempio n. 5
0
    def loss_type(self, loss_type: str) -> None:
        if loss_type not in ["hard", "semi-hard"]:
            raise e.ValueError("`loss_type` should be `hard` or `semi-hard`")

        self._loss_type = loss_type
Esempio n. 6
0
    def distance(self, distance: str) -> None:
        if distance not in ["L1", "L2", "squared-L2", "angular"]:
            raise e.ValueError("`distance` should be `L1`, `L2` or `angular`")

        self._distance = distance
Esempio n. 7
0
    def distance(self, distance: str) -> None:
        if distance not in ["concat", "diff"]:
            raise e.ValueError("`distance` should be `concat`, or `diff`")

        self._distance = distance