from models import *

print("Train using model", Net)

model = Net()
if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=1e-6, momentum=args.momentum)

print("Args:", args)

val_losses=[]
for epoch in range(1, args.epochs + 1):
    if(epoch==1):
        test_loss = fp_train.test(epoch, args, model, test_loader, fp.dxs, fp.dys, test_length=0.1*len(valid_dataset))
    fp_train.train(epoch, args, model, optimizer, train_loader, fp.dxs, fp.dys)
    test_loss = fp_train.test(epoch, args, model, test_loader, fp.dxs, fp.dys, test_length=0.1*len(valid_dataset))
    val_losses.append(test_loss)
    loss_flag = 1
    #for i in range(2,5):
    #    if(epoch<=15 or val_losses[-1]<val_losses[-i]):
    #       loss_flag = 1

    path = os.path.join(args.log_dir, "ckpt", "state_dict-ep_{}.pth".format(epoch))
    print("Saving model in", path)
    torch.save(model.state_dict(), path)

    if(not loss_flag == 1):
        with open(os.path.join(args.log_dir, "termination_epoch"), "wb") as f:
            f.write(str(epoch)+"\n")
                            "wb"))

fp_target = util.np2var(fp_target, args.cuda)

fp = Fingerprints()
fp.dxs = fp_dx
fp.dys = fp_target

#from model import Net
from model import CW_Net as Net
#from small_model import Very_Small_Net as Net

model = Net()
if args.cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=args.lr)

print("Args:", args)

for epoch in range(1, args.epochs + 1):
    if (epoch == 1):
        fp_train.test(epoch, args, model, test_loader, fp.dxs, fp.dys)
    fp_train.train(epoch, args, model, optimizer, train_loader, fp.dxs, fp.dys)
    fp_train.test(epoch, args, model, test_loader, fp.dxs, fp.dys)

    path = os.path.join(args.log_dir, "ckpt",
                        "state_dict-ep_{}.pth".format(epoch))
    print("Saving model in", path)
    torch.save(model.state_dict(), path)