def request_model(peer_ranks, variables, mode, peer_selection_strategy): import tensorflow as tf var_shapes = [var.shape for var in variables] var_sizes = [var.shape.num_elements() for var in variables] # Remove self rank from the list peer_ranks.remove(current_rank()) if mode == 'async': print("Request a model asynchronously.") request_model = _op_lib.async_request_model( variables, self_rank=current_rank(), ranks=peer_ranks, var_type_size=variables[0].dtype.size, var_sizes=var_sizes, shapes=var_shapes, peer_selection_strategy=peer_selection_strategy) elif mode == 'sync': print("Request a model synchronously.") request_model = _op_lib.request_model( variables, self_rank=current_rank(), ranks=peer_ranks, var_type_size=variables[0].dtype.size, var_sizes=var_sizes, shapes=var_shapes, peer_selection_strategy=peer_selection_strategy) else: raise Exception("Invalid type of model request mode") return request_model
def model_averaging(peer_ranks, variables, mode, peer_selection_strategy): import tensorflow as tf var_sizes = [var.shape.num_elements() for var in variables] # Remove self rank from the list peer_ranks.remove(current_rank()) if mode == 'async': print( "Applying model averaging with a model requested asynchronously.") model_averaging = _op_lib.async_model_averaging( variables, self_rank=current_rank(), ranks=peer_ranks, var_type_size=variables[0].dtype.size, var_sizes=var_sizes, peer_selection_strategy=peer_selection_strategy) elif mode == 'sync': print("Applying model averaging with a model requested synchronously.") model_averaging = _op_lib.model_averaging( variables, self_rank=current_rank(), ranks=peer_ranks, var_type_size=variables[0].dtype.size, var_sizes=var_sizes, peer_selection_strategy=peer_selection_strategy) else: raise Exception("Invalid type of model request mode.") return model_averaging
def get_neighbour_mask(edges): """Compute a bool vector of neighbours for the current peer. For the peer of rank i, v[j] = true if (i, j) is an edge of the MST, otherwise v[j] = false. """ return _op_lib.kungfu_get_neighbour_mask( edges, self_rank=current_rank(), cluster_size=current_cluster_size())