コード例 #1
0
    def __init__(self, hidden_channels, outer_out_channels, inner_out_channels,
                 num_layers, batch_size):
        super().__init__()

        self.batch_size = batch_size
        self.dim = hidden_channels

        nn = Sequential(Linear(5, 128), ReLU(), Linear(128,
                                                       self.dim * self.dim))

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv(
                {
                    ('x_i', 'inner_edge_i', 'x_i'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_i', 'outer_edge_ij', 'x_j'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_j', 'outer_edge_ji', 'x_i'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_i', 'inner_edge_i', 'x_i'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    NNConv(self.dim, self.dim, nn, aggr='mean'),
                },
                aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(self.dim, outer_out_channels)

        self.lin_i = Linear(self.dim, inner_out_channels)
        self.lin_j = Linear(self.dim, inner_out_channels)
コード例 #2
0
    def __init__(self, hidden_channels, outer_out_channels, inner_out_channels,
                 num_layers, batch_size, num_node_types, num_heads):
        super().__init__()

        self.batch_size = batch_size
        self.hidden_channels = hidden_channels
        self.heads = num_heads

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv(
                {
                    ('x_i', 'inner_edge_i', 'x_i'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_i', 'outer_edge_ij', 'x_j'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_j', 'outer_edge_ji', 'x_i'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_i', 'inner_edge_i', 'x_i'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                    ('x_j', 'inner_edge_j', 'x_j'):
                    GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                },
                aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(self.hidden_channels, outer_out_channels)

        self.lin_i = Linear(self.hidden_channels, inner_out_channels)
        self.lin_j = Linear(self.hidden_channels, inner_out_channels)
コード例 #3
0
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels,
                           hidden_channels,
                           data.metadata(),
                           num_heads,
                           group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)
コード例 #4
0
    def __init__(self, hidden_channels: int, out_channels: int,
                 dropout: float):
        super().__init__()
        self.dropout = dropout

        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), hidden_channels)
        self.lin = Linear(-1, out_channels)
コード例 #5
0
    def __init__(self, metadata, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                edge_type: SAGEConv((-1, -1), hidden_channels)
                for edge_type in metadata[1]
            })
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)
コード例 #6
0
def test_identical_linear_default_initialization(lazy):
    x = torch.randn(3, 4, 16)

    torch.manual_seed(12345)
    lin1 = Linear(-1 if lazy else 16, 32)
    lin1(x)

    torch.manual_seed(12345)
    lin2 = PTLinear(16, 32)

    assert lin1.weight.tolist() == lin2.weight.tolist()
    assert lin1.bias.tolist() == lin2.bias.tolist()
    assert lin1(x).tolist() == lin2(x).tolist()
コード例 #7
0
def test_load_lazy_linear(dim1, dim2):
    lin1 = Linear(dim1, 32)
    lin2 = Linear(dim1, 32)
    lin2.load_state_dict(lin1.state_dict())

    if dim1 != -1:
        assert torch.allclose(lin1.weight, lin2.weight)
        assert torch.allclose(lin1.bias, lin2.bias)
        assert not hasattr(lin1, '_hook')
        assert not hasattr(lin2, '_hook')
    else:
        assert isinstance(lin1.weight, UninitializedParameter)
        assert isinstance(lin2.weight, UninitializedParameter)
        assert hasattr(lin1, '_hook')
        assert hasattr(lin2, '_hook')
コード例 #8
0
def test_copy_linear(lazy):
    lin = Linear(-1 if lazy else 16, 32)

    copied_lin = copy.copy(lin)
    assert id(copied_lin) != id(lin)
    assert id(copied_lin.weight) == id(lin.weight)
    if not isinstance(copied_lin.weight, UninitializedParameter):
        assert copied_lin.weight.data_ptr() == lin.weight.data_ptr()
    assert id(copied_lin.bias) == id(lin.bias)
    assert copied_lin.bias.data_ptr() == lin.bias.data_ptr()

    copied_lin = copy.deepcopy(lin)
    assert id(copied_lin) != id(lin)
    assert id(copied_lin.weight) != id(lin.weight)
    if not isinstance(copied_lin.weight, UninitializedParameter):
        assert copied_lin.weight.data_ptr() != lin.weight.data_ptr()
        assert copied_lin.weight.tolist() == lin.weight.tolist()
    assert id(copied_lin.bias) != id(lin.bias)
    assert copied_lin.bias.data_ptr() != lin.bias.data_ptr()
    assert copied_lin.bias.tolist() == lin.bias.tolist()
コード例 #9
0
def test_lazy_linear(weight, bias):
    x = torch.randn(3, 4, 16)
    lin = Linear(-1, 32, weight_initializer=weight, bias_initializer=bias)
    assert str(lin) == 'Linear(-1, 32, bias=True)'
    assert lin(x).size() == (3, 4, 32)
    assert str(lin) == 'Linear(16, 32, bias=True)'
コード例 #10
0
 def __init__(self, out_channels):
     super().__init__(aggr='add')
     self.lin = Linear(-1, out_channels)
コード例 #11
0
 def __init__(self):
     super().__init__()
     self.conv1 = GCNConv(3, 16, normalize=False)
     self.conv2 = GCNConv(16, 16, normalize=False)
     self.lin = Linear(16, 2)
コード例 #12
0
    train_loader = HGTLoader(data,
                             num_samples=[1024] * 4,
                             shuffle=True,
                             input_nodes=train_input_nodes,
                             **kwargs)
    val_loader = HGTLoader(data,
                           num_samples=[1024] * 4,
                           input_nodes=val_input_nodes,
                           **kwargs)

model = Sequential('x, edge_index', [
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (Linear(-1, dataset.num_classes), 'x -> x'),
])
model = to_hetero(model, data.metadata(), aggr='sum').to(device)


@torch.no_grad()
def init_params():
    # Initialize lazy parameters via forwarding a single batch to the model:
    batch = next(iter(train_loader))
    batch = batch.to(device, 'edge_index')
    model(batch.x_dict, batch.edge_index_dict)


def train():
    model.train()
コード例 #13
0
def test_load_lazy_linear(dim1, dim2):
    lin1 = Linear(dim1, 32)
    lin2 = Linear(dim1, 32)
    lin2.load_state_dict(lin1.state_dict())

    if dim1 != -1:
        assert torch.allclose(lin1.weight, lin2.weight)
        assert torch.allclose(lin1.bias, lin2.bias)
        assert not hasattr(lin1, '_hook')
        assert not hasattr(lin2, '_hook')
    else:
        assert isinstance(lin1.weight, UninitializedParameter)
        assert isinstance(lin2.weight, UninitializedParameter)
        assert hasattr(lin1, '_hook')
        assert hasattr(lin2, '_hook')

    with pytest.raises(RuntimeError, match="in state_dict"):
        lin1.load_state_dict({}, strict=True)
    lin1.load_state_dict({}, strict=False)