コード例 #1
0
ファイル: train.py プロジェクト: zaixizhang/Alchemy
def train(model="sch", epochs=80, device=th.device("cpu")):
    print("start")
    alchemy_dataset = TencentAlchemyDataset()
    alchemy_loader = DataLoader(dataset=alchemy_dataset,
                                batch_size=20,
                                collate_fn=batcher(),
                                shuffle=False,
                                num_workers=0)

    if model == "sch":
        model = SchNetModel(norm=True, output_dim=12)
    elif model == "mgcn":
        model = MGCNModel(norm=True, output_dim=12)
    elif model == "MPNN":
        model = MPNNModel(output_dim=12)
    print(model)
    if model.name in ["MGCN", "SchNet"]:
        model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
    model.to(device)

    loss_fn = nn.MSELoss()
    MAE_fn = nn.L1Loss()
    optimizer = th.optim.Adam(model.parameters(), lr=0.0001)

    for epoch in range(epochs):

        w_loss, w_mae = 0, 0
        model.train()

        for idx, batch in enumerate(alchemy_loader):
            batch.graph.to(device)
            batch.label = batch.label.to(device)

            res = model(batch.graph)
            loss = loss_fn(res, batch.label)
            mae = MAE_fn(res, batch.label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            w_mae += mae.detach().item()
            w_loss += loss.detach().item()
        w_mae /= idx + 1

        print("Epoch {:2d}, loss: {:.7f}, mae: {:.7f}".format(
            epoch, w_loss, w_mae))
コード例 #2
0
def sch_train(epochs=80, device=torch.device('cpu')):
    alchemy_dataset = TencentAlchemyDataset()
    alchemy_loader = DataLoader(dataset=alchemy_dataset,
                                batch_size=20,
                                collate_fn=batcher(),
                                shuffle=False,
                                num_workers=0)

    model = SchNetModel(norm=True, output_dim=15)
    model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
    model.to(device)

    loss_fn = nn.MSELoss()
    MAE_fn = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    losses = []
    for epoch in range(epochs):
        start = time()
        w_loss, w_mae = 0, 0
        model.train()

        for idx, batch in enumerate(alchemy_loader):
            batch.graph.to(device)
            batch.label = batch.label.to(device)

            res = model(batch.graph)
            loss = loss_fn(res, batch.label)
            mae = MAE_fn(res, batch.label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            w_mae += mae.detach().item()
            w_loss += loss.detach().item()
        w_mae /= idx + 1

        print(
            "Epoch {:2d}, \tloss: {:.7f}, \tmae: {:.7f}, \ttime: {:.3f}".format(
                epoch, w_loss, w_mae, time() - start))
        losses.append(w_loss)
    torch.save(model, r'./model/schnet-{}.pth'.format(epochs))
    plt.plot(range(epochs), losses, color='red')
    plt.savefig(
        r'./img/schnet_{}_{:.4f}.png'.format(strftime('%Y-%m-%d'), str(w_mae)))
    plt.show()
コード例 #3
0
def eval(model="sch",
         epochs=80,
         device=th.device("cpu"),
         train_dataset='',
         eval_dataset='',
         epoch=1):
    print("start")
    epoch = int(epoch)
    test_dataset = TencentAlchemyDataset()
    test_dir = './'
    test_file = train_dataset + '_' + eval_dataset + "_cross.csv"
    test_dataset.mode = "Train"
    test_dataset.transform = None
    test_dataset.file_path = test_file
    test_dataset._load()

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=10,
        collate_fn=batcher(),
        shuffle=False,
        num_workers=0,
    )

    if model == "sch":
        model = SchNetModel(norm=False, output_dim=1)
    elif model == "mgcn":
        model = MGCNModel(norm=False, output_dim=1)
    elif model == "MPNN":
        model = MPNNModel(output_dim=1)
    print(model)
    # if model.name in ["MGCN", "SchNet"]:
    #     model.set_mean_std(mean, std, device)
    model.load_state_dict(
        th.load('./' + train_dataset + "/model_" + str(epoch)))
    model.to(device)

    loss_fn = nn.MSELoss()
    MAE_fn = nn.L1Loss()
    # optimizer = th.optim.Adam(model.parameters(), lr=0.0001)

    val_loss, val_mae = 0, 0
    res_file = open(
        train_dataset + '_' + eval_dataset + str(epoch) + "_crossres.txt", 'w')
    for jdx, batch in enumerate(test_loader):
        batch.graph.to(device)
        batch.label = batch.label.to(device)

        res = model(batch.graph)
        res_np = res.cpu().detach().numpy()
        label_np = batch.label.cpu().detach().numpy()
        for i in range(len(res_np)):
            res_file.write(str(res_np[i][0]) + '\t')
            res_file.write(str(label_np[i][0]) + '\n')

        loss = loss_fn(res, batch.label)
        mae = MAE_fn(res, batch.label)

        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        val_mae += mae.detach().item()
        val_loss += loss.detach().item()
    val_mae /= jdx + 1
    val_loss /= jdx + 1
    print("Epoch {:2d}, val_loss: {:.7f}, val_mae: {:.7f}".format(
        epoch, val_loss, val_mae))
    print("test_dataset.mean= %s" % (test_dataset.mean))
    print("test_dataset.std= %s" % (test_dataset.std))
コード例 #4
0
def train(model="sch",
          epochs=80,
          device=th.device("cpu"),
          dataset='',
          save=''):
    print("start")
    train_dir = "./"
    train_file = dataset + "_train.csv"
    alchemy_dataset = TencentAlchemyDataset()
    alchemy_dataset.mode = "Train"
    alchemy_dataset.transform = None
    alchemy_dataset.file_path = train_file
    alchemy_dataset._load()

    test_dataset = TencentAlchemyDataset()
    test_dir = train_dir
    test_file = dataset + "_valid.csv"
    test_dataset.mode = "Train"
    test_dataset.transform = None
    test_dataset.file_path = test_file
    test_dataset._load()

    alchemy_loader = DataLoader(
        dataset=alchemy_dataset,
        batch_size=10,
        collate_fn=batcher(),
        shuffle=False,
        num_workers=0,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=10,
        collate_fn=batcher(),
        shuffle=False,
        num_workers=0,
    )

    if model == "sch":
        model = SchNetModel(norm=False, output_dim=1)
    elif model == "mgcn":
        model = MGCNModel(norm=False, output_dim=1)
    print(model)
    # if model.name in ["MGCN", "SchNet"]:
    #     model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
    model.to(device)
    # print("test_dataset.mean= %s" % (alchemy_dataset.mean))
    # print("test_dataset.std= %s" % (alchemy_dataset.std))

    loss_fn = nn.MSELoss()
    MAE_fn = nn.L1Loss()
    optimizer = th.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                        mode='min',
                                                        factor=0.9,
                                                        patience=10,
                                                        threshold=0.0000001,
                                                        threshold_mode='rel',
                                                        cooldown=0,
                                                        min_lr=0.000001,
                                                        eps=1e-08,
                                                        verbose=False)

    for epoch in range(epochs):

        w_loss, w_mae = 0, 0
        model.train()

        for idx, batch in enumerate(alchemy_loader):
            batch.graph.to(device)
            batch.label = batch.label.to(device)

            res = model(batch.graph)
            loss = loss_fn(res, batch.label)
            mae = MAE_fn(res, batch.label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            w_mae += mae.detach().item()
            w_loss += loss.detach().item()
        w_mae /= idx + 1
        w_loss /= idx + 1
        scheduler.step(w_mae)

        print("Epoch {:2d}, loss: {:.7f}, mae: {:.7f}".format(
            epoch, w_loss, w_mae))

        val_loss, val_mae = 0, 0
        for jdx, batch in enumerate(test_loader):
            batch.graph.to(device)
            batch.label = batch.label.to(device)

            res = model(batch.graph)
            loss = loss_fn(res, batch.label)
            mae = MAE_fn(res, batch.label)

            # optimizer.zero_grad()
            # mae.backward()
            # optimizer.step()

            val_mae += mae.detach().item()
            val_loss += loss.detach().item()
        val_mae /= jdx + 1
        val_loss /= jdx + 1
        print("Epoch {:2d}, val_loss: {:.7f}, val_mae: {:.7f}".format(
            epoch, val_loss, val_mae))

        if epoch % 200 == 0:
            th.save(model.state_dict(), save + "/model_" + str(epoch))
コード例 #5
0
def train(model="sch", epochs=80, device=th.device("cpu"), dataset=''):
    print("start")
    train_dir = "./"
    train_file = "train_smi.csv"
    alchemy_dataset = TencentAlchemyDataset()
    alchemy_dataset.mode = "Train"
    alchemy_dataset.transform = None
    alchemy_dataset.file_path = train_file
    alchemy_dataset._load()

    test_dataset = TencentAlchemyDataset()
    test_dir = train_dir
    test_file = "val_smi.csv"
    test_dataset.mode = "Train"
    test_dataset.transform = None
    test_dataset.file_path = test_file
    test_dataset._load()

    alchemy_loader = DataLoader(
        dataset=alchemy_dataset,
        batch_size=10,
        collate_fn=batcher(),
        shuffle=False,
        num_workers=0,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=10,
        collate_fn=batcher(),
        shuffle=False,
        num_workers=0,
    )

    if model == "sch":
        model = SchNetModel(norm=False, output_dim=1)
    elif model == "mgcn":
        model = MGCNModel(norm=False, output_dim=1)
    elif model == "MPNN":
        model = MPNNModel(output_dim=1)
    print(model)
    # if model.name in ["MGCN", "SchNet"]:
    #     model.set_mean_std(alchemy_dataset.mean, alchemy_dataset.std, device)
    model.to(device)
    # print("test_dataset.mean= %s" % (alchemy_dataset.mean))
    # print("test_dataset.std= %s" % (alchemy_dataset.std))

    loss_fn = nn.MSELoss()
    MAE_fn = nn.L1Loss()
    optimizer = th.optim.Adam(model.parameters(), lr=0.0001)

    for epoch in range(epochs):

        w_loss, w_mae = 0, 0
        model.train()

        for idx, batch in enumerate(alchemy_loader):
            batch.graph.to(device)
            batch.label = batch.label.to(device)

            res = model(batch.graph)
            loss = loss_fn(res, batch.label)
            mae = MAE_fn(res, batch.label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            w_mae += mae.detach().item()
            w_loss += loss.detach().item()
        w_mae /= idx + 1
        w_loss /= idx + 1

        print("Epoch {:2d}, loss: {:.7f}, mae: {:.7f}".format(
            epoch, w_loss, w_mae))

        val_loss, val_mae = 0, 0
        if (epoch % 50 == 0):
            res_file = open('val_results_%s.txt' % (epoch), 'w')
        for jdx, batch in enumerate(test_loader):
            batch.graph.to(device)
            batch.label = batch.label.to(device)

            res = model(batch.graph)
            loss = loss_fn(res, batch.label)
            mae = MAE_fn(res, batch.label)

            optimizer.zero_grad()
            mae.backward()
            optimizer.step()

            val_mae += mae.detach().item()
            val_loss += loss.detach().item()

            res_np = res.cpu().detach().numpy()
            label_np = batch.label.cpu().detach().numpy()

            if (epoch % 50 == 0):
                for i in range(len(res_np)):
                    res_file.write(str(res_np[i][0]) + '\t')
                    res_file.write(str(label_np[i][0]) + '\n')

        val_mae /= jdx + 1
        val_loss /= jdx + 1
        print("Epoch {:2d}, val_loss: {:.7f}, val_mae: {:.7f}".format(
            epoch, val_loss, val_mae))

        if epoch % 50 == 0:
            th.save(model.state_dict(),
                    './' + dataset + "/model_" + str(epoch))