def main(args): try: os.stat(args.save) input("Save File Exists, OverWrite? <CTL-C> for no") except: os.mkdir(args.save) ds = dataset(args) args = dynArgs(args, ds) m = model(args) print(args.device) m = m.to(args.device) if args.ckpt: # # with open(args.save+"/commandLineArgs.txt") as f: # clargs = f.read().strip().split("\n") # argdif =[x for x in sys.argv[1:] if x not in clargs] # assert(len(argdif)==2); # assert([x for x in argdif if x[0]=='-']==['-ckpt']) # cpt = torch.load(args.ckpt) m.load_state_dict(cpt) starte = int(args.ckpt.split("/")[-1].split(".")[0]) + 1 args.lr = float(args.ckpt.split("-")[-1]) print('ckpt restored') else: with open(args.save + "/commandLineArgs.txt", 'w') as f: f.write("\n".join(sys.argv[1:])) starte = 0 o = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9) # early stopping based on Val Loss lastloss = 1000000 for e in range(starte, args.epochs): print("epoch ", e, "lr", o.param_groups[0]['lr']) train(m, o, ds, args) vloss = evaluate(m, ds, args) if args.lrwarm: update_lr(o, args, e) print("Saving model") torch.save( m.state_dict(), args.save + "/" + str(e) + ".vloss-" + str(vloss)[:8] + ".lr-" + str(o.param_groups[0]['lr'])) if vloss > lastloss: if args.lrdecay: print("decay lr") o.param_groups[0]['lr'] *= 0.5 lastloss = vloss
def main(args): args.eval = True ds = dataset(args) args = dynArgs(args, ds) m = model(args) print(args.device) m = m.to(args.device) ''' with open(args.save+"/commandLineArgs.txt") as f: clargs = f.read().strip().split("\n") argdif =[x for x in sys.argv[1:] if x not in clargs] assert(len(argdif)==2); assert([x for x in argdif if x[0]=='-']==['-ckpt']) ''' cpt = torch.load(args.ckpt) m.load_state_dict(cpt) starte = int(args.ckpt.split("/")[-1].split(".")[0]) + 1 args.lr = float(args.ckpt.split("-")[-1]) print('ckpt restored') m.args = args m.maxlen = args.max m.starttok = ds.OUTP.vocab.stoi['<start>'] m.endtok = ds.OUTP.vocab.stoi['<eos>'] m.eostok = ds.OUTP.vocab.stoi['.'] args.vbsz = 1 preds, gold = test(args, ds, m) sys.exit() o = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9) # early stopping based on Val Loss lastloss = 1000000 for e in range(starte, starte + 1): print("epoch ", e, "lr", o.param_groups[0]['lr']) vloss = evaluate(m, ds, args)
m.train() return preds,golds ''' def metrics(preds,gold): cands = {'generated_description'+str(i):x.strip() for i,x in enumerate(preds)} refs = {'generated_description'+str(i):[x.strip()] for i,x in enumerate(gold)} x = evalMetrics.Evaluate() scores = x.evaluate(live=True, cand=cands, ref=refs) return scores ''' if __name__=="__main__": args = pargs() args.eval = True ds = dataset(args) args = dynArgs(args,ds) m = model(args) cpt = torch.load(args.save) m.load_state_dict(cpt) m = m.to(args.device) m.args = args m.maxlen = args.max m.starttok = ds.OUTP.vocab.stoi['<start>'] m.endtok = ds.OUTP.vocab.stoi['<eos>'] m.eostok = ds.OUTP.vocab.stoi['.'] args.vbsz = 1 preds,gold = test(args,ds,m) ''' scores = metrics(preds,gold) for k,v in scores.items():