Ejemplo n.º 1
0
from torch_geometric.nn import MLP, CorrectAndSmooth

root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')
dataset = PygNodePropPredDataset('ogbn-products',
                                 root,
                                 transform=T.ToSparseTensor())
evaluator = Evaluator(name='ogbn-products')
split_idx = dataset.get_idx_split()
data = dataset[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP([dataset.num_features, 200, 200, dataset.num_classes],
            dropout=0.5,
            batch_norm=True,
            relu_first=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

x, y = data.x.to(device), data.y.to(device)
train_idx = split_idx['train'].to(device)
val_idx = split_idx['valid'].to(device)
test_idx = split_idx['test'].to(device)
x_train, y_train = x[train_idx], y[train_idx]


def train():
    model.train()
    optimizer.zero_grad()
    out = model(x_train)
    loss = criterion(out, y_train.view(-1))
    loss.backward()
Ejemplo n.º 2
0
                    help='Balances loss from hard labels and teacher outputs')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())
data = dataset[0].to(device)

gnn = GCN(dataset.num_node_features, hidden_channels=16,
          out_channels=dataset.num_classes, num_layers=2).to(device)
mlp = MLP([dataset.num_node_features, 64, dataset.num_classes], dropout=0.5,
          batch_norm=False).to(device)

gnn_optimizer = torch.optim.Adam(gnn.parameters(), lr=0.01, weight_decay=5e-4)
mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4)


def train_teacher():
    gnn.train()
    gnn_optimizer.zero_grad()
    out = gnn(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    gnn_optimizer.step()
    return float(loss)


@torch.no_grad()
def test_teacher():
    gnn.eval()