def greedy(stage=1): p = None m = DataParallel(ModelStack(1)).to(o.device) if stage > 1: p = DataParallel(ModelStack(stage - 1)).to(o.device) load(p, torch.load("save/01-10g.tar")) p.stage = stage - 1 train(m, p) # concat and save a = change_key(m.module.m[0].state_dict(), lambda x: f"m.{stage-1}." + x) if p: a.update(p.module.state_dict()) torch.save(a, f"save/01-10g.tar")
def greedy(stage=1): p = None m = DataParallel(ModelStack(1)).to(o.device) if stage > 1: p = DataParallel(ModelStack(stage - 1)).to(o.device) load(p, "save/01-10g.tar") p.eval() p.stage = stage - 1 # init stage using stage-1 a = change_key(p.module.m[-1].state_dict(), lambda x: f"m.0.{x}") load(m, a) train(m, p) # concat and save a = change_key(m.module.m[0].state_dict(), lambda x: f"m.{stage-1}." + x) if p: a.update(p.module.state_dict()) torch.save(a, "save/01-10g.tar")