def train_func_checkpoint(): checkpoint = sgd.load_checkpoint() assert checkpoint is not None assert checkpoint["epoch"] == 2 for i in range(checkpoint["epoch"], 5): sgd.save_checkpoint(epoch=i) return 1
def train_func(config): itr = 0 ckpt = sgd.load_checkpoint() if ckpt is not None: itr = ckpt["iter"] + 1 for i in range(itr, config["max_iter"]): sgd.save_checkpoint(iter=i) sgd.report(test=i, training_iteration=i)
def train(): checkpoint = sgd.load_checkpoint() if checkpoint: epoch = checkpoint["epoch"] else: epoch = 0 print("Epoch: ", epoch) for i in range(epoch, 2): sgd.report(loss=1, iter=i) sgd.save_checkpoint(epoch=i + 1)
def train_func(): ckpt = sgd.load_checkpoint() restored = bool(ckpt) # Does a previous checkpoint exist? itr = 0 if ckpt: itr = ckpt["iter"] + 1 for i in range(itr, 4): if i == 2 and not restored: raise Exception("try to fail me") sgd.save_checkpoint(iter=i) sgd.report(test=i, training_iteration=i)
def train(): for i in range(2): sgd.save_checkpoint(epoch=i)
def train_mismatch(): sgd.save_checkpoint(epoch=0) sgd.report(index=0) # skip checkpoint sgd.report(index=1)
def train(): for i in range(2): sgd.save_checkpoint(epoch=i) sgd.report(index=i)
def train_mismatch(): sgd.save_checkpoint(epoch=0)
def train_func(): assert sgd.load_checkpoint() is None for i in range(3): sgd.save_checkpoint(epoch=i) return 1
def train_slow(): for i in range(2): sgd.save_checkpoint(epoch=i) time.sleep(5) sgd.report(index=i) time.sleep(5)
def train_func(): for i in range(10): sgd.report(test=i) sgd.save_checkpoint(hello="world")
def train(): if (sgd.world_rank()) == 0: sgd.save_checkpoint(epoch=0) else: sgd.report(iter=0)