def test_state_dict_note(self): # モデルの重みにメモをかきこんだら怒られるのだろうか model = GRU( input_size=1, # Sequential MNIST タスクなら入力は1次元 output_size=10, # Sequential MNIST タスクなら出力は10次元 num_layers=1, # GRU ブロックの積み重ね数 d_hidden=128) # GRU ブロックの出力次元数(隠れ状態の次元数) model.load_state_dict( torch.load('./weights/gru_sequential_mnist_sample.dict')) # state_dict をとった後 clone して参照を切る state_dict = model.state_dict() for k, v in state_dict.items(): state_dict[k] = v.clone() # メモする state_dict['loss'] = 10.0 state_dict['accuracy'] = 0.97 state_dict['note'] = 'hoge' # メモ付きの重みを普通にモデルに流し込もうとすると怒られる with pytest.raises(RuntimeError): model.load_state_dict(state_dict) # strict=False にしておけばOK # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/modules/module.py#L1010-L1012 model.gru.weight_ih_l0.data[0, 0] = 1.0 # ここをわざと変更しておいて流し込まれたことを確認 model.load_state_dict(state_dict, strict=False) assert model.gru.weight_ih_l0.data[0, 0].item() == approx(0.2248, rel=1e-3)
def test_state_dict_ref(self): # モデルの state_dict をとった後モデルのパラメータを更新すると state_dict も連動する torch.manual_seed(0) model = GRU(input_size=1, output_size=4, num_layers=1, d_hidden=8) state_dict = model.state_dict() assert state_dict['gru.weight_ih_l0'][0, 0].item() == approx(-0.002647, rel=1e-3) model.gru.weight_ih_l0.data[0, 0] = 1.0 assert state_dict['gru.weight_ih_l0'][0, 0].item() == approx(1.0, rel=1e-3) # clone すれば参照は切れる state_dict['gru.weight_ih_l0'] = state_dict['gru.weight_ih_l0'].clone() model.gru.weight_ih_l0.data[0, 0] = 2.0 assert state_dict['gru.weight_ih_l0'][0, 0].item() == approx(1.0, rel=1e-3)