コード例 #1
0
def set_weights(model: torch.nn.ModuleList,
                weights: fl.common.Weights) -> None:
    """Set model weights from a list of NumPy ndarrays."""
    state_dict = OrderedDict({
        k: torch.Tensor(np.atleast_1d(v))
        for k, v in zip(model.state_dict().keys(), weights)
    })
    model.load_state_dict(state_dict, strict=True)
コード例 #2
0
def get_weights(model: torch.nn.ModuleList) -> fl.common.Weights:
    """Get model weights as a list of NumPy ndarrays."""
    return [val.cpu().numpy() for _, val in model.state_dict().items()]