コード例 #1
0
ファイル: test_model.py プロジェクト: thinksee/jdit
 def test__weight_init(self):
     if torch.cuda.is_available():
         net = Model(self.mode, [0], "kaiming", show_structure=False)
     elif torch.cuda.device_count() > 1:
         net = Model(self.mode, [0, 1], "kaiming", show_structure=False)
     elif torch.cuda.device_count() > 2:
         net = Model(self.mode, [2, 3], "kaiming", show_structure=False)
     else:
         net = Model(self.mode, [], "kaiming", show_structure=False)
     net.init_fc = init.kaiming_normal_
     self.mode.apply(net._weight_init)
コード例 #2
0
 def test_weightsInit(self):
     net = Model()
     net.init_fc = init.kaiming_normal_
     self.mode.apply(net._weight_init)