def set_weights(model: torch.nn.ModuleList, weights: fl.common.Weights) -> None: """Set model weights from a list of NumPy ndarrays.""" state_dict = OrderedDict({ k: torch.Tensor(np.atleast_1d(v)) for k, v in zip(model.state_dict().keys(), weights) }) model.load_state_dict(state_dict, strict=True)
def get_weights(model: torch.nn.ModuleList) -> fl.common.Weights: """Get model weights as a list of NumPy ndarrays.""" return [val.cpu().numpy() for _, val in model.state_dict().items()]