Esempio n. 1
0
    def train_func_checkpoint():
        checkpoint = sgd.load_checkpoint()
        assert checkpoint is not None
        assert checkpoint["epoch"] == 2

        for i in range(checkpoint["epoch"], 5):
            sgd.save_checkpoint(epoch=i)
        return 1
Esempio n. 2
0
    def train_func(config):
        itr = 0
        ckpt = sgd.load_checkpoint()
        if ckpt is not None:
            itr = ckpt["iter"] + 1

        for i in range(itr, config["max_iter"]):
            sgd.save_checkpoint(iter=i)
            sgd.report(test=i, training_iteration=i)
Esempio n. 3
0
 def train():
     checkpoint = sgd.load_checkpoint()
     if checkpoint:
         epoch = checkpoint["epoch"]
     else:
         epoch = 0
     print("Epoch: ", epoch)
     for i in range(epoch, 2):
         sgd.report(loss=1, iter=i)
         sgd.save_checkpoint(epoch=i + 1)
Esempio n. 4
0
    def train_func():
        ckpt = sgd.load_checkpoint()
        restored = bool(ckpt)  # Does a previous checkpoint exist?
        itr = 0
        if ckpt:
            itr = ckpt["iter"] + 1

        for i in range(itr, 4):
            if i == 2 and not restored:
                raise Exception("try to fail me")
            sgd.save_checkpoint(iter=i)
            sgd.report(test=i, training_iteration=i)
Esempio n. 5
0
 def train():
     for i in range(2):
         sgd.save_checkpoint(epoch=i)
Esempio n. 6
0
 def train_mismatch():
     sgd.save_checkpoint(epoch=0)
     sgd.report(index=0)
     # skip checkpoint
     sgd.report(index=1)
Esempio n. 7
0
 def train():
     for i in range(2):
         sgd.save_checkpoint(epoch=i)
         sgd.report(index=i)
Esempio n. 8
0
 def train_mismatch():
     sgd.save_checkpoint(epoch=0)
Esempio n. 9
0
 def train_func():
     assert sgd.load_checkpoint() is None
     for i in range(3):
         sgd.save_checkpoint(epoch=i)
     return 1
Esempio n. 10
0
 def train_slow():
     for i in range(2):
         sgd.save_checkpoint(epoch=i)
         time.sleep(5)
         sgd.report(index=i)
         time.sleep(5)
Esempio n. 11
0
 def train_func():
     for i in range(10):
         sgd.report(test=i)
     sgd.save_checkpoint(hello="world")
Esempio n. 12
0
 def train():
     if (sgd.world_rank()) == 0:
         sgd.save_checkpoint(epoch=0)
     else:
         sgd.report(iter=0)