예제 #1
0
    def __init__(self, feature_map, training_dataset, test_dataset=None, datapoints=None,
                 multiclass_extension=None):
        """Constructor.

        Args:
            feature_map (FeatureMap): feature map module, used to transform data
            training_dataset (dict): training dataset.
            test_dataset (Optional[dict]): testing dataset.
            datapoints (Optional[numpy.ndarray]): prediction dataset.
            multiclass_extension (Optional[MultiExtension]): if number of classes > 2 then
                a multiclass scheme is needed.

        Raises:
            ValueError: if training_dataset is None
            AquaError: use binary classifer for classes > 3
        """
        super().__init__()
        if training_dataset is None:
            raise ValueError('Training dataset must be provided')

        is_multiclass = get_num_classes(training_dataset) > 2
        if is_multiclass:
            if multiclass_extension is None:
                raise AquaError('Dataset has more than two classes. '
                                'A multiclass extension must be provided.')
        else:
            if multiclass_extension is not None:
                logger.warning("Dataset has just two classes. "
                               "Supplied multiclass extension will be ignored")

        self.training_dataset, self.class_to_label = split_dataset_to_data_and_labels(
            training_dataset)
        if test_dataset is not None:
            self.test_dataset = split_dataset_to_data_and_labels(test_dataset,
                                                                 self.class_to_label)
        else:
            self.test_dataset = None

        self.label_to_class = {label: class_name for class_name, label
                               in self.class_to_label.items()}
        self.num_classes = len(list(self.class_to_label.keys()))

        if datapoints is not None and not isinstance(datapoints, np.ndarray):
            datapoints = np.asarray(datapoints)
        self.datapoints = datapoints

        self.feature_map = feature_map
        self.num_qubits = self.feature_map.num_qubits

        if multiclass_extension is None:
            qsvm_instance = _QSVM_Binary(self)
        else:
            qsvm_instance = _QSVM_Multiclass(self, multiclass_extension)

        self.instance = qsvm_instance
예제 #2
0
    def __init__(
            self,
            feature_map: FeatureMap,
            training_dataset: Optional[Dict[str, np.ndarray]] = None,
            test_dataset: Optional[Dict[str, np.ndarray]] = None,
            datapoints: Optional[np.ndarray] = None,
            multiclass_extension: Optional[MulticlassExtension] = None
    ) -> None:
        """

        Args:
            feature_map: feature map module, used to transform data
            training_dataset: training dataset.
            test_dataset: testing dataset.
            datapoints: prediction dataset.
            multiclass_extension: if number of classes > 2 then
                a multiclass scheme is needed.

        Raises:
            AquaError: Using binary classifier when number of classes > 2
        """
        super().__init__()
        # check the validity of provided arguments if possible
        if training_dataset is not None:
            is_multiclass = get_num_classes(training_dataset) > 2
            if is_multiclass:
                if multiclass_extension is None:
                    raise AquaError('Dataset has more than two classes. '
                                    'A multiclass extension must be provided.')
            else:
                if multiclass_extension is not None:
                    logger.warning(
                        "Dataset has just two classes. "
                        "Supplied multiclass extension will be ignored")

        self.training_dataset = None
        self.test_dataset = None
        self.datapoints = None
        self.class_to_label = None
        self.label_to_class = None
        self.num_classes = None

        self.setup_training_data(training_dataset)
        self.setup_test_data(test_dataset)
        self.setup_datapoint(datapoints)

        self.feature_map = feature_map
        self.num_qubits = self.feature_map.num_qubits

        if multiclass_extension is None:
            qsvm_instance = _QSVM_Binary(self)
        else:
            qsvm_instance = _QSVM_Multiclass(self, multiclass_extension)

        self.instance = qsvm_instance