def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: assert not torch.is_grad_enabled() (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample, detect_anomalies) = ctx._saved_non_tensors alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors dummy_grad_mask, *flat_grad_outputs = raw_grads flat_grad_outputs_cpu = [] for tensor in flat_grad_outputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of gradients has nan/inf values") flat_grad_outputs_cpu.append(tensor.cpu()) num_samples, max_experts = dummy_grad_mask.shape inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu)) grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)) backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"]))) # dispatch tasks to all remote experts, collect responses pending_tasks = {} for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert): expert = expert_per_sample[i.item()][j.item()] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint) inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij))) tensors_serialized = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)] new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized)) pending_tasks[new_task] = (i, j) backward_survivor_indices, survivor_grad_inputs = cls._collect_responses( pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies) if len(backward_survivor_indices) == 0: raise TimeoutError("Backward pass: no alive experts responded within timeout.") # assemble responses backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], [])) survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs)) # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape] grad_inputs = [] for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked): grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]), device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype) grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked # sum gradients from each expert grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1)) return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs))) grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest( uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload])) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, *deserialized_grad_inputs)
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int, timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float], detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]: assert not torch.is_grad_enabled() num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample)) flat_inputs_cpu = [] for tensor in flat_inputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of inputs has nan/inf values") flat_inputs_cpu.append(tensor.cpu()) flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu))) assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples # dispatch tasks to all remote experts collect responses pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {} for i in range(num_samples): for j, expert in enumerate(experts_per_sample[i]): input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip( flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint) new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors)) pending_tasks[new_task] = (i, j) alive_grid_indices, alive_flat_outputs = cls._collect_responses( pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies) if len(alive_grid_indices) == 0: raise TimeoutError("Forward pass: no alive experts responded within timeout.") # assemble responses alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices)) mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device) mask[alive_ii, alive_jj] = True alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs)) # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape] outputs = [] for response_stacked in alive_flat_outputs_stacked: output = torch.zeros( [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device, dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad) output[alive_ii, alive_jj] = response_stacked outputs.append(output.to(flat_inputs[0].device)) # save individual outputs for backward pass ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu) ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample, detect_anomalies) return (mask,) + tuple(outputs)
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: inputs_and_grad_outputs = tuple( nested_flatten((ctx.saved_tensors, grad_outputs))) backward_schema = tuple( nested_flatten( (ctx.info["forward_schema"], ctx.info["outputs_schema"]))) serialized_tensors = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) ] grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, None, *deserialized_grad_inputs)
async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]): stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub( expert.endpoint, aio=True) try: outputs = await stub.forward( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=[ serialize_torch_tensor(tensor) for tensor in inputs ])) return grid_indices, tuple( deserialize_torch_tensor(tensor) for tensor in outputs.tensors) except grpc.experimental.aio.AioRpcError as error: logger.warning( f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})" )
def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] inputs = tuple( map(torch.Tensor.detach, inputs)) # detach to avoid pickling the computation graph ctx.uid, ctx.stub = uid, stub ctx.save_for_backward(*inputs) outputs = stub.forward( runtime_pb2.ExpertRequest( uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs])) deserialized_outputs = [ deserialize_torch_tensor(tensor) for tensor in outputs.tensors ] return tuple(deserialized_outputs)
def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] # detach to avoid pickling the computation graph inputs = tuple(tensor.cpu().detach() for tensor in inputs) ctx.uid, ctx.stub, ctx.info = uid, stub, info ctx.save_for_backward(*inputs) serialized_tensors = [ serialize_torch_tensor(inp, proto.compression) for inp, proto in zip(inputs, nested_flatten(info["forward_schema"])) ] outputs = stub.forward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_outputs = [ deserialize_torch_tensor(tensor) for tensor in outputs.tensors ] return tuple(deserialized_outputs)
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int, timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float], detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]: assert not torch.is_grad_enabled() num_samples, max_experts = len(experts_per_sample), max( map(len, experts_per_sample)) flat_inputs_cpu = [] for tensor in flat_inputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of inputs has nan/inf values") flat_inputs_cpu.append(tensor.cpu()) flat_inputs_per_sample = list( zip(*(x.split(1, dim=0) for x in flat_inputs_cpu))) assert len(experts_per_sample) == len( flat_inputs_per_sample) == num_samples # dispatch tasks to all remote experts collect responses pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {} for i in range(num_samples): for j, expert in enumerate(experts_per_sample[i]): input_tensors = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip( flat_inputs_per_sample[i], nested_flatten(info['forward_schema'])) ] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub( expert.endpoint) new_task = stub.forward.future( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors)) pending_tasks[new_task] = (i, j) responded_inds, alive_flat_outputs = cls._collect_responses( pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies) if len(responded_inds) < k_min: raise TimeoutError( f"Forward pass: less than {k_min} responded within timeout.") if not isinstance(info['outputs_schema'], tuple): outputs_schema = (info['outputs_schema'], ) else: outputs_schema = info['outputs_schema'] outputs = nested_map( lambda descriptor: descriptor.make_empty( num_samples, max_experts, device=flat_inputs[0].device).zero_( ), outputs_schema) # assemble responses if len(responded_inds) > 0 or allow_zero_outputs: batch_inds, expert_inds = map( lambda x: torch.as_tensor( x, device=flat_inputs[0].device, dtype=torch.long), list(zip(*responded_inds)) or ([], [])) alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip( *alive_flat_outputs)) # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape] for output, response_stacked in zip(outputs, alive_flat_outputs_stacked): output[batch_inds, expert_inds] = response_stacked.to(output.device) else: raise RuntimeError( 'Forward pass: 0 experts responded within timeout and allow_zero_outputs is False' ) mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device) mask[batch_inds, expert_inds] = True # save individual outputs for backward pass ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu) ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample, detect_anomalies) return (mask, ) + outputs