def test_get_error_unusual_targets_shape(): from learning import error model = multioutputs.MultiOutputs([ helpers.SetOutputModel([1.0]), helpers.SetOutputModel([1.0, 1.0, 1.0]) ]) assert validation.get_error(model, [[]], [[[1.0], [1.0, 1.0, 1.0]]], error_func=MeanSquaredError()) == 0.0 assert validation.get_error(model, [[]], [[[1.0], [0.0, 0.0, 0.0]]], error_func=MeanSquaredError()) == 0.5
def test_benchmark(monkeypatch): # Patch time.clock so time attribute is deterministic monkeypatch.setattr(time, 'clock', lambda : 0.0) # Make network that returns set output for patterns = [ ([0], [1]), ([1], [1]), ([2], [1]) ] model = helpers.SetOutputModel([1]) # Track patterns for training training_patterns = [] def post_pattern_callback(network_, input_vec, target_vec): training_patterns.append((list(input_vec), list(target_vec))) # Cross validate with deterministic network, and check output stats = validation.benchmark(model, zip(*patterns), num_folds=3, num_runs=2, iterations=1, post_pattern_callback=post_pattern_callback) # Check assert (helpers.fix_numpy_array_equality(stats) == helpers.fix_numpy_array_equality(_BENCHMARK_STATS)) assert training_patterns == [([1], [1]), ([2], [1]), # First fold ([0], [1]), ([2], [1]), # Second fold ([0], [1]), ([1], [1]), # Third fold ([1], [1]), ([2], [1]), # First fold 2 ([0], [1]), ([2], [1]), # Second fold 2 ([0], [1]), ([1], [1])] # Third fold 2
def test_get_error(): model = helpers.SetOutputModel([1]) assert validation.get_error( model, numpy.array([[1]]), numpy.array([[0]]), error_func=MeanSquaredError()) == 1.0 assert validation.get_error( model, numpy.array([[1]]), numpy.array([[1]]), error_func=MeanSquaredError()) == 0.0 assert validation.get_error( model, numpy.array([[1]]), numpy.array([[0.5]]), error_func=MeanSquaredError()) == 0.25 assert validation.get_error( model, numpy.array([[1], [1]]), numpy.array([[1], [0]]), error_func=MeanSquaredError()) == 0.5 assert validation.get_error( model, numpy.array([[1], [1]]), numpy.array([[0.5], [0.5]]), error_func=MeanSquaredError()) == 0.25
def test_get_accuracy(): model = helpers.SetOutputModel([1]) assert validation.get_accuracy(model, numpy.array([[1], [1]]), numpy.array([[1], [0]])) == 0.5 assert validation.get_accuracy(model, numpy.array([[1], [1]]), numpy.array([[1], [1]])) == 1.0 assert validation.get_accuracy(model, numpy.array([[1], [1]]), numpy.array([[0], [0]])) == 0.0
def test_break_on_no_improvement_completely_stagnant(): nn = helpers.SetOutputModel(1.0) # Stop training if error does not improve after 5 iterations nn.train([[0.0]], [[0.0]], error_stagnant_distance=10, error_stagnant_threshold=None, error_improve_iters=5) assert nn.iteration == 6 # The 6th is 5 away from the first
def test_bagger(): # Create dummy layers that return set outputs outputs = [[0, 1, 2], [1, 2, 3]] models = [helpers.SetOutputModel(output) for output in outputs] bagger = ensemble.Bagger(models) # Assert bagger returns average of those outputs output = bagger.activate([]) assert list(output) == [0.5, 1.5, 2.5]
def test_compare(monkeypatch): # Patch time.clock so time attribute is deterministic monkeypatch.setattr(time, 'clock', lambda: 0.0) # Make network that returns set output for patterns = [([0], [1]), ([1], [1]), ([2], [1])] model = helpers.SetOutputModel([1]) model2 = helpers.SetOutputModel([1]) # Cross validate with deterministic network, and check output stats = validation.compare(['model', 'model2'], [model, model2], zip(*patterns), num_folds=3, num_runs=2, all_kwargs={'iterations': 1}) # Check assert (helpers.fix_numpy_array_equality(_drop_models_model_stat(stats)) == helpers.fix_numpy_array_equality(_COMPARE_STATS))
def test_break_on_stagnation_completely_stagnant(): # If error doesn't change by enough after enough iterations # stop training nn = helpers.SetOutputModel(1.0) # Stop training if error does not change by more than threshold after # distance iterations nn.train([[0.0]], [[0.0]], error_stagnant_distance=5, error_stagnant_threshold=0.01) assert nn.iteration == 6 # The 6th is 5 away from the first
def test_unserialize_wrong_type(): """Model.unserialize should raise error if serialized model is of wrong type.""" with pytest.raises(ValueError): base.Model.unserialize(helpers.SetOutputModel(1.0).serialize())
def test_unserialize(): model = helpers.SetOutputModel(random.uniform(0, 1)) model_copy = helpers.SetOutputModel.unserialize(model.serialize()) assert model_copy.__dict__ == model.__dict__, 'Should have same content' assert model_copy is not model, 'Should have different id'
def test_serialize(): model = helpers.SetOutputModel(1.0) assert isinstance(model.serialize(), str), 'Model.serialize should return string'
def test_multioutputs_activate(): model = multioutputs.MultiOutputs(helpers.SetOutputModel(1), 2) assert model.activate([None]) == [1, 1]