def test_load_point(self): if torch.cuda.device_count() == 0: net = Model(self.mode, [], "kaiming", show_structure=False) elif torch.cuda.device_count() == 1: net = Model(self.mode, [0], "kaiming", show_structure=False) else: net = Model(self.mode, [0, 1], "kaiming", show_structure=False) net.check_point("tm", self.epoch, "test_model") net.load_point("tm", self.epoch, "test_model") dir = "test_model/" shutil.rmtree(dir)
def test_check_point(self): if torch.cuda.is_available(): net = Model(self.mode, [0], "kaiming", show_structure=False) elif torch.cuda.device_count() > 1: net = Model(self.mode, [0, 1], "kaiming", show_structure=False) elif torch.cuda.device_count() > 2: net = Model(self.mode, [2, 3], "kaiming", show_structure=False) else: net = Model(self.mode, [], "kaiming", show_structure=False) net.check_point("tm", self.epoch, "test_model") dir = "test_model/" shutil.rmtree(dir)
def test_save_load_weights(self): print(self.mode) if torch.cuda.device_count() == 0: net = Model(self.mode, [], "kaiming", show_structure=False) elif torch.cuda.device_count() == 1: net = Model(self.mode, [0], "kaiming", show_structure=False) else: net = Model(self.mode, [0, 1], "kaiming", show_structure=False) net.check_point("tm", self.epoch, "test_model") net.load_weights("test_model/checkpoint/Weights_tm_%d.pth" % self.epoch) dir = "test_model/" shutil.rmtree(dir)