예제 #1
0
    def test_iris_empty_test(self):
        data = load_data(
            "iris",
            train_size=150,
        )

        assert len(data.train_data) == 150
        assert len(data.train_target) == 150
        assert len(data.test_data) == 0
        assert len(data.test_target) == 0
        assert data.train_target.shape[1] == 16
        assert data.test_target.shape[1] == 16
예제 #2
0
    def test_circles_basic(self):
        data = load_data("circles",
                         train_size=150,
                         test_size=50,
                         target_length=2)

        assert len(data.train_data) == 150
        assert len(data.train_target) == 150
        assert len(data.test_data) == 50
        assert len(data.test_target) == 50

        assert len(data.train_data[0]) == 2
        assert len(data.train_target[0]) == 2
        assert len(data.test_data[0]) == 2
        assert len(data.test_target[0]) == 2
예제 #3
0
    def test_iris_basic(self):
        num_classes = 3
        data = load_data(
            "iris",
            train_size=120,
            test_size=30,
            target_length=num_classes,  # num of classes
        )

        assert len(data.train_data) == 120
        assert len(data.train_target) == 120
        assert len(data.test_data) == 30
        assert len(data.test_target) == 30

        assert len(data.train_data[0]) == 4
        assert len(data.train_target[0]) == num_classes
        assert len(data.test_data[0]) == 4
        assert len(data.test_target[0]) == num_classes
예제 #4
0
    def test_mnist_basic(self):
        classes = (6, 7, 8)
        data = load_data(
            "mnist",
            wires=5,
            classes=classes,
            train_size=150,
            test_size=50,
            target_length=len(classes),
        )

        assert len(data.train_data) == 150
        assert len(data.train_target) == 150
        assert len(data.test_data) == 50
        assert len(data.test_target) == 50

        assert len(data.train_data[0]) == 5
        assert len(data.train_target[0]) == len(classes)
        assert len(data.test_data[0]) == 5
        assert len(data.test_target[0]) == len(classes)
예제 #5
0
파일: main.py 프로젝트: cirKITers/masKIT
            exclude=("data", "target"),
        )
    seed = train_params.pop("seed", 1337)
    np.random.seed(seed)
    random.seed(seed)

    data_params = {
        "wires": train_params["wires"],
        "classes": [6, 9],
        "train_size": 120,
        "test_size": 100,
        "shuffle": True,
        "target_length": len(train_params.get("interpret", (0, ))),
    }
    try:
        data = load_data(train_params.pop("dataset"), **data_params)
    except ValueError:
        data = DataSet(None, None, None, None)
    testing = train_params.pop("testing", False)
    result = train(
        **train_params,
        data=data.train_data,
        target=data.train_target,
        validation_data=data.validation_data,
        validation_target=data.validation_target,
    )
    if testing:
        test(
            result["__circuit"],
            masked_circuit=result["__masked_circuit"],
            rotations=result["__rotations"],