def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)): dht = [] for i in range(dht_size): neighbors_i = [ f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht))) ] dht.append( hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)) real_experts = sorted({ 'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts) }) for batch_start in range(0, len(real_experts), batch_size): random.choice(dht).declare_experts( real_experts[batch_start:batch_start + batch_size], wait=True, endpoint= f"host{batch_start // batch_size}:{random.randint(0, 65536)}") neighbors_i = [ f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht))) ] you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc) beam_search = MoEBeamSearcher(you, 'expert.', grid_dims) for i in range(10): topk_experts = beam_search.find_best_experts( [np.random.randn(dim) for dim in grid_dims], beam_size) assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts) assert len(topk_experts) == beam_size for i in range(10): batch_experts = beam_search.batch_find_best_experts( [np.random.randn(batch_size, dim) for dim in grid_dims], beam_size=beam_size) assert isinstance(batch_experts, list) and len(batch_experts) == batch_size assert all( isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts) assert all(len(experts) == beam_size for experts in batch_experts)
class RemoteMixtureOfExperts(nn.Module): """ A torch module that performs mixture of experts inference with a local gating function and multiple remote experts. Natively supports pytorch autograd. :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without the missing experts :param in_features: common input size for experts and gating function :param grid_size: dimensions that form expert uid (see below) :param uid_prefix: common prefix for all expert uids (must end with '.') :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]} :param dht: a DHT instance used to search for best experts :param k_best: average this many highest-scoring experts to compute activations :param k_min: make sure at least this many experts returned output (i.e. didn't fail) :param timeout_after_k_min: wait for this many seconds after k_min experts returned results. :param detect_anomalies: whether to check input/output tensors for NaN and infinity values Any expert that didn't manage to return output after that delay is considered unavailable """ def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int, k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None, backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False, **dht_kwargs): super().__init__() self.dht = dht self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs) self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout self.timeout_after_k_min = timeout_after_k_min self.detect_anomalies = detect_anomalies self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions self._expert_info = None # expert['info'] from one of experts in the grid def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor): """ Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all dimensions except for first and last (we assume that extra dimensions represent sequence length or image height/width) :param input: a tensor of values that are used to estimate gating function, batch-first. :param args: extra positional parameters that will be passed to each expert after input, batch-first :param kwargs: extra keyword parameters that will be passed to each expert, batch-first :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first """ if input.ndim != 2: input_for_gating = input.mean(dim=tuple(range(1, input.ndim - 1))) else: input_for_gating = input # 1. compute scores and find most appropriate experts with beam search grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1) chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts( [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best) if self._expert_info is None: try: self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i)) except grpc.RpcError as e: logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}") expert_mask, *expert_outputs = _RemoteCallMany.apply( DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout, self.backward_timeout, self.detect_anomalies, self.info, *nested_flatten(((input, *args), kwargs))) # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape] expert_logits = self.compute_expert_scores(grid_scores, chosen_experts) masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype) expert_logits = torch.where(expert_mask, expert_logits, masked_logits) expert_weights = torch.softmax(expert_logits, dim=1) averaged_outputs_flat = [ (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1) for tensor in expert_outputs] # ^-- multiply by softmax weights along first 2 axes return nested_pack(averaged_outputs_flat, self.info['outputs_schema']) def compute_expert_scores( self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor: """ Compute scores for each expert by adding up grid scores, autograd-friendly :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch :returns: a tensor of scores, float32[batch_size, k] :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf """ expert_counts = list(map(len, batch_experts)) batch_size = len(batch_experts) max_num_experts = max(expert_counts) total_num_experts = sum(expert_counts) expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device) expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1] flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1 flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices] flat_experts = [expert for row in batch_experts for expert in row] grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64) for i, expert in enumerate(flat_experts): expert_indices = expert.uid[len(self.beam_search.uid_prefix):] expert_indices = list(map(int, expert_indices.split(UID_DELIMITER))) grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype) scores_per_dim = [ dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0) for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)] flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0) scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device) scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores return scores @property def info(self): if self._expert_info is None: # grab some expert to set ensemble output shape proj_device = self.proj.weight.device dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device)) dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1) dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1) self._expert_info = dummy_experts[0].info return self._expert_info