def test_predict_num_roles():
    """
    Test predict function of HostFootprint class with
    varying number of distinct roles present
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        testdata = os.path.join(tmpdir, 'test_data')
        shutil.copytree('./tests/test_data', testdata)
        for file in ['combined_three_roles.csv', 'combined_two_roles.csv']:
            input_file = os.path.join(testdata, file)
            operation = 'train'
            sys.argv = hf_args(tmpdir, operation, input_file)
            instance = HostFootprint()
            instance.main()
            operation = 'predict'
            sys.argv = hf_args(tmpdir, operation, input_file)
            instance = HostFootprint()
            instance.main()

            predictions = json.loads(instance.predict())
            assert isinstance(predictions, dict)
            # Check if number of predictions is correct
            if file == 'combined_three_roles.csv':
                assert len(predictions) == 6
            else:
                assert len(predictions) == 4
def test_train():
    """Test training function of HostFootprint class"""
    with tempfile.TemporaryDirectory() as tmpdir:
        testdata = os.path.join(tmpdir, 'test_data')
        shutil.copytree('./tests/test_data', testdata)
        input_file = os.path.join(testdata, 'combined.csv')
        operation = 'train'
        sys.argv = hf_args(tmpdir, operation, input_file)
        instance = HostFootprint()
        instance.main()
def test_train_bad_data_too_few_columns():
    """
    This test tries to train a model on a mal-formed csv with too few fields
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        testdata = os.path.join(tmpdir, 'test_data')
        shutil.copytree('./tests/test_data', testdata)
        input_file = os.path.join(testdata, 'bad_data_too_few_columns.csv')
        operation = 'train'
        sys.argv = hf_args(tmpdir, operation, input_file)
        instance = HostFootprint()
        with pytest.raises(Exception):
            instance.main()
Ejemplo n.º 4
0
 def run_algorithm_stage(self, in_path):
     raw_args = self.add_opt_args(self.stage_args['algorithm'])
     raw_args.extend(['-O', self.operation, '-v', self.log_level, in_path])
     instance = HostFootprint(raw_args=raw_args)
     return instance.main()