def train():
    dataset = trainDataset()
    model = MainModel(dataset.vocab_size()[0], dataset.vocab_size()[1])
    #model.load_state_dict(torch.load("model2.pth"))
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             pin_memory=True)

    for epoch in range(EPOCHS):
        print(f"EPOCH: {epoch + 1}/{EPOCHS}")
        losses = []

        for idx, data in tqdm(enumerate(dataloader)):
            outputs = model(data["source"].cuda(), data["target"].cuda(),
                            data["alignment"].cuda())
            loss = torch.nn.functional.binary_cross_entropy(
                outputs.view(-1), data["predictions"].cuda().view(-1).float())
            # print(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.detach())
        print(f"Mean Loss for Epoch: {epoch} is {sum(losses) / len(losses)}")
        torch.save(model.state_dict(), f"model_lstm.pth")
Esempio n. 2
0
        print("g_loss: " + str(g_avg_loss))
        logging.info("g_loss: " + str(g_avg_loss))

        # Testing
        network.cpu().eval()
        p_1, r_1, f1_1 = evaluator.sparse_network_reconstruct(network, args.pos_size[0])
        p_2, r_2, f1_2 = evaluator.sparse_network_reconstruct(network, args.pos_size[1])
        p_3, r_3, f1_3 = evaluator.sparse_network_reconstruct(network, args.pos_size[2])
        print(p_3, r_3, f1_3)
        if f1_3 > max_f1:
            max_f1 = f1_3
            max_p = p_3
            max_r = r_3
            torch.save(network.state_dict(), "model/" + model_dir + ".pt")
        network.cuda(args.cuda)
        network.generator.sample_linear.cpu()
        print("Network Reconstruction Results (K=" + str(args.pos_size[0]) + "): ")
        print("Prec: " + str(p_1) + " Rec: " + str(r_1) + " F1: " + str(f1_1))
        print("Network Reconstruction Results (K=" + str(args.pos_size[1]) + "): ")
        print("Prec: " + str(p_2) + " Rec: " + str(r_2) + " F1: " + str(f1_2))
        print("Network Reconstruction Results (K=" + str(args.pos_size[2]) + "): ")
        print("Prec: " + str(max_p) + " Rec: " + str(max_r) + " F1: " + str(max_f1))
        logging.info("K=" + str(args.pos_size[0]) + "Prec: " + str(p_1) + " Rec: " + str(r_1) + " F1: " + str(f1_1))
        logging.info("K=" + str(args.pos_size[1]) + "Prec: " + str(p_2) + " Rec: " + str(r_2) + " F1: " + str(f1_2))
        logging.info("K=" + str(args.pos_size[2]) + "Prec: " + str(max_p) + " Rec: " + str(max_r) + " F1: " + str(max_f1))
        accu, mle = evaluator.sparse_seq_predict(network, test_marker, test_time, test_mask, adj_list, 1)
        print("accu", accu)
        print("mle", mle)
        logging.info("accu: " + str(accu))
        logging.info("mle: " + str(mle))