Exemplo n.º 1
0
    def test_load_point(self):
        if torch.cuda.device_count() == 0:
            net = Model(self.mode, [], "kaiming", show_structure=False)
        elif torch.cuda.device_count() == 1:
            net = Model(self.mode, [0], "kaiming", show_structure=False)
        else:
            net = Model(self.mode, [0, 1], "kaiming", show_structure=False)

        net.check_point("tm", self.epoch, "test_model")
        net.load_point("tm", self.epoch, "test_model")
        dir = "test_model/"
        shutil.rmtree(dir)
Exemplo n.º 2
0
 def test_check_point(self):
     if torch.cuda.is_available():
         net = Model(self.mode, [0], "kaiming", show_structure=False)
     elif torch.cuda.device_count() > 1:
         net = Model(self.mode, [0, 1], "kaiming", show_structure=False)
     elif torch.cuda.device_count() > 2:
         net = Model(self.mode, [2, 3], "kaiming", show_structure=False)
     else:
         net = Model(self.mode, [], "kaiming", show_structure=False)
     net.check_point("tm", self.epoch, "test_model")
     dir = "test_model/"
     shutil.rmtree(dir)
Exemplo n.º 3
0
 def test_save_load_weights(self):
     print(self.mode)
     if torch.cuda.device_count() == 0:
         net = Model(self.mode, [], "kaiming", show_structure=False)
     elif torch.cuda.device_count() == 1:
         net = Model(self.mode, [0], "kaiming", show_structure=False)
     else:
         net = Model(self.mode, [0, 1], "kaiming", show_structure=False)
     net.check_point("tm", self.epoch, "test_model")
     net.load_weights("test_model/checkpoint/Weights_tm_%d.pth" %
                      self.epoch)
     dir = "test_model/"
     shutil.rmtree(dir)