def test_get_project(client: Client): client._get = Mock( return_value={"data": {"project": {"_id": 123, "description": "", "display_name": ""}}} ) check = client.get_project(name="proj") assert check.name == "proj" assert check.client == client assert check.project_id == 123 client._create_project = Mock( return_value={"data": {"id": "5eb07df99294fd2dbc3dbe6a"}} ) client._get_project = Mock( return_value={ "project": { "name": "random", "id": "5eb07df99294fd2dbc3dbe6a", "description": "", "display_name": "" } } ) check = client.get_project(create=True) client._get_project.assert_called_with("5eb07df99294fd2dbc3dbe6a") assert check.name == "random" assert check.client == client assert check.project_id == "5eb07df99294fd2dbc3dbe6a"
def tail(args, gretel_client: Client): """command handler for `gretel tail`""" if args.project: iterator = gretel_client._iter_records(args.project, direction='forward') for record in iterator: print(record)
def test_cli_tail(client: Client): sys.argv = ['gretel', '--project', 'test-proj', 'tail'] client._iter_records = Mock(return_value=[]) parser = cli.parse_command() command = parser.parse_args() command.func(command, client) assert client._iter_records.call_count == 1
def write(args, gretel_client: Client): """command handler for `gretel write`""" if args.file: input_source = args.file elif args.stdin: # unix pipes aren't seekable. for certain readers we need # to be able to seek and replay bytes from a stream. input_source = SeekableStreamBuffer(sys.stdin.buffer) else: raise Exception("No valid input stream passed. Valid inputs include " "--file or --stdin.") reader = reader_from_args(args, input_source) sampler = ConstantSampler(sample_rate=args.sample_rate, record_limit=args.max_records) gretel_client._write_records(project=args.project, sampler=sampler, reader=reader)
def test_get_cloud_client_prompt(getenv, getpass, Client): # when no env is set and prompt is true, ask for gretel key getenv.return_value = None get_cloud_client("api", "prompt") assert getpass.call_count == 1 # when api key is set, and prompt is true, use api key getenv.return_value = "abcd123" get_cloud_client("api", "prompt") Client.assert_called_with(host="api.gretel.cloud", api_key="abcd123") assert getpass.call_count == 1 # when api key is set and prompt always is true, ask for api key get_cloud_client("api", "prompt_always") assert getpass.call_count == 2 # use api key env variable get_cloud_client("api", "abc123") Client.assert_called_with(host="api.gretel.cloud", api_key="abc123")
def test_api_4xx_errors(client: Client): client.session.get = Mock(side_effect=[Fake404(), Fake400(), Fake401(), Fake403()]) with pytest.raises(NotFound): client._get("foo", None) with pytest.raises(BadRequest): client._get("foo", None) with pytest.raises(Unauthorized): client._get("foo", None) with pytest.raises(Forbidden): client._get("foo", None)
def test_record_writer_csv(fake, client: Client): client._post = Mock() input_csv = io.StringIO() csv_writer = csv.writer(input_csv, quoting=csv.QUOTE_NONNUMERIC) header = [f"header_{x}" for x in range(10)] csv_writer.writerow(header) rows = [] for _ in range(5): row = fake.pylist(nb_elements=10, variable_nb_elements=False) csv_writer.writerow(row) # CSVs don't preserve types by default. as a result we # want to cast everything to a string so that it can be # used in the call assertion. rows.append([str(val) for val in row]) input_csv.seek(0) client._write_records(project="test-proj", reader=CsvReader(input_csv)) expected_payload = [dict(zip(header, row)) for row in rows] client._post.assert_called_with("records/send/test-proj", data=expected_payload)
def test_cli_write(client: Client, generate_csv, test_records, tmpdir_factory): file_path = tmpdir_factory.mktemp('test') / 'test_csv.csv' with open(file_path, 'w') as input_csv: generate_csv(test_records, input_csv) sys.argv = ['gretel', '--project', 'test-proj', 'write', '--file', str(file_path), '--sample-rate', '2', '--max-record', '10'] client._write_records = Mock() parser = cli.parse_command() command = parser.parse_args() command.func(command, client) _, kwargs = client._write_records.call_args assert kwargs['project'] == 'test-proj' assert isinstance(kwargs['reader'], CsvReader) assert kwargs['sampler'].sample_rate == 2 assert kwargs['sampler'].record_limit == 10
def test_cli_write_json_stream(client: Client, test_records): class TestInput: def __init__(self): self.input_buffer = io.BytesIO() @property def buffer(self): return self.input_buffer sys.stdin = TestInput() for record in test_records: sys.stdin.buffer.write(f"{json.dumps(record)}\n\n".encode()) sys.stdin.buffer.seek(0) sys.argv = ['gretel', '--project', 'test-proj', 'write', '--stdin', '--reader', 'json'] client._post = Mock() parser = cli.parse_command() command = parser.parse_args() command.func(command, client) client._post.called_with('records/send/test-proj', {}, test_records)