def test__weight_init(self): if torch.cuda.is_available(): net = Model(self.mode, [0], "kaiming", show_structure=False) elif torch.cuda.device_count() > 1: net = Model(self.mode, [0, 1], "kaiming", show_structure=False) elif torch.cuda.device_count() > 2: net = Model(self.mode, [2, 3], "kaiming", show_structure=False) else: net = Model(self.mode, [], "kaiming", show_structure=False) net.init_fc = init.kaiming_normal_ self.mode.apply(net._weight_init)
def test_weightsInit(self): net = Model() net.init_fc = init.kaiming_normal_ self.mode.apply(net._weight_init)