예제 #1
0
def test_register_for_task_same_node_not_counted_twice(redisdb):
    node = CommitteeCandidate()
    node.conn = redisdb

    registration_channel = 'test_registration'

    node._task_registration_channel = registration_channel
    node.register_for_task()
    node.register_for_task()
    assert node.get_number_of_registered_nodes() == 1
예제 #2
0
def test_disable_registration_for_client_task_only_one_task_in_queue(
        redisdb, client_task_definition_path, client_task_definition_data):
    with open(client_task_definition_path) as request_file:
        request_data = request_file.read()

    node = CommitteeCandidate()
    node.conn = redisdb
    node.task_data = client_task_definition_data
    node.set_task_id(request_data)

    redisdb.lpush(CLIENT_TASK_CHANNEL, request_data)
    node.disable_registration_for_client_task()

    assert redisdb.llen(CLIENT_TASK_CHANNEL) == 0
예제 #3
0
def test_inform_client_of_hash_voting_results(client_task_definition_data):
    node = CommitteeCandidate()
    node.task_data = client_task_definition_data

    node.conn = MagicMock()

    node.inform_client_of_hash_allocation()

    node.conn.lpush.assert_called()
예제 #4
0
def test_inform_client_of_task_id(client_task_definition_data):
    node = CommitteeCandidate()
    node.task_data = client_task_definition_data

    node.conn = MagicMock()

    node.inform_client_of_task_id()

    node.conn.lpush.assert_called()
예제 #5
0
def test_set_task_id_provides_consistent_hash(client_task_definition_path,
                                              client_task_definition_data):
    with open(client_task_definition_path) as request_file:
        request_data = request_file.read()

    node = CommitteeCandidate()
    node.task_data = client_task_definition_data
    node.set_task_id(request_data)

    task_id = node.task_id

    for _ in range(10):
        node.set_task_id(request_data)
        assert task_id == node.task_id
예제 #6
0
def test_inform_client_of_task_integrated(client_task_definition_data,
                                          redisdb):
    client_address = 'test_client'

    node = CommitteeCandidate()
    node.conn = redisdb

    node.task_id = '123'

    node.task_data = client_task_definition_data
    node.task_data['client_listen_address'] = client_address
    node._client_response_listening_channel = 'test_cluster'

    node.inform_client_of_task_id()

    client = Client()
    client._client_id = client_task_definition_data['client_id']
    client.conn = redisdb
    client._client_listen_address = client_address

    client.obtain_cluster_task_id()
    assert client._task_id == '123'
    assert client._cluster_address == 'test_cluster'
예제 #7
0
def test_collect_segment_hash_table_simple(redisdb):
    node = CommitteeCandidate()
    node.conn = redisdb
    cluster_listen_address = 'test_cluster'
    node._client_response_listening_channel = cluster_listen_address

    client_request = {
        'client_id': '123',
        'hashes': range(NUMBER_OF_DATASET_SEGMENTS)
    }

    redisdb.lpush(cluster_listen_address, yaml.dump(client_request))

    node.collect_segment_hash_table()
    assert len(node.segment_hash_table) == NUMBER_OF_DATASET_SEGMENTS
예제 #8
0
def test_validate_request_data_invalid_type(data):
    node = CommitteeCandidate()
    node.task_data = data

    with pytest.raises(ValueError):
        node.validate_request_data()
예제 #9
0
def test_validate_request_data_simple(client_task_definition_data):
    node = CommitteeCandidate()
    node.task_data = client_task_definition_data

    node.validate_request_data()
예제 #10
0
def test_committee_gets_disolved_before_completing_training(redisdb):
    registration_channel = 'test_registration'

    node_1 = CommitteeCandidate()
    node_1.conn = redisdb

    node_1._task_registration_channel = registration_channel
    node_1.register_for_task()

    node_2 = CommitteeCandidate()
    node_2.conn = redisdb

    node_2._task_registration_channel = registration_channel
    node_2.register_for_task()

    assert node_1.get_number_of_registered_nodes() == 2
    assert node_2.get_number_of_registered_nodes() == 2

    time.sleep(20)

    assert len(node_1.get_registered_nodes()) == 0
예제 #11
0
def test_disable_registration_for_client_task():
    node = CommitteeCandidate()
    node.conn = MagicMock()

    node.disable_registration_for_client_task()
    node.conn.lrem.assert_called()
예제 #12
0
def test_get_number_of_registered_nodes_multiple(redisdb):
    registration_channel = 'test_registration'

    node_1 = CommitteeCandidate()
    node_1.conn = redisdb

    node_1._task_registration_channel = registration_channel
    node_1.register_for_task()

    node_2 = CommitteeCandidate()
    node_2.conn = redisdb

    node_2._task_registration_channel = registration_channel
    node_2.register_for_task()

    assert node_1.get_number_of_registered_nodes() == 2
    assert node_2.get_number_of_registered_nodes() == 2
예제 #13
0
    def __init__(self, redis_host, redis_port, context, is_debug=False):
        CommitteeCandidate.__init__(self, redis_host, redis_port, is_debug)
        # lambda convenience methods for message map

        self.zero_setter = np.vectorize(lambda int_type: int_type &
                                        (~(1 << 31)),
                                        otypes=[np.uint32])
        self.one_setter = np.vectorize(lambda int_type: int_type | (1 << 31),
                                       otypes=[np.uint32])

        self.ctx = context

        self.network_time_to_pick_next_msg = None

        # the hash of the current batch
        self.batch_hash = None

        # the current state of the model
        self.net_hash_start = None
        self.net_hash_end = None

        # set the threshold for residual gradient
        self.tau = None

        # Collect all parameters from net and its children, then initialize them.
        self.net = None
        # self.net.hybridize()  # for performance reasons

        self.trainer = None

        self._peer_msg_ids = []
        self._consumed_peer_msg_ids = []

        # Add metrics: accuracy and cross-entropy loss
        accuracy = mx.metric.Accuracy()
        ce_loss = mx.metric.CrossEntropy()
        self.comp_metric = mx.metric.CompositeEvalMetric([accuracy, ce_loss])

        self.loss = gluon.loss.SoftmaxCrossEntropyLoss()

        # variables needed for indexing
        self.gradients_sizes = None

        # cumulative sum of gradients used in indexing
        self.gradients_cumulative = None

        # gradients blueprint
        self.gradients_blueprint = []
        self.grads = None

        # flag for showing that this is the first batch
        # (some network parameters are lazily initialized, so we need this)
        self.gradient_residual = None

        self.last_updated_iteration_index = 0
        self.global_iteration_index = 0

        self.message_key_template = 'it_res_{}_{}_{}'
        self.node_output_directory = None
        self._model_path = None

        self.miner = None