def test_register_for_task_same_node_not_counted_twice(redisdb): node = CommitteeCandidate() node.conn = redisdb registration_channel = 'test_registration' node._task_registration_channel = registration_channel node.register_for_task() node.register_for_task() assert node.get_number_of_registered_nodes() == 1
def test_disable_registration_for_client_task_only_one_task_in_queue( redisdb, client_task_definition_path, client_task_definition_data): with open(client_task_definition_path) as request_file: request_data = request_file.read() node = CommitteeCandidate() node.conn = redisdb node.task_data = client_task_definition_data node.set_task_id(request_data) redisdb.lpush(CLIENT_TASK_CHANNEL, request_data) node.disable_registration_for_client_task() assert redisdb.llen(CLIENT_TASK_CHANNEL) == 0
def test_inform_client_of_hash_voting_results(client_task_definition_data): node = CommitteeCandidate() node.task_data = client_task_definition_data node.conn = MagicMock() node.inform_client_of_hash_allocation() node.conn.lpush.assert_called()
def test_inform_client_of_task_id(client_task_definition_data): node = CommitteeCandidate() node.task_data = client_task_definition_data node.conn = MagicMock() node.inform_client_of_task_id() node.conn.lpush.assert_called()
def test_set_task_id_provides_consistent_hash(client_task_definition_path, client_task_definition_data): with open(client_task_definition_path) as request_file: request_data = request_file.read() node = CommitteeCandidate() node.task_data = client_task_definition_data node.set_task_id(request_data) task_id = node.task_id for _ in range(10): node.set_task_id(request_data) assert task_id == node.task_id
def test_inform_client_of_task_integrated(client_task_definition_data, redisdb): client_address = 'test_client' node = CommitteeCandidate() node.conn = redisdb node.task_id = '123' node.task_data = client_task_definition_data node.task_data['client_listen_address'] = client_address node._client_response_listening_channel = 'test_cluster' node.inform_client_of_task_id() client = Client() client._client_id = client_task_definition_data['client_id'] client.conn = redisdb client._client_listen_address = client_address client.obtain_cluster_task_id() assert client._task_id == '123' assert client._cluster_address == 'test_cluster'
def test_collect_segment_hash_table_simple(redisdb): node = CommitteeCandidate() node.conn = redisdb cluster_listen_address = 'test_cluster' node._client_response_listening_channel = cluster_listen_address client_request = { 'client_id': '123', 'hashes': range(NUMBER_OF_DATASET_SEGMENTS) } redisdb.lpush(cluster_listen_address, yaml.dump(client_request)) node.collect_segment_hash_table() assert len(node.segment_hash_table) == NUMBER_OF_DATASET_SEGMENTS
def test_validate_request_data_invalid_type(data): node = CommitteeCandidate() node.task_data = data with pytest.raises(ValueError): node.validate_request_data()
def test_validate_request_data_simple(client_task_definition_data): node = CommitteeCandidate() node.task_data = client_task_definition_data node.validate_request_data()
def test_committee_gets_disolved_before_completing_training(redisdb): registration_channel = 'test_registration' node_1 = CommitteeCandidate() node_1.conn = redisdb node_1._task_registration_channel = registration_channel node_1.register_for_task() node_2 = CommitteeCandidate() node_2.conn = redisdb node_2._task_registration_channel = registration_channel node_2.register_for_task() assert node_1.get_number_of_registered_nodes() == 2 assert node_2.get_number_of_registered_nodes() == 2 time.sleep(20) assert len(node_1.get_registered_nodes()) == 0
def test_disable_registration_for_client_task(): node = CommitteeCandidate() node.conn = MagicMock() node.disable_registration_for_client_task() node.conn.lrem.assert_called()
def test_get_number_of_registered_nodes_multiple(redisdb): registration_channel = 'test_registration' node_1 = CommitteeCandidate() node_1.conn = redisdb node_1._task_registration_channel = registration_channel node_1.register_for_task() node_2 = CommitteeCandidate() node_2.conn = redisdb node_2._task_registration_channel = registration_channel node_2.register_for_task() assert node_1.get_number_of_registered_nodes() == 2 assert node_2.get_number_of_registered_nodes() == 2
def __init__(self, redis_host, redis_port, context, is_debug=False): CommitteeCandidate.__init__(self, redis_host, redis_port, is_debug) # lambda convenience methods for message map self.zero_setter = np.vectorize(lambda int_type: int_type & (~(1 << 31)), otypes=[np.uint32]) self.one_setter = np.vectorize(lambda int_type: int_type | (1 << 31), otypes=[np.uint32]) self.ctx = context self.network_time_to_pick_next_msg = None # the hash of the current batch self.batch_hash = None # the current state of the model self.net_hash_start = None self.net_hash_end = None # set the threshold for residual gradient self.tau = None # Collect all parameters from net and its children, then initialize them. self.net = None # self.net.hybridize() # for performance reasons self.trainer = None self._peer_msg_ids = [] self._consumed_peer_msg_ids = [] # Add metrics: accuracy and cross-entropy loss accuracy = mx.metric.Accuracy() ce_loss = mx.metric.CrossEntropy() self.comp_metric = mx.metric.CompositeEvalMetric([accuracy, ce_loss]) self.loss = gluon.loss.SoftmaxCrossEntropyLoss() # variables needed for indexing self.gradients_sizes = None # cumulative sum of gradients used in indexing self.gradients_cumulative = None # gradients blueprint self.gradients_blueprint = [] self.grads = None # flag for showing that this is the first batch # (some network parameters are lazily initialized, so we need this) self.gradient_residual = None self.last_updated_iteration_index = 0 self.global_iteration_index = 0 self.message_key_template = 'it_res_{}_{}_{}' self.node_output_directory = None self._model_path = None self.miner = None