コード例 #1
0
def test_get_csv_path_test_relative_path(client_task_definition_csv_data,
                                         csv_dataset_path):
    client = Client()
    client._cluster_request_data = client_task_definition_csv_data
    client._cluster_request_data['ml']['dataset']['source'][
        'features'] = 'tests/data/abalone.csv'
    assert client.get_csv_path() == csv_dataset_path
コード例 #2
0
def test_send_data_segments_to_client_response_message_order(
        mocker, tmpdir, redisdb, client_task_definition_data):
    segment0 = tmpdir.join('test0.segment')
    segment0.write('content')

    segment1 = tmpdir.join('test1.segment')
    segment1.write('content')

    segment2 = tmpdir.join('test2.segment')
    segment2.write('content')

    client = Client()
    client.conn = redisdb
    client._cluster_request_data = client_task_definition_data
    client._cluster_address = str(uuid.uuid4())
    client._dataset_segments = {
        'test0': str(segment0),
        'test1': str(segment1),
        'test2': str(segment2)
    }

    client.segment_hashes = ['test1', 'test2', 'test0']

    client.send_data_segments_to_cluster()

    message_raw = redisdb.lrange(client._cluster_address, 0, -1)
    assert message_raw is not None

    messages = list(
        map(functools.partial(yaml.load, Loader=yaml.UnsafeLoader),
            message_raw))

    for key, message in zip(client.segment_hashes, messages):
        assert message['hash'] == key
コード例 #3
0
def test_load_dataset_is_mnist_called(client_task_definition_data):
    client = Client()
    client._load_mnist_dataset = MagicMock()

    client._cluster_request_data = client_task_definition_data
    client.load_dataset()

    client._load_mnist_dataset.assert_called_once()
コード例 #4
0
def test_load_csv_dataset_from_hardrive(client_task_definition_csv_data):
    client = Client()
    client._load_mnist_dataset = MagicMock()

    client._cluster_request_data = client_task_definition_csv_data

    training_data = client._load_csv_dataset()
    assert type(training_data) == mxnet.gluon.data.DataLoader
    for feature_batch, label_batch in training_data:
        assert len(feature_batch) == 100
コード例 #5
0
def test_load_dataset_is_csv_called(client_task_definition_data):
    client = Client()
    client._load_mnist_dataset = MagicMock()
    client._load_csv_dataset = MagicMock()

    client._cluster_request_data = client_task_definition_data
    client._cluster_request_data['ml']['dataset']['format'] = 'CSV'

    client.load_dataset()
    client._load_csv_dataset.assert_called_once()
コード例 #6
0
def test_send_dataset_hashes_message_format_testing(
        redisdb, client_task_definition_data, tmpdir):
    client = Client()
    client.conn = redisdb
    client._cluster_request_data = client_task_definition_data
    client._cluster_address = 'test_cluster_address'
    client._worker_output_directory = str(tmpdir)

    client.send_dataset_hashes()

    message = redisdb.lpop('test_cluster_address')
    message_data = yaml.load(message, yaml.UnsafeLoader)

    assert message_data['client_id'] == client._client_id
    assert len(message_data['hashes']) == NUMBER_OF_DATASET_SEGMENTS
コード例 #7
0
def test_send_data_segments_to_client_cleanup_procedure(
        mocker, tmpdir, redisdb, client_task_definition_data):
    segment = tmpdir.join('test.segment')
    segment.write('content')

    client = Client()
    client.conn = redisdb
    client._cluster_request_data = client_task_definition_data
    client._cluster_address = str(uuid.uuid4())
    client._dataset_segments = {'test': str(segment)}

    client.segment_hashes = ['test']

    client.send_data_segments_to_cluster()

    assert not os.path.isfile(str(segment))
コード例 #8
0
def test_send_data_segments_to_client_response_message(
        mocker, tmpdir, redisdb, client_task_definition_data):
    segment = tmpdir.join('test.segment')
    segment.write('content')

    client = Client()
    client.conn = redisdb
    client._cluster_request_data = client_task_definition_data
    client._cluster_address = str(uuid.uuid4())
    client._dataset_segments = {'test': str(segment)}

    client.segment_hashes = ['test']

    client.send_data_segments_to_cluster()

    message_raw = redisdb.lrange(client._cluster_address, 0, -1)
    assert message_raw is not None

    messages = list(
        map(functools.partial(yaml.load, Loader=yaml.UnsafeLoader),
            message_raw))

    message = messages[0]
    assert all(key in message for key in ['hash', 'bucket', 'key'])
コード例 #9
0
def test_get_csv_path_test_absolute_path(client_task_definition_csv_data,
                                         csv_dataset_path):
    client = Client()
    client._cluster_request_data = client_task_definition_csv_data

    assert client.get_csv_path() == csv_dataset_path