示例#1
0
def test_bad_newline(live_server):
    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())

    run_batch_predictions(base_url=base_url,
                          base_headers={},
                          user='******',
                          pwd='password',
                          api_token=None,
                          create_api_token=False,
                          pid='56dd9570018e213242dfa93c',
                          lid='56dd9570018e213242dfa93d',
                          import_id=None,
                          n_retry=3,
                          concurrent=1,
                          resume=False,
                          n_samples=10,
                          out_file='out.csv',
                          keep_cols=None,
                          delimiter=',',
                          dataset='tests/fixtures/diabetes_bad_newline.csv',
                          pred_name=None,
                          timeout=None,
                          ui=ui,
                          auto_sample=False,
                          fast_mode=False,
                          dry_run=False,
                          encoding='',
                          skip_dialect=False)

    lines = len(open('out.csv', 'rb').readlines())

    assert lines == 5
    ui.warning.assert_any_call('Detected empty rows in the CSV file. '
                               'These rows will be discarded.')
def test_no_delimiter(live_server):
    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    with pytest.raises(csv.Error) as ctx:
        run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            import_id=None,
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file='out.csv',
            keep_cols=None,
            delimiter=';',
            dataset='tests/fixtures/temperatura_predict.csv',
            pred_name=None,
            timeout=None,
            ui=ui,
            auto_sample=False,
            fast_mode=False,
            dry_run=False,
            encoding='',
            skip_dialect=False
        )
    assert str(ctx.value) == ("Could not determine delimiter")
示例#3
0
def test_header_only(live_server):
    ui = PickableMock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    with pytest.raises(ValueError) as ctx:
        run_batch_predictions(base_url=base_url,
                              base_headers={},
                              user='******',
                              pwd='password',
                              api_token=None,
                              create_api_token=False,
                              pid='56dd9570018e213242dfa93c',
                              lid='56dd9570018e213242dfa93d',
                              import_id=None,
                              n_retry=3,
                              concurrent=1,
                              resume=False,
                              n_samples=10,
                              out_file='out.csv',
                              keep_cols=None,
                              delimiter=',',
                              dataset='tests/fixtures/header_only.csv',
                              pred_name=None,
                              timeout=30,
                              ui=ui,
                              auto_sample=False,
                              fast_mode=False,
                              dry_run=False,
                              encoding='',
                              skip_dialect=False)
    assert str(ctx.value) == ("Input file 'tests/fixtures/header_only.csv' "
                              "is empty.")
def test_no_delimiter(live_server):
    ui = mock.Mock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    with pytest.raises(ValueError) as ctx:
        run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file='out.csv',
            keep_cols=None,
            delimiter=';',
            dataset='tests/fixtures/temperatura_predict.csv',
            pred_name=None,
            timeout=30,
            ui=ui
        )
    assert str(ctx.value) == ("Delimiter ';' not found. "
                              "Please check your input file "
                              "or consider the flag `--delimiter=''`.")
def test_header_only(live_server):
    ui = PickableMock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    with pytest.raises(ValueError) as ctx:
        run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file='out.csv',
            keep_cols=None,
            delimiter=',',
            dataset='tests/fixtures/header_only.csv',
            pred_name=None,
            timeout=30,
            ui=ui,
            auto_sample=False,
            fast_mode=False,
            dry_run=False,
            encoding='',
            skip_dialect=False
        )
    assert str(ctx.value) == ("Input file 'tests/fixtures/header_only.csv' "
                              "is empty.")
def test_empty_file(live_server):
    ui = mock.Mock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    with pytest.raises(ValueError) as ctx:
        run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file='out.csv',
            keep_cols=None,
            delimiter=',',
            dataset='tests/fixtures/empty.csv',
            pred_name=None,
            timeout=30,
            ui=ui
        )
    assert str(ctx.value) == "Input file 'tests/fixtures/empty.csv' is empty."
def test_gzipped_csv(live_server, ui):
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file='out.csv',
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/temperatura_predict.csv.gz',
        pred_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False,
        max_batch_size=1000
    )

    assert ret is None
示例#8
0
def test_explicit_delimiter_gzip(live_server):
    ui = PickableMock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file='out.csv',
        keep_cols=None,
        delimiter=',',
        dataset='tests/fixtures/temperatura_predict.csv.gz',
        pred_name=None,
        timeout=30,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False)

    assert ret is None
示例#9
0
def main_standalone(argv=sys.argv[1:]):
    freeze_support()
    global ui  # global variable hack, will get rid of a bit later
    warnings.simplefilter('ignore')
    parsed_args = parse_args(argv, standalone=True)
    exit_code = 1

    generic_opts = parse_generic_options(parsed_args)
    import_id = parsed_args['import_id']
    try:
        exit_code = run_batch_predictions(base_headers={},
                                          user=None,
                                          pwd=None,
                                          api_token=None,
                                          create_api_token=False,
                                          pid=None,
                                          lid=None,
                                          import_id=import_id,
                                          ui=ui,
                                          **generic_opts)
    except SystemError:
        pass
    except ShelveError as e:
        ui.error(str(e))
    except KeyboardInterrupt:
        ui.info('Keyboard interrupt')
    except Exception as e:
        ui.fatal(str(e))
    finally:
        ui.close()
        return exit_code
def test_regression(live_server, tmpdir):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui = mock.Mock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93e',
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/regression_predict.csv',
        pred_name=None,
        timeout=30,
        ui=ui
    )

    assert ret is None

    actual = out.read_text('utf-8')
    with open('tests/fixtures/regression_output.csv', 'r') as f:
        assert actual == f.read()
示例#11
0
def test_explicit_delimiter_gzip(live_server):
    ui = PickableMock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file='out.csv',
        keep_cols=None,
        delimiter=',',
        dataset='tests/fixtures/temperatura_predict.csv.gz',
        pred_name=None,
        timeout=30,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )

    assert ret is None
def test_tab_delimiter(live_server):
    ui = mock.Mock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file='out.csv',
        keep_cols=None,
        delimiter='\t',
        dataset='tests/fixtures/temperatura_predict_tab.csv',
        pred_name=None,
        timeout=30,
        ui=ui
    )

    assert ret is None
示例#13
0
def main(argv=sys.argv[1:]):
    freeze_support()
    global ui  # global variable hack, will get rid of a bit later
    warnings.simplefilter('ignore')
    parsed_args = parse_args(argv)
    exit_code = 1

    generic_opts = parse_generic_options(parsed_args)

    # parse args
    pid = parsed_args['project_id']
    lid = parsed_args['model_id']
    try:
        verify_objectid(pid)
        verify_objectid(lid)
    except ValueError as e:
        ui.fatal(str(e))

    # auth only ---
    datarobot_key = parsed_args.get('datarobot_key')
    api_token = parsed_args.get('api_token')
    create_api_token = parsed_args.get('create_api_token')
    user = parsed_args.get('user')
    pwd = parsed_args.get('password')

    if not generic_opts['dry_run']:
        user = user or ui.prompt_user()
        user = user.strip()

        if not api_token and not pwd:
            pwd = ui.getpass()

    base_headers = {}
    if datarobot_key:
        base_headers['datarobot-key'] = datarobot_key
    # end auth ---

    try:
        exit_code = run_batch_predictions(base_headers=base_headers,
                                          user=user,
                                          pwd=pwd,
                                          api_token=api_token,
                                          create_api_token=create_api_token,
                                          pid=pid,
                                          lid=lid,
                                          import_id=None,
                                          ui=ui,
                                          **generic_opts)
    except SystemError:
        pass
    except ShelveError as e:
        ui.error(str(e))
    except KeyboardInterrupt:
        ui.info('Keyboard interrupt')
    except Exception as e:
        ui.fatal(str(e))
    finally:
        ui.close()
        return exit_code
示例#14
0
def test_regression(live_server,
                    tmpdir,
                    ui,
                    keep_cols=None,
                    in_fixture='tests/fixtures/regression_predict.csv',
                    out_fixture='tests/fixtures/regression_output.csv',
                    fast_mode=False,
                    skip_row_id=False,
                    output_delimiter=None,
                    skip_dialect=False,
                    n_samples=500,
                    max_batch_size=None,
                    expected_ret=None):
    # train one model in project
    out = tmpdir.join('out.csv')

    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(base_url=base_url,
                                base_headers={},
                                user='******',
                                pwd='password',
                                api_token=None,
                                create_api_token=False,
                                pid='56dd9570018e213242dfa93c',
                                lid='56dd9570018e213242dfa93e',
                                import_id=None,
                                n_retry=3,
                                concurrent=1,
                                resume=False,
                                n_samples=n_samples,
                                out_file=str(out),
                                keep_cols=keep_cols,
                                delimiter=None,
                                dataset=in_fixture,
                                pred_name=None,
                                timeout=30,
                                ui=ui,
                                auto_sample=False,
                                fast_mode=fast_mode,
                                dry_run=False,
                                encoding='',
                                skip_dialect=skip_dialect,
                                skip_row_id=skip_row_id,
                                output_delimiter=output_delimiter,
                                max_batch_size=max_batch_size)
    assert ret is expected_ret

    if out_fixture:
        actual = out.read_text('utf-8')
        with open(out_fixture, 'rU') as f:
            expected = f.read()
            print(len(actual), len(expected))
            assert actual == expected
示例#15
0
def test_422(live_server, tmpdir):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui_class = mock.Mock(spec=UI)
    ui = ui_class.return_value
    ui.fatal.side_effect = SystemExit
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    with pytest.raises(SystemExit):
        run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242eee422',
            import_id=None,
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file=str(out),
            keep_cols=None,
            delimiter=None,
            dataset='tests/fixtures/temperatura_predict.csv.gz',
            pred_name=None,
            timeout=None,
            ui=ui,
            auto_sample=False,
            fast_mode=False,
            dry_run=False,
            encoding='',
            skip_dialect=False
        )
    ui.fatal.assert_called()
    ui.fatal.assert_called_with(
        '''Predictions are not available because: "Server raised 422.".'''
    )
def test_quotechar_in_keep_cols(live_server):
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ui = PickableMock()
    with tempfile.NamedTemporaryFile(prefix='test_',
                                     suffix='.csv',
                                     delete=False) as fd:
        head = open("tests/fixtures/quotes_input_head.csv",
                    "rb").read()
        body_1 = open("tests/fixtures/quotes_input_first_part.csv",
                      "rb").read()
        body_2 = open("tests/fixtures/quotes_input_bad_part.csv",
                      "rb").read()
        fd.file.write(head)
        size = 0
        while size < DETECT_SAMPLE_SIZE_SLOW:
            fd.file.write(body_1)
            size += len(body_1)
        fd.file.write(body_2)
        fd.close()

        ret = run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            import_id=None,
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file='out.csv',
            keep_cols=["b", "c"],
            delimiter=None,
            dataset=fd.name,
            pred_name=None,
            timeout=None,
            ui=ui,
            auto_sample=True,
            fast_mode=False,
            dry_run=False,
            encoding='',
            skip_dialect=False
        )
        assert ret is None

        last_line = open("out.csv", "rb").readlines()[-1]
        expected_last_line = b'1044,2,"eeeeeeee ""eeeeee"" eeeeeeeeeeee'
        assert last_line[:len(expected_last_line)] == expected_last_line
def test_multiclass_pid_lid(live_server, tmpdir, ui, keep_cols=None,
                            in_fixture='tests/fixtures/iris_predict.csv',
                            out_fixture='tests/fixtures/iris_out.csv',
                            fast_mode=False, output_delimiter=None,
                            skip_row_id=False, skip_dialect=False,
                            n_samples=500,
                            max_batch_size=None,
                            expected_ret=None):
    out = tmpdir.join('out.csv')

    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())

    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='5a29097f962d7465d1a81946',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=n_samples,
        out_file=str(out),
        keep_cols=keep_cols,
        delimiter=None,
        dataset=in_fixture,
        pred_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=fast_mode,
        dry_run=False,
        encoding='',
        skip_dialect=skip_dialect,
        skip_row_id=skip_row_id,
        output_delimiter=output_delimiter,
        max_batch_size=max_batch_size
    )

    assert ret is expected_ret

    if out_fixture:
        actual = out.read_text('utf-8')
        with open(out_fixture, 'rU') as f:
            expected = f.read()
            assert actual == expected
示例#18
0
def test_regression(live_server, tmpdir, keep_cols=None,
                    in_fixture='tests/fixtures/regression_predict.csv',
                    out_fixture='tests/fixtures/regression_output.csv',
                    fast_mode=False, skip_row_id=False, output_delimiter=None,
                    skip_dialect=False,
                    n_samples=500,
                    max_batch_size=None):
    # train one model in project
    out = tmpdir.join('out.csv')

    with UI(False, 'DEBUG', False) as ui:
        base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
        ret = run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93e',
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=n_samples,
            out_file=str(out),
            keep_cols=keep_cols,
            delimiter=None,
            dataset=in_fixture,
            pred_name=None,
            timeout=30,
            ui=ui,
            auto_sample=False,
            fast_mode=fast_mode,
            dry_run=False,
            encoding='',
            skip_dialect=skip_dialect,
            skip_row_id=skip_row_id,
            output_delimiter=output_delimiter,
            max_batch_size=max_batch_size
        )
        assert ret is None

        if out_fixture:
            actual = out.read_text('utf-8')
            with open(out_fixture, 'rU') as f:
                expected = f.read()
                print(len(actual), len(expected))
                assert actual == expected
示例#19
0
def test_simple_with_wrong_encoding(live_server, tmpdir, func_params):
    out = tmpdir.join('out.csv')
    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    with pytest.raises(UnicodeDecodeError) as execinfo:
        run_batch_predictions(base_url=base_url,
                              base_headers={},
                              user='******',
                              pwd='password',
                              api_token=None,
                              create_api_token=False,
                              deployment_id=func_params['deployment_id'],
                              pid=func_params['pid'],
                              lid=func_params['lid'],
                              import_id=None,
                              n_retry=3,
                              concurrent=1,
                              resume=False,
                              n_samples=10,
                              out_file=str(out),
                              keep_cols=None,
                              delimiter=None,
                              dataset='tests/fixtures/jpReview_books_reg.csv',
                              pred_name=None,
                              pred_threshold_name=None,
                              pred_decision_name=None,
                              timeout=None,
                              ui=ui,
                              auto_sample=False,
                              fast_mode=False,
                              dry_run=False,
                              encoding='cp932',
                              skip_dialect=False)

    # Fixture dataset encoding 'utf-8' and we trying to decode it with 'cp932'
    assert "'cp932' codec can't decode byte" in str(execinfo.value)
示例#20
0
def test_quotechar_in_keep_cols(live_server):
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ui = PickableMock()
    with tempfile.NamedTemporaryFile(prefix='test_',
                                     suffix='.csv',
                                     delete=False) as fd:
        head = open("tests/fixtures/quotes_input_head.csv", "rb").read()
        body_1 = open("tests/fixtures/quotes_input_first_part.csv",
                      "rb").read()
        body_2 = open("tests/fixtures/quotes_input_bad_part.csv", "rb").read()
        fd.file.write(head)
        size = 0
        while size < DETECT_SAMPLE_SIZE_SLOW:
            fd.file.write(body_1)
            size += len(body_1)
        fd.file.write(body_2)
        fd.close()

        ret = run_batch_predictions(base_url=base_url,
                                    base_headers={},
                                    user='******',
                                    pwd='password',
                                    api_token=None,
                                    create_api_token=False,
                                    pid='56dd9570018e213242dfa93c',
                                    lid='56dd9570018e213242dfa93d',
                                    import_id=None,
                                    n_retry=3,
                                    concurrent=1,
                                    resume=False,
                                    n_samples=10,
                                    out_file='out.csv',
                                    keep_cols=["b", "c"],
                                    delimiter=None,
                                    dataset=fd.name,
                                    pred_name=None,
                                    timeout=None,
                                    ui=ui,
                                    auto_sample=True,
                                    fast_mode=False,
                                    dry_run=False,
                                    encoding='',
                                    skip_dialect=False)
        assert ret is None

        last_line = open("out.csv", "rb").readlines()[-1]
        expected_last_line = b'1044,2,"eeeeeeee ""eeeeee"" eeeeeeeeeeee'
        assert last_line[:len(expected_last_line)] == expected_last_line
示例#21
0
def check_regression_jp(live_server, tmpdir, fast_mode, gzipped):
    """Use utf8 encoded input data.

    """
    if fast_mode:
        out_fname = 'out_fast.csv'
    else:
        out_fname = 'out.csv'
    out = tmpdir.join(out_fname)

    dataset_suffix = '.gz' if gzipped else ''

    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93e',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=500,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/regression_jp.csv' + dataset_suffix,
        pred_name='new_name',
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=fast_mode,
        dry_run=False,
        encoding='',
        skip_dialect=False,
        compression=True
    )
    assert ret is None

    actual = out.read_text('utf-8')

    with open('tests/fixtures/regression_output_jp.csv', 'rU') as f:
        assert actual == f.read()
示例#22
0
def test_keep_wrong_cols(live_server, tmpdir, func_params, fast_mode=False):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui_class = mock.Mock(spec=UI)
    ui = ui_class.return_value
    ui.fatal.side_effect = SystemExit

    with pytest.raises(SystemExit):
        base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
        ret = run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            deployment_id=func_params['deployment_id'],
            pid=func_params['pid'],
            lid=func_params['lid'],
            import_id=None,
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file=str(out),
            keep_cols=['not_present', 'x'],
            delimiter=None,
            dataset='tests/fixtures/temperatura_predict.csv',
            pred_name=None,
            pred_threshold_name=None,
            pred_decision_name=None,
            timeout=None,
            ui=ui,
            auto_sample=False,
            fast_mode=fast_mode,
            dry_run=False,
            encoding='',
            skip_dialect=False
        )

        assert ret is None

    ui.fatal.assert_called()
    ui.fatal.assert_called_with(
        '''keep_cols "['not_present']" not in columns ['', 'x'].'''
    )
示例#23
0
def check_regression_jp(live_server, tmpdir, fast_mode, gzipped):
    """Use utf8 encoded input data.

    """
    if fast_mode:
        out_fname = 'out_fast.csv'
    else:
        out_fname = 'out.csv'
    out = tmpdir.join(out_fname)

    dataset_suffix = '.gz' if gzipped else ''

    ui = PickableMock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93e',
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=500,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/regression_jp.csv' + dataset_suffix,
        pred_name='new_name',
        timeout=30,
        ui=ui,
        auto_sample=False,
        fast_mode=fast_mode,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )
    assert ret is None

    actual = out.read_text('utf-8')

    with open('tests/fixtures/regression_output_jp.csv', 'rU') as f:
        assert actual == f.read()
示例#24
0
def test_request_client_timeout(live_server, tmpdir):
    out = tmpdir.join('out.csv')
    ui = PickableMock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    with mock.patch('datarobot_batch_scoring.'
                    'network.requests.Session') as nw_mock:
        nw_mock.return_value.send = mock.Mock(
            side_effect=requests.exceptions.ReadTimeout)

        ret = run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file=str(out),
            keep_cols=None,
            delimiter=None,
            dataset='tests/fixtures/temperatura_predict.csv.gz',
            pred_name=None,
            timeout=30,
            ui=ui,
            auto_sample=False,
            fast_mode=False,
            dry_run=False,
            encoding='',
            skip_dialect=False
        )

    assert ret is None
    returned = out.read_text('utf-8')
    assert '' in returned, returned
    ui.warning.assert_called_with(textwrap.dedent("""The server did not send any data
in the allotted amount of time.
You might want to decrease the "--n_concurrent" parameters
or
increase "--timeout" parameter.
"""))
def test_keep_wrong_cols(live_server, tmpdir, fast_mode=False):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui_class = mock.Mock(spec=UI)
    ui = ui_class.return_value
    ui.fatal.side_effect = SystemExit

    with pytest.raises(SystemExit):
        base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
        ret = run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            import_id=None,
            n_retry=3,
            concurrent=1,
            resume=False,
            n_samples=10,
            out_file=str(out),
            keep_cols=['not_present', 'x'],
            delimiter=None,
            dataset='tests/fixtures/temperatura_predict.csv',
            pred_name=None,
            timeout=None,
            ui=ui,
            auto_sample=False,
            fast_mode=fast_mode,
            dry_run=False,
            encoding='',
            skip_dialect=False
        )

        assert ret is None

    ui.fatal.assert_called()
    ui.fatal.assert_called_with(
        '''keep_cols "['not_present']" not in columns ['', 'x'].'''
    )
示例#26
0
def test_wrong_result_order(live_server, tmpdir, ui):
    out = tmpdir.join('out.csv')
    live_server.app.config["DELAY_AT"] = {
        8: 3.0,
        9: 2.0,
        10: 1.0
    }

    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93e',
        import_id=None,
        n_retry=3,
        concurrent=4,
        resume=False,
        n_samples=100,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/regression_jp.csv',
        pred_name='new_name',
        timeout=30,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False,
        compression=True
    )
    assert ret is None

    actual = out.read_text('utf-8')

    with open('tests/fixtures/regression_output_jp.csv', 'rU') as f:
        assert actual == f.read()
示例#27
0
def test_prediction_explanations_keepcols(live_server, tmpdir):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='5afb150782c7dd45fcc03951',
        lid='5b2cad28aa1d12847310acf4',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=['medical_specialty', 'number_diagnoses'],
        delimiter=None,
        dataset='tests/fixtures/10kDiabetes.csv',
        pred_name=None,
        pred_threshold_name=None,
        pred_decision_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False,
        max_prediction_explanations=5
    )

    assert ret is None
    actual = out.read_text('utf-8')
    file_path = 'tests/fixtures/10kDiabetes_5explanations_keepcols.csv'
    with open(file_path, 'rU') as f:
        expected = f.read()
    assert str(actual) == str(expected), expected
示例#28
0
def main_standalone(argv=sys.argv[1:]):
    freeze_support()
    global ui  # global variable hack, will get rid of a bit later
    warnings.simplefilter('ignore')
    parsed_args = parse_args(argv, standalone=True)
    exit_code = 1

    generic_opts = parse_generic_options(parsed_args)
    import_id = parsed_args['import_id']

    if generic_opts['dry_run']:
        base_url = ''
        ui.info('Running in dry-run mode')
    else:
        try:
            base_url = get_endpoint(parsed_args['host'], PRED_API_V10)
            ui.info('Will be using API endpoint: {}'.format(base_url))
        except ValueError as e:
            ui.fatal(str(e))

    try:
        exit_code = run_batch_predictions(base_url=base_url,
                                          base_headers={},
                                          user=None,
                                          pwd=None,
                                          api_token=None,
                                          create_api_token=False,
                                          pid=None,
                                          lid=None,
                                          import_id=import_id,
                                          ui=ui,
                                          **generic_opts)
    except SystemError:
        pass
    except ShelveError as e:
        ui.error(str(e))
    except KeyboardInterrupt:
        ui.info('Keyboard interrupt')
    except Exception as e:
        ui.fatal(str(e))
    finally:
        ui.close()
        return exit_code
示例#29
0
def test_pred_threshold_classification(live_server, tmpdir, func_params):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        deployment_id=func_params['deployment_id'],
        pid=func_params['pid'],
        lid=func_params['lid'],
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/temperatura_predict.csv',
        pred_name='healthy',
        pred_threshold_name='threshold',
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )

    assert ret is None

    expected = out.read_text('utf-8')
    with open(
        'tests/fixtures/temperatura_output_healthy_threshold.csv', 'rU'
    ) as f:
        assert expected == f.read(), expected
示例#30
0
def test_request_client_timeout(live_server, tmpdir, ui):
    live_server.app.config['PREDICTION_DELAY'] = 3
    out = tmpdir.join('out.csv')
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/temperatura_predict.csv.gz',
        pred_name=None,
        timeout=1,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )

    assert ret is 1
    returned = out.read_text('utf-8')
    assert '' in returned, returned
    logs = read_logs()
    assert textwrap.dedent("""The server did not send any data
in the allotted amount of time.
You might want to decrease the "--n_concurrent" parameters
or
increase "--timeout" parameter.
""") in logs
def test_multiclass_import_id(live_server, tmpdir, ui):
    out = tmpdir.join('out.csv')

    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())

    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid=None,
        lid=None,
        import_id='098fa761405d1c9d8a5ea71dc0f3d2bb5ce898b5',
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=500,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/iris_predict.csv',
        pred_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False,
        skip_row_id=False,
        output_delimiter=None,
        max_batch_size=None
    )

    assert not ret

    actual = out.read_text('utf-8')
    with open('tests/fixtures/iris_out.csv', 'rU') as f:
        expected = f.read()
        assert actual == expected
示例#32
0
def test_lost_retry(live_server, tmpdir, ui):
    out = tmpdir.join('out.csv')
    live_server.app.config["PREDICTION_DELAY"] = 1.0
    live_server.app.config["FAIL_AT"] = [14]

    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93e',
        import_id=None,
        n_retry=3,
        concurrent=4,
        resume=False,
        n_samples=100,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/regression_jp.csv',
        pred_name='new_name',
        timeout=30,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )
    assert ret is None

    actual = out.read_text('utf-8').splitlines()
    actual.sort()

    with open('tests/fixtures/regression_output_jp.csv', 'rU') as f:
        expected = f.read().splitlines()
        expected.sort()
        assert actual == expected
示例#33
0
def test_simple_transferable(live_server, tmpdir):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        import_id='0ec5bcea7f0f45918fa88257bfe42c09',
        pid=None,
        lid=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/regression_predict.csv',
        pred_name=None,
        pred_threshold_name=None,
        pred_decision_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )

    assert ret is None
    actual = out.read_text('utf-8')
    with open('tests/fixtures/regression_output.csv', 'rU') as f:
        expected = f.read()
    assert str(actual) == str(expected), expected
示例#34
0
def test_request_log_client_error(live_server, tmpdir, ui):
    live_server.app.config["FAIL_GRACEFULLY_AT"] = [8, 9]

    out = tmpdir.join('out.csv')
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        import_id=None,
        n_retry=3,
        concurrent=2,
        resume=False,
        n_samples=5,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/temperatura_predict.csv.gz',
        pred_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )
    assert ret is None

    actual = out.read_text('utf-8')
    assert len(actual.splitlines()) == 101

    logs = read_logs()

    assert 'failed with status code 400 message: Requested failure' in logs
示例#35
0
def test_simple_with_unicode(live_server, tmpdir, func_params, dataset_name):
    # train one model in project
    out = tmpdir.join('out.csv')
    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        deployment_id=func_params['deployment_id'],
        pid=func_params['pid'],
        lid=func_params['lid'],
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/{}'.format(dataset_name),
        pred_name=None,
        pred_threshold_name=None,
        pred_decision_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False)

    assert ret is None
    actual = out.read_text('utf-8')
    with open('tests/fixtures/jpReview_books_reg_out.csv', 'rU') as f:
        expected = f.read()
    assert str(actual) == str(expected), expected
def test_simple(live_server, tmpdir):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui = PickableMock()
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/temperatura_predict.csv.gz',
        pred_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )

    assert ret is None
    actual = out.read_text('utf-8')
    with open('tests/fixtures/temperatura_output.csv', 'rU') as f:
        expected = f.read()
    assert str(actual) == str(expected), expected
示例#37
0
def main_standalone(argv=sys.argv[1:]):
    freeze_support()
    global ui  # global variable hack, will get rid of a bit later
    warnings.simplefilter('ignore')
    parsed_args = parse_args(argv, standalone=True)
    exit_code = 1

    generic_opts = parse_generic_options(parsed_args)
    import_id = parsed_args['import_id']

    if generic_opts['dry_run']:
        base_url = ''
        ui.info('Running in dry-run mode')
    else:
        try:
            base_url = get_endpoint(parsed_args['host'],
                                    PRED_API_V10)
            ui.info('Will be using API endpoint: {}'.format(base_url))
        except ValueError as e:
            ui.fatal(str(e))

    try:
        exit_code = run_batch_predictions(
            base_url=base_url,
            base_headers={}, user=None, pwd=None,
            api_token=None, create_api_token=False,
            pid=None, lid=None, import_id=import_id, ui=ui, **generic_opts
        )
    except SystemError:
        pass
    except ShelveError as e:
        ui.error(str(e))
    except KeyboardInterrupt:
        ui.info('Keyboard interrupt')
    except Exception as e:
        ui.fatal(str(e))
    finally:
        ui.close()
        return exit_code
示例#38
0
def test_simple_api_v1(live_server, tmpdir):
    # train one model in project
    out = tmpdir.join('out.csv')

    ui = PickableMock()
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93f',
        import_id=None,
        n_retry=3,
        concurrent=1,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/temperatura_predict.csv.gz',
        pred_name=None,
        timeout=None,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )

    assert ret is None
    actual = out.read_text('utf-8')
    with open('tests/fixtures/temperatura_api_v1_output.csv', 'rU') as f:
        expected = f.read()
    assert str(actual) == str(expected), expected
示例#39
0
def test_compression(live_server, tmpdir, ui):
    out = tmpdir.join('out.csv')
    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        import_id=None,
        n_retry=3,
        concurrent=2,
        resume=False,
        n_samples=100,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/regression_jp.csv.gz',
        pred_name=None,
        timeout=30,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False,
        compression=True
    )
    assert ret is None

    actual = out.read_text('utf-8')
    assert len(actual.splitlines()) == 1411

    logs = read_logs()
    assert "space savings" in logs
示例#40
0
def test_os_env_proxy_handling(live_server, tmpdir, ui):
    os.environ["HTTP_PROXY"] = "http://localhost"

    out = tmpdir.join('out.csv')
    base_url = '{webhost}/predApi/v1.0/'.format(webhost=live_server.url())
    with pytest.raises(SystemExit):
        ret = run_batch_predictions(
            base_url=base_url,
            base_headers={},
            user='******',
            pwd='password',
            api_token=None,
            create_api_token=False,
            pid='56dd9570018e213242dfa93c',
            lid='56dd9570018e213242dfa93d',
            import_id=None,
            n_retry=1,
            concurrent=2,
            resume=False,
            n_samples=1,
            out_file=str(out),
            keep_cols=None,
            delimiter=None,
            dataset='tests/fixtures/temperatura_predict.csv.gz',
            pred_name=None,
            timeout=None,
            ui=ui,
            auto_sample=False,
            fast_mode=False,
            dry_run=False,
            encoding='',
            skip_dialect=False)
        assert ret is 1

    logs = read_logs()
    assert "Failed to establish a new connection" in logs
    os.environ["HTTP_PROXY"] = ""
示例#41
0
def test_request_pool_is_full(live_server, tmpdir, ui):
    live_server.app.config["PREDICTION_DELAY"] = 1

    out = tmpdir.join('out.csv')

    base_url = '{webhost}/api/v1/'.format(webhost=live_server.url())
    ret = run_batch_predictions(
        base_url=base_url,
        base_headers={},
        user='******',
        pwd='password',
        api_token=None,
        create_api_token=False,
        pid='56dd9570018e213242dfa93c',
        lid='56dd9570018e213242dfa93d',
        import_id=None,
        n_retry=3,
        concurrent=30,
        resume=False,
        n_samples=10,
        out_file=str(out),
        keep_cols=None,
        delimiter=None,
        dataset='tests/fixtures/criteo_top30_1m.csv.gz',
        pred_name=None,
        timeout=30,
        ui=ui,
        auto_sample=False,
        fast_mode=False,
        dry_run=False,
        encoding='',
        skip_dialect=False
    )
    assert ret is None

    logs = read_logs()
    assert "Connection pool is full" not in logs
示例#42
0
def main(argv=sys.argv[1:]):
    global ui  # global variable hack, will get rid of a bit later
    warnings.simplefilter("ignore")
    parser = argparse.ArgumentParser(
        description=DESCRIPTION, epilog=EPILOG, formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument(
        "--verbose", "-v", action="store_true", help="Provides status updates while " "the script is running."
    )
    parser.add_argument("--version", action="version", version=VERSION_TEMPLATE, help="Show version")
    dataset_gr = parser.add_argument_group("Dataset and server")
    dataset_gr.add_argument(
        "--host",
        type=str,
        help="Specifies the protocol (http or https) and "
        "hostname of the prediction API endpoint. "
        'E.g. "https://example.orm.datarobot.com"',
    )
    dataset_gr.add_argument("project_id", type=str, help="Specifies the project " "identification string.")
    dataset_gr.add_argument("model_id", type=str, help="Specifies the model identification string.")
    dataset_gr.add_argument("dataset", type=str, help="Specifies the .csv input file that " "the script scores.")
    dataset_gr.add_argument(
        "--out",
        type=str,
        nargs="?",
        default="out.csv",
        help="Specifies the file name, "
        "and optionally path, "
        "to which the results are written. "
        "If not specified, "
        "the default file name is out.csv, "
        "written to the directory containing the script. "
        "(default: %(default)r)",
    )
    auth_gr = parser.add_argument_group("Authentication parameters")
    auth_gr.add_argument(
        "--user",
        type=str,
        help="Specifies the username used to acquire " "the api-token. " "Use quotes if the name contains spaces.",
    )
    auth_gr.add_argument(
        "--password",
        type=str,
        nargs="?",
        help="Specifies the password used to acquire " "the api-token. " "Use quotes if the name contains spaces.",
    )
    auth_gr.add_argument(
        "--api_token",
        type=str,
        nargs="?",
        help="Specifies the api token for the requests; "
        "if you do not have a token, "
        "you must specify the password argument.",
    )
    auth_gr.add_argument(
        "--create_api_token",
        action="store_true",
        default=False,
        help="Requests a new API token. To use this option, "
        "you must specify the "
        "password argument for this request "
        "(not the api_token argument). "
        "(default: %(default)r)",
    )
    auth_gr.add_argument(
        "--datarobot_key",
        type=str,
        nargs="?",
        help="An additional datarobot_key " "for dedicated prediction instances.",
    )
    conn_gr = parser.add_argument_group("Connection control")
    conn_gr.add_argument(
        "--timeout", type=int, default=30, help="The timeout for each post request. " "(default: %(default)r)"
    )
    conn_gr.add_argument(
        "--n_samples",
        type=int,
        nargs="?",
        default=False,
        help="Specifies the number of samples (rows) to use "
        'per batch. If not defined the "auto_sample" option '
        "will be used.",
    )
    conn_gr.add_argument(
        "--n_concurrent",
        type=int,
        nargs="?",
        default=4,
        help="Specifies the number of concurrent requests " "to submit. (default: %(default)r)",
    )
    conn_gr.add_argument(
        "--n_retry",
        type=int,
        default=3,
        help="Specifies the number of times DataRobot "
        "will retry if a request fails. "
        "A value of -1, the default, specifies "
        "an infinite number of retries."
        "(default: %(default)r)",
    )
    conn_gr.add_argument(
        "--resume",
        action="store_true",
        default=False,
        help="Starts the prediction from the point at which "
        "it was halted. "
        "If the prediction stopped, for example due "
        "to error or network connection issue, you can run "
        "the same command with all the same "
        "all arguments plus this resume argument.",
    )
    csv_gr = parser.add_argument_group("CVS parameters")
    csv_gr.add_argument(
        "--keep_cols",
        type=str,
        nargs="?",
        help="Specifies the column names to append " "to the predictions. " "Enter as a comma-separated list.",
    )
    csv_gr.add_argument(
        "--delimiter",
        type=str,
        nargs="?",
        default=None,
        help="Specifies the delimiter to recognize in "
        'the input .csv file. E.g. "--delimiter=,". '
        "If not specified, the script tries to automatically "
        'determine the delimiter. The special keyword "tab" '
        "can be used to indicate a tab delimited csv.",
    )
    csv_gr.add_argument(
        "--pred_name",
        type=str,
        nargs="?",
        default=None,
        help="Specifies column name for prediction results, "
        "empty name is used if not specified. For binary "
        "predictions assumes last class in lexical order "
        "as positive",
    )
    csv_gr.add_argument(
        "--fast",
        action="store_true",
        default=False,
        help="Experimental: faster CSV processor. " "Note: does not support multiline csv. ",
    )
    csv_gr.add_argument(
        "--auto_sample",
        action="store_true",
        default=False,
        help='Override "n_samples" and instead '
        "use chunks of about 1.5 MB. This is recommended and "
        'enabled by default if "n_samples" is not defined.',
    )
    csv_gr.add_argument(
        "--encoding",
        type=str,
        default="",
        help="Declare the dataset encoding. "
        "If an encoding is not provided the batch_scoring "
        'script attempts to detect it. E.g "utf-8", "latin-1" '
        'or "iso2022_jp". See the Python docs for a list of '
        "valid encodings "
        "https://docs.python.org/3/library/codecs.html"
        "#standard-encodings",
    )
    csv_gr.add_argument(
        "--skip_dialect",
        action="store_true",
        default=False,
        help="Tell the batch_scoring script " "to skip csv dialect detection.",
    )
    csv_gr.add_argument("--skip_row_id", action="store_true", default=False, help="Skip the row_id column in output.")
    csv_gr.add_argument("--output_delimiter", type=str, default=None, help="Set the delimiter for output file.")
    misc_gr = parser.add_argument_group("Miscellaneous")
    misc_gr.add_argument("-y", "--yes", dest="prompt", action="store_true", help="Always answer 'yes' for user prompts")
    misc_gr.add_argument("-n", "--no", dest="prompt", action="store_false", help="Always answer 'no' for user prompts")
    misc_gr.add_argument(
        "--dry_run", dest="dry_run", action="store_true", help="Only read/chunk input data but dont send " "requests."
    )
    misc_gr.add_argument(
        "--stdout", action="store_true", dest="stdout", default=False, help="Send all log messages to stdout."
    )

    defaults = {
        "prompt": None,
        "out": "out.csv",
        "create_api_token": False,
        "timeout": 30,
        "n_samples": False,
        "n_concurrent": 4,
        "n_retry": 3,
        "resume": False,
        "fast": False,
        "stdout": False,
        "auto_sample": False,
    }

    conf_file = get_config_file()
    if conf_file:
        file_args = parse_config_file(conf_file)
        defaults.update(file_args)
    parser.set_defaults(**defaults)
    for action in parser._actions:
        if action.dest in defaults and action.required:
            action.required = False
            if "--" + action.dest not in argv:
                action.nargs = "?"
    parsed_args = {k: v for k, v in vars(parser.parse_args(argv)).items() if v is not None}
    loglevel = logging.DEBUG if parsed_args["verbose"] else logging.INFO
    stdout = parsed_args["stdout"]
    ui = UI(parsed_args.get("prompt"), loglevel, stdout)
    printed_args = copy.copy(parsed_args)
    printed_args.pop("password", None)
    ui.debug(printed_args)
    ui.info("platform: {} {}".format(sys.platform, sys.version))

    # parse args
    host = parsed_args["host"]
    pid = parsed_args["project_id"]
    lid = parsed_args["model_id"]
    n_retry = int(parsed_args["n_retry"])
    if parsed_args.get("keep_cols"):
        keep_cols = [s.strip() for s in parsed_args["keep_cols"].split(",")]
    else:
        keep_cols = None
    concurrent = int(parsed_args["n_concurrent"])
    dataset = parsed_args["dataset"]
    n_samples = int(parsed_args["n_samples"])
    delimiter = parsed_args.get("delimiter")
    resume = parsed_args["resume"]
    out_file = parsed_args["out"]
    datarobot_key = parsed_args.get("datarobot_key")
    timeout = int(parsed_args["timeout"])
    fast_mode = parsed_args["fast"]
    auto_sample = parsed_args["auto_sample"]
    if not n_samples:
        auto_sample = True
    encoding = parsed_args["encoding"]
    skip_dialect = parsed_args["skip_dialect"]
    skip_row_id = parsed_args["skip_row_id"]
    output_delimiter = parsed_args.get("output_delimiter")

    if "user" not in parsed_args:
        user = ui.prompt_user()
    else:
        user = parsed_args["user"].strip()

    if not os.path.exists(parsed_args["dataset"]):
        ui.fatal("file {} does not exist.".format(parsed_args["dataset"]))

    try:
        verify_objectid(pid)
        verify_objectid(lid)
    except ValueError as e:
        ui.fatal(str(e))

    if delimiter == "\\t" or delimiter == "tab":
        # NOTE: on bash you have to use Ctrl-V + TAB
        delimiter = "\t"

    if delimiter and delimiter not in VALID_DELIMITERS:
        ui.fatal('Delimiter "{}" is not a valid delimiter.'.format(delimiter))

    if output_delimiter == "\\t" or output_delimiter == "tab":
        # NOTE: on bash you have to use Ctrl-V + TAB
        output_delimiter = "\t"

    if output_delimiter and output_delimiter not in VALID_DELIMITERS:
        ui.fatal('Output delimiter "{}" is not a valid delimiter.'.format(output_delimiter))

    api_token = parsed_args.get("api_token")
    create_api_token = parsed_args.get("create_api_token")
    pwd = parsed_args.get("password")
    pred_name = parsed_args.get("pred_name")
    dry_run = parsed_args.get("dry_run", False)

    base_url = parse_host(host, ui)

    base_headers = {}
    if datarobot_key:
        base_headers["datarobot-key"] = datarobot_key

    ui.debug("batch_scoring v{}".format(__version__))
    ui.info("connecting to {}".format(base_url))

    try:
        run_batch_predictions(
            base_url=base_url,
            base_headers=base_headers,
            user=user,
            pwd=pwd,
            api_token=api_token,
            create_api_token=create_api_token,
            pid=pid,
            lid=lid,
            n_retry=n_retry,
            concurrent=concurrent,
            resume=resume,
            n_samples=n_samples,
            out_file=out_file,
            keep_cols=keep_cols,
            delimiter=delimiter,
            dataset=dataset,
            pred_name=pred_name,
            timeout=timeout,
            ui=ui,
            fast_mode=fast_mode,
            auto_sample=auto_sample,
            dry_run=dry_run,
            encoding=encoding,
            skip_dialect=skip_dialect,
            skip_row_id=skip_row_id,
            output_delimiter=output_delimiter,
        )
    except SystemError:
        pass
    except ShelveError as e:
        ui.error(str(e))
    except KeyboardInterrupt:
        ui.info("Keyboard interrupt")
    except Exception as e:
        ui.fatal(str(e))
    finally:
        ui.close()
示例#43
0
def _main(argv, deployment_aware=False):
    freeze_support()
    global ui  # global variable hack, will get rid of a bit later
    warnings.simplefilter('ignore')
    parsed_args = parse_args(argv, deployment_aware=deployment_aware)
    exit_code = 1

    generic_opts = parse_generic_options(parsed_args)

    # parse args
    if deployment_aware:
        deployment_id = parsed_args['deployment_id']
        verify_objectid(deployment_id)
        pid, lid = None, None
    else:
        ui.warning('batch_scoring command is deprecated. '
                   'Use batch_scoring_deployment_aware command instead.')
        deployment_id = None
        pid = parsed_args['project_id']
        lid = parsed_args['model_id']
        try:
            verify_objectid(pid)
            verify_objectid(lid)
        except ValueError as e:
            ui.fatal(str(e))

    # auth only ---
    datarobot_key = parsed_args.get('datarobot_key')
    api_token = parsed_args.get('api_token')
    create_api_token = parsed_args.get('create_api_token')
    user = parsed_args.get('user')
    pwd = parsed_args.get('password')

    if not generic_opts['dry_run']:
        user = user or ui.prompt_user()
        user = user.lower()
        user = user.strip()

        if not api_token and not pwd:
            pwd = ui.getpass()

    base_headers = {}
    if datarobot_key:
        base_headers['datarobot-key'] = datarobot_key
    # end auth ---

    if generic_opts['dry_run']:
        base_url = ''
        ui.info('Running in dry-run mode')
    else:
        try:
            base_url = get_endpoint(parsed_args['host'],
                                    parsed_args['api_version'])
            ui.info('Will be using API endpoint: {}'.format(base_url))
        except ValueError as e:
            ui.fatal(str(e))

    try:
        exit_code = run_batch_predictions(base_url=base_url,
                                          base_headers=base_headers,
                                          user=user,
                                          pwd=pwd,
                                          api_token=api_token,
                                          create_api_token=create_api_token,
                                          pid=pid,
                                          lid=lid,
                                          import_id=None,
                                          deployment_id=deployment_id,
                                          ui=ui,
                                          **generic_opts)
    except SystemError:
        pass
    except ShelveError as e:
        ui.error(str(e))
    except KeyboardInterrupt:
        ui.info('Keyboard interrupt')
    except UnicodeDecodeError as e:
        ui.error(str(e))
        if generic_opts.get('fast_mode'):
            ui.error("You are using --fast option, which uses a small sample "
                     "of data to figuring out the encoding of your file. You "
                     "can try to specify the encoding directly for this file "
                     "by using the encoding flag (e.g. --encoding utf-8). "
                     "You could also try to remove the --fast mode to auto-"
                     "detect the encoding with a larger sample size")
    except Exception as e:
        ui.fatal(str(e))
    finally:
        ui.close()
        return exit_code
示例#44
0
def main(argv=sys.argv[1:]):
    freeze_support()
    global ui  # global variable hack, will get rid of a bit later
    warnings.simplefilter('ignore')
    parsed_args = parse_args(argv)
    exit_code = 1

    generic_opts = parse_generic_options(parsed_args)

    # parse args
    pid = parsed_args['project_id']
    lid = parsed_args['model_id']
    try:
        verify_objectid(pid)
        verify_objectid(lid)
    except ValueError as e:
        ui.fatal(str(e))

    # auth only ---
    datarobot_key = parsed_args.get('datarobot_key')
    api_token = parsed_args.get('api_token')
    create_api_token = parsed_args.get('create_api_token')
    user = parsed_args.get('user')
    pwd = parsed_args.get('password')

    if not generic_opts['dry_run']:
        user = user or ui.prompt_user()
        user = user.strip()

        if not api_token and not pwd:
            pwd = ui.getpass()

    base_headers = {}
    if datarobot_key:
        base_headers['datarobot-key'] = datarobot_key
    # end auth ---

    if generic_opts['dry_run']:
        base_url = ''
        ui.info('Running in dry-run mode')
    else:
        try:
            base_url = get_endpoint(parsed_args['host'],
                                    parsed_args['api_version'])
            ui.info('Will be using API endpoint: {}'.format(base_url))
        except ValueError as e:
            ui.fatal(str(e))

    try:
        exit_code = run_batch_predictions(
            base_url=base_url, base_headers=base_headers, user=user, pwd=pwd,
            api_token=api_token, create_api_token=create_api_token,
            pid=pid, lid=lid, import_id=None, ui=ui, **generic_opts
        )
    except SystemError:
        pass
    except ShelveError as e:
        ui.error(str(e))
    except KeyboardInterrupt:
        ui.info('Keyboard interrupt')
    except UnicodeDecodeError as e:
        ui.error(str(e))
        if generic_opts.get('fast_mode'):
            ui.error("You are using --fast option, which uses a small sample "
                     "of data to figuring out the encoding of your file. You "
                     "can try to specify the encoding directly for this file "
                     "by using the encoding flag (e.g. --encoding utf-8). "
                     "You could also try to remove the --fast mode to auto-"
                     "detect the encoding with a larger sample size")
    except Exception as e:
        ui.fatal(str(e))
    finally:
        ui.close()
        return exit_code