def test_get_set_params_invariance(teardown): pca = PCA(name="pca") classifier = RandomForestClassifier(name="classifier") x = Input() h = pca(x) y = classifier(h) model = Model(x, y) params1 = model.get_params() model.set_params(**params1) params2 = model.get_params() assert params1 == params2
def test_get_set_params_invariance(teardown): scaler = StandardScaler(name="scaler") regressor = LinearRegression(name="regressor") x = Input() y_t = Input() y_t_scaled = scaler(y_t) y_p_scaled = regressor(x, y_t_scaled) y_p = scaler(y_p_scaled, compute_func="inverse_transform", trainable=False) model = Model(x, y_p, y_t) params1 = model.get_params() model.set_params(**params1) params2 = model.get_params() assert params2 == params1
def test_set_params(teardown): dummy1 = DummyEstimator(name="dummy1") dummy2 = DummyEstimator(x=456, y="def", name="dummy2") concat = Concatenate(name="concat") # a step without get_params/set_params # a meaningless pipeline that contains shared steps x1 = Input() x2 = Input() h = dummy1(x1) c = concat([x1, h]) y1 = dummy2(c) y2 = dummy2(x2, compute_func=lambda X: X * 2, trainable=False) model = Model([x1, x2], [y1, y2]) # Fails when setting params on step that does not implement set_params new_params_wrong = {"concat__axis": 2} with pytest.raises(AttributeError): model.set_params(**new_params_wrong) # Fails when setting params on step that does not exist new_params_wrong = {"non_existent_step__param": 42} with pytest.raises(ValueError): model.set_params(**new_params_wrong) # Fails when setting a non-existent param in a step new_params_wrong = {"dummy1__non_existent_param": 42} with pytest.raises(ValueError): model.set_params(**new_params_wrong) new_dummy = DummyEstimator() new_params = { "dummy2": new_dummy, "dummy1__x": 100, "dummy1__y": "pqr", "dummy2__x": 789, "dummy2__y": "ijk", } model.set_params(**new_params) params = model.get_params() expected = { "dummy1": dummy1, "dummy2": new_dummy, "concat": concat, "dummy1__x": 100, "dummy1__y": "pqr", "dummy2__x": 789, "dummy2__y": "ijk", } assert params == expected # Connectivity of the new step should be the same as the old step assert new_dummy.name is dummy2.name for port in range(2): assert new_dummy.get_inputs_at(port) is dummy2.get_inputs_at(port) assert new_dummy.get_outputs_at(port) is dummy2.get_outputs_at(port) assert new_dummy.get_targets_at(port) is dummy2.get_targets_at(port) assert new_dummy.get_trainable_at(port) is dummy2.get_trainable_at(port) assert new_dummy.get_compute_func_at(port) is dummy2.get_compute_func_at(port)
def test_set_params(teardown): pca = PCA(name="pca") classifier = RandomForestClassifier(name="classifier") concat = Concatenate(name="concat") # a step without get_params/set_params x = Input() h = pca(x) c = concat([x, h]) y = classifier(c) model = Model(x, y) # Fails when setting params on step that does not implement set_params new_params_wrong = {"concat__axis": 2} with pytest.raises(AttributeError): model.set_params(**new_params_wrong) # Fails when setting params on step that does not exist new_params_wrong = {"non_existent_step__param": 42} with pytest.raises(ValueError): model.set_params(**new_params_wrong) # Fails when setting a non-existent param in a step new_params_wrong = {"pca__non_existent_param": 42} with pytest.raises(ValueError): model.set_params(**new_params_wrong) new_classifier = LogisticRegression() new_params = { "classifier": new_classifier, "pca__n_components": 4, "pca__whiten": True, "classifier__C": 100.0, "classifier__fit_intercept": False, "classifier__penalty": "l1", } model.set_params(**new_params) params = model.get_params() expected = { "pca": pca, "classifier": new_classifier, "concat": concat, "pca__n_components": 4, "pca__whiten": True, "pca__tol": 0.0, "pca__svd_solver": "auto", "pca__copy": True, "pca__random_state": None, "pca__iterated_power": "auto", "classifier__C": 100.0, "classifier__class_weight": None, "classifier__dual": False, "classifier__fit_intercept": False, "classifier__intercept_scaling": 1, "classifier__max_iter": 100, "classifier__multi_class": "warn", "classifier__n_jobs": None, "classifier__penalty": "l1", "classifier__random_state": None, "classifier__solver": "warn", "classifier__tol": 0.0001, "classifier__verbose": 0, "classifier__warm_start": False, "classifier__l1_ratio": None, } assert expected == params