Exemplo n.º 1
0
def main():
    args = parseParams()
    try:
        os.stat(args.savestr)
    except:
        os.mkdir(args.savestr)
    DS = torch.load(args.datafile)
    DS.mkbatches(args.bsz)
    args.vsz = DS.vsz
    args.evsz = len(DS.itos)
    print("Vocab Size: ", args.vsz)
    print("E Vocab Size: ", args.evsz)
    start = 0
    if args.resume:
        M, optimizer = torch.load(args.resume)
        start = int(args.resume.split("/")[-1].split("_")[0]) + 1
        M.enc.flatten_parameters()
        M.dec.flatten_parameters()
    else:
        M = model(args).cuda()
        optimizer = torch.optim.Adam(M.parameters(), lr=args.lr)
    print(M)
    for epoch in range(start, args.epochs):
        args.epoch = str(epoch)
        trainloss = train(M, DS, args, optimizer)
        print("train loss epoch", epoch, trainloss)
        b = validate(M, DS, args)
        print("valid bleu ", b)
        torch.save((M, optimizer),
                   args.savestr + args.epoch + "_bleu-" + str(b))
Exemplo n.º 2
0
def main():
    args = parseParams()
    DS = torch.load(args.datafile)
    if args.debug:
        args.bsz = 2
        DS.train = DS.train[:2]
        DS.valid = DS.valid[:2]

    args.vsz = DS.vsz
    args.svsz = DS.svsz
    if args.resume:
        M, optimizer = torch.load(args.resume)
        M.enc.flatten_parameters()
        M.dec.flatten_parameters()
        e = args.resume.split("/")[-1] if "/" in args.resume else args.resume
        e = e.split('_')[0]
        e = int(e) + 1
    else:
        optimizer = torch.optim.Adam(M.parameters(), lr=args.lr)
        M = model(args).cuda()
    print(M)
    for epoch in range(e, args.epochs):
        args.epoch = str(epoch)
        trainloss = train(M, DS, args, optimizer)
        print("train loss epoch", epoch, trainloss)
        b = validate(M, DS, args)
        print("valid bleu ", b)
        torch.save((M, optimizer),
                   args.savestr + args.epoch + "_bleu-" + str(b))
Exemplo n.º 3
0
def main():
    args = parseParams()
    args.savestr = "saved_models/extra/"
    try:
        os.stat(args.savestr)
    except:
        os.mkdir(args.savestr)
    if args.debug:
        args.bsz = 2
        args.datafile = "data/multiref_debug.pt"
    DS = torch.load(args.datafile)
    DS.mkbatches(args.bsz)
    args.vsz = DS.vsz
    M = model(args).cuda()
    init_vocab(M, DS, args)
    print(M)
    optimizer = torch.optim.Adam(M.parameters(), lr=args.lr)
    for epoch in range(args.epochs):
        args.epoch = str(epoch)
        trainloss = train(M, DS, args, optimizer)
        print("train loss epoch", epoch, trainloss)
        b = validate(M, DS, args)
        print("valid bleu ", b)
        torch.save((M, optimizer),
                   args.savestr + args.epoch + "_bleu-" + str(b))
        args.temperature *= 0.5
Exemplo n.º 4
0
def main():
    args = parseParams()
    DS = torch.load(args.datafile)
    DS.args = args
    models = [x for x in os.listdir(args.savestr) if x[0].isdigit()]
    for m in models:
        args.epoch = m
        I, S, _, _ = torch.load(args.savestr + m)
        if not args.cuda:
            print('move to cpu')
            S = S.cpu()
            I = I.cpu()
        S.dec.flatten_parameters()
        I.enc.flatten_parameters()
        I.vdec.flatten_parameters()
        S.args = args
        I.args = args
        S.endtok = DS.vocab.index("<eos>")
        S.punct = [DS.vocab.index(t) for t in ['.', '!', '?']]
        validate(I, S, DS, args, m)
Exemplo n.º 5
0
def main():
  args = parseParams()
  try:
    os.stat(args.savestr)
  except:
    os.mkdir(args.savestr)
  DS = torch.load(args.datafile)
  DS.mkbatches(args.bsz)
  args.vsz = DS.vsz
  print("Vocab Size: ",args.vsz)
  M = model(args).cuda()
  print(M)
  optimizer = torch.optim.Adam(M.parameters(), lr=args.lr)
  for epoch in range(args.epochs):
    args.epoch = str(epoch)
    trainloss = train(M,DS,args,optimizer)
    print("train loss epoch",epoch,trainloss)
    b = validate(M,DS,args)
    print("valid bleu ",b)
    torch.save((M,optimizer),args.savestr+args.epoch+"_bleu-"+str(b))
Exemplo n.º 6
0
def main():
  args = parseParams()
  DS = torch.load(args.datafile)
  if args.debug:
    args.bsz=2
    DS.train = DS.train[:2]
    DS.valid= DS.valid[:2]

  args.vsz = DS.vsz
  args.svsz = DS.svsz
  args.catsz = len(DS.vcats)+1
  M = model(args).cuda()
  print(M)
  optimizer = torch.optim.Adam(M.parameters(), lr=args.lr)
  for epoch in range(args.epochs):
    args.epoch = str(epoch)
    trainloss = train(M,DS,args,optimizer)
    print("train loss epoch",epoch,trainloss)
    b = validate(M,DS,args)
    print("valid bleu ",b)
    torch.save((M,optimizer),args.savestr+args.epoch+"_bleu-"+str(b))
Exemplo n.º 7
0
from preprocess import *
from arguments import s2s1cats as parseParams
if __name__ == "__main__":
    args = parseParams()
    DS = load_data(args)
    torch.save(DS, args.datafile)