def __init__(self, hidden_channels, outer_out_channels, inner_out_channels, num_layers, batch_size): super().__init__() self.batch_size = batch_size self.dim = hidden_channels nn = Sequential(Linear(5, 128), ReLU(), Linear(128, self.dim * self.dim)) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv( { ('x_i', 'inner_edge_i', 'x_i'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_j', 'inner_edge_j', 'x_j'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_i', 'outer_edge_ij', 'x_j'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_j', 'outer_edge_ji', 'x_i'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_i', 'inner_edge_i', 'x_i'): NNConv(self.dim, self.dim, nn, aggr='mean'), ('x_j', 'inner_edge_j', 'x_j'): NNConv(self.dim, self.dim, nn, aggr='mean'), }, aggr='sum') self.convs.append(conv) self.lin = Linear(self.dim, outer_out_channels) self.lin_i = Linear(self.dim, inner_out_channels) self.lin_j = Linear(self.dim, inner_out_channels)
def __init__(self, hidden_channels, outer_out_channels, inner_out_channels, num_layers, batch_size, num_node_types, num_heads): super().__init__() self.batch_size = batch_size self.hidden_channels = hidden_channels self.heads = num_heads self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv( { ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_i', 'outer_edge_ij', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_j', 'outer_edge_ji', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads), }, aggr='sum') self.convs.append(conv) self.lin = Linear(self.hidden_channels, outer_out_channels) self.lin_i = Linear(self.hidden_channels, inner_out_channels) self.lin_j = Linear(self.hidden_channels, inner_out_channels)
def __init__(self, hidden_channels, out_channels, num_heads, num_layers): super().__init__() self.lin_dict = torch.nn.ModuleDict() for node_type in data.node_types: self.lin_dict[node_type] = Linear(-1, hidden_channels) self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum') self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels)
def __init__(self, hidden_channels: int, out_channels: int, dropout: float): super().__init__() self.dropout = dropout self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), hidden_channels) self.lin = Linear(-1, out_channels)
def __init__(self, metadata, hidden_channels, out_channels, num_layers): super().__init__() self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv({ edge_type: SAGEConv((-1, -1), hidden_channels) for edge_type in metadata[1] }) self.convs.append(conv) self.lin = Linear(hidden_channels, out_channels)
def test_identical_linear_default_initialization(lazy): x = torch.randn(3, 4, 16) torch.manual_seed(12345) lin1 = Linear(-1 if lazy else 16, 32) lin1(x) torch.manual_seed(12345) lin2 = PTLinear(16, 32) assert lin1.weight.tolist() == lin2.weight.tolist() assert lin1.bias.tolist() == lin2.bias.tolist() assert lin1(x).tolist() == lin2(x).tolist()
def test_load_lazy_linear(dim1, dim2): lin1 = Linear(dim1, 32) lin2 = Linear(dim1, 32) lin2.load_state_dict(lin1.state_dict()) if dim1 != -1: assert torch.allclose(lin1.weight, lin2.weight) assert torch.allclose(lin1.bias, lin2.bias) assert not hasattr(lin1, '_hook') assert not hasattr(lin2, '_hook') else: assert isinstance(lin1.weight, UninitializedParameter) assert isinstance(lin2.weight, UninitializedParameter) assert hasattr(lin1, '_hook') assert hasattr(lin2, '_hook')
def test_copy_linear(lazy): lin = Linear(-1 if lazy else 16, 32) copied_lin = copy.copy(lin) assert id(copied_lin) != id(lin) assert id(copied_lin.weight) == id(lin.weight) if not isinstance(copied_lin.weight, UninitializedParameter): assert copied_lin.weight.data_ptr() == lin.weight.data_ptr() assert id(copied_lin.bias) == id(lin.bias) assert copied_lin.bias.data_ptr() == lin.bias.data_ptr() copied_lin = copy.deepcopy(lin) assert id(copied_lin) != id(lin) assert id(copied_lin.weight) != id(lin.weight) if not isinstance(copied_lin.weight, UninitializedParameter): assert copied_lin.weight.data_ptr() != lin.weight.data_ptr() assert copied_lin.weight.tolist() == lin.weight.tolist() assert id(copied_lin.bias) != id(lin.bias) assert copied_lin.bias.data_ptr() != lin.bias.data_ptr() assert copied_lin.bias.tolist() == lin.bias.tolist()
def test_lazy_linear(weight, bias): x = torch.randn(3, 4, 16) lin = Linear(-1, 32, weight_initializer=weight, bias_initializer=bias) assert str(lin) == 'Linear(-1, 32, bias=True)' assert lin(x).size() == (3, 4, 32) assert str(lin) == 'Linear(16, 32, bias=True)'
def __init__(self, out_channels): super().__init__(aggr='add') self.lin = Linear(-1, out_channels)
def __init__(self): super().__init__() self.conv1 = GCNConv(3, 16, normalize=False) self.conv2 = GCNConv(16, 16, normalize=False) self.lin = Linear(16, 2)
train_loader = HGTLoader(data, num_samples=[1024] * 4, shuffle=True, input_nodes=train_input_nodes, **kwargs) val_loader = HGTLoader(data, num_samples=[1024] * 4, input_nodes=val_input_nodes, **kwargs) model = Sequential('x, edge_index', [ (SAGEConv((-1, -1), 64), 'x, edge_index -> x'), ReLU(inplace=True), (SAGEConv((-1, -1), 64), 'x, edge_index -> x'), ReLU(inplace=True), (Linear(-1, dataset.num_classes), 'x -> x'), ]) model = to_hetero(model, data.metadata(), aggr='sum').to(device) @torch.no_grad() def init_params(): # Initialize lazy parameters via forwarding a single batch to the model: batch = next(iter(train_loader)) batch = batch.to(device, 'edge_index') model(batch.x_dict, batch.edge_index_dict) def train(): model.train()
def test_load_lazy_linear(dim1, dim2): lin1 = Linear(dim1, 32) lin2 = Linear(dim1, 32) lin2.load_state_dict(lin1.state_dict()) if dim1 != -1: assert torch.allclose(lin1.weight, lin2.weight) assert torch.allclose(lin1.bias, lin2.bias) assert not hasattr(lin1, '_hook') assert not hasattr(lin2, '_hook') else: assert isinstance(lin1.weight, UninitializedParameter) assert isinstance(lin2.weight, UninitializedParameter) assert hasattr(lin1, '_hook') assert hasattr(lin2, '_hook') with pytest.raises(RuntimeError, match="in state_dict"): lin1.load_state_dict({}, strict=True) lin1.load_state_dict({}, strict=False)