Ejemplo n.º 1
0
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_nodes_dict, x_types, num_edge_types):
        super(RGCN, self).__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types

        # Create embeddings for all node types that do not come with features.
        self.emb_dict = ParameterDict({
            f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        I, H, O = in_channels, hidden_channels, out_channels  # noqa

        # Create `num_layers` many message passing layers.
        self.convs = ModuleList()
        self.convs.append(RGCNConv(I, H, num_node_types, num_edge_types))
        for _ in range(num_layers - 2):
            self.convs.append(RGCNConv(H, H, num_node_types, num_edge_types))
        self.convs.append(RGCNConv(H, O, self.num_node_types, num_edge_types))

        self.reset_parameters()
    def __init__(self,
                 feature_extractor_params,
                 l2=0.1,
                 gamma=0.1,
                 kernel='linear',
                 regularize_phi=False,
                 fixe_hps=False,
                 do_cv=False,
                 device='cuda'):
        """
        In the constructor we instantiate an lstm module
        """
        super(MetaKrrSingleKernelNetwork, self).__init__()
        self.feature_extractor = FeaturesExtractorFactory()(
            **feature_extractor_params)
        self.kernel = kernel
        self.do_cv = do_cv
        self.regularize_phi = regularize_phi
        self.device = device
        self.fixe_hps = fixe_hps

        if (not fixe_hps) and (not do_cv):
            self.l2 = Parameter(torch.FloatTensor([l2]).to(device))
        else:
            self.l2 = torch.FloatTensor([l2]).to(device)
            if do_cv and fixe_hps:
                self.l2s = torch.FloatTensor([l2]).to(device)

        if not do_cv:
            if fixe_hps:
                if kernel == 'rbf':
                    self.kernel_params = dict(
                        gamma=torch.FloatTensor([gamma]).to(device))
                elif kernel == 'sm':
                    #todo: Need to finish this
                    self.kernel_params = dict(
                        gamma=torch.FloatTensor([gamma]).to(device))
                else:
                    self.kernel_params = dict()
            else:
                if kernel == 'rbf':
                    self.kernel_params = ParameterDict(
                        dict(gamma=Parameter(
                            torch.FloatTensor([gamma]).to(device))))
                elif kernel == 'sm':
                    #todo: Need to finish this
                    self.kernel_params = ParameterDict(
                        dict(gamma=Parameter(
                            torch.FloatTensor([gamma]).to(device))))
                else:
                    self.kernel_params = dict()
        self.phis_norms = []
Ejemplo n.º 3
0
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_nodes_dict, x_types, num_edge_types, args):
        super(RGNN, self).__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.args = args

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types

        # Create embeddings for all node types that do not come with features.
        self.emb_dict = ParameterDict({
            f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        I, H, O = in_channels, hidden_channels, out_channels  # noqa

        if self.args.conv_name == 'rgcn':
            self.convs = ModuleList()
            self.convs.append(
                RGCNConv(I, H, num_node_types, num_edge_types, self.args))
            for _ in range(num_layers - 2):
                self.convs.append(
                    RGCNConv(H, H, num_node_types, num_edge_types, self.args))
            self.convs.append(
                RGCNConv(H, O, self.num_node_types, num_edge_types, self.args))
        else:
            self.convs = ModuleList()
            self.convs.append(
                RGSNConv(I, H, num_node_types, num_edge_types, self.args))
            for _ in range(num_layers - 2):
                self.convs.append(
                    RGSNConv(H, H, num_node_types, num_edge_types, self.args))
            self.convs.append(
                RGSNConv(H, O, self.num_node_types, num_edge_types, self.args))

        if self.args.Norm4:
            self.norm = torch.nn.LayerNorm(I)

        self.reset_parameters()
Ejemplo n.º 4
0
 def param_dict(self) -> ParameterDict:
     p = ParameterDict()
     if self.add_module_params_to_process:
         for nm, param in self.nn_module.named_parameters():
             nm = 'module_' + nm.replace('.', '_')
             p[nm] = param
     return p
Ejemplo n.º 5
0
 def __init__(self, rank: int):
     super().__init__(rank=rank)
     num_upper_tri = int(rank * (rank - 1) / 2)
     self._param_dict = ParameterDict()
     self._param_dict['cholesky_log_diag'] = Parameter(data=.01 *
                                                       torch.randn(rank))
     self._param_dict['cholesky_off_diag'] = Parameter(
         data=.01 * torch.randn(num_upper_tri))
Ejemplo n.º 6
0
class RGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_nodes_dict, x_types, edge_types):
        super(RGCN, self).__init__()

        node_types = list(num_nodes_dict.keys())

        self.embs = ParameterDict({
            key: Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        self.convs = ModuleList()
        self.convs.append(
            RGCNConv(in_channels, hidden_channels, node_types, edge_types))
        for _ in range(num_layers - 2):
            self.convs.append(
                RGCNConv(hidden_channels, hidden_channels, node_types,
                         edge_types))
        self.convs.append(
            RGCNConv(hidden_channels, out_channels, node_types, edge_types))

        self.dropout = dropout

        self.reset_parameters()

    def reset_parameters(self):
        for emb in self.embs.values():
            torch.nn.init.xavier_uniform_(emb)
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x_dict, adj_t_dict):
        x_dict = copy.copy(x_dict)
        for key, emb in self.embs.items():
            x_dict[key] = emb

        for conv in self.convs[:-1]:
            x_dict = conv(x_dict, adj_t_dict)
            for key, x in x_dict.items():
                x_dict[key] = F.relu(x)
                x_dict[key] = F.dropout(x,
                                        p=self.dropout,
                                        training=self.training)
        return self.convs[-1](x_dict, adj_t_dict)
Ejemplo n.º 7
0
    def param_dict(self) -> ModuleDict:
        p = ModuleDict()
        for process_name, process in self.processes.items():
            p[f"process:{process_name}"] = process.param_dict()

        p['measure_cov'] = self.measure_covariance.param_dict()

        p['init_state'] = ParameterDict([('mean', self.init_mean_params)])
        p['init_state'].update(self.init_covariance.param_dict().items())

        p['process_cov'] = self.process_covariance.param_dict()

        return p
Ejemplo n.º 8
0
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_nodes_dict, x_types, edge_types):
        super(RGCN, self).__init__()

        node_types = list(num_nodes_dict.keys())

        self.embs = ParameterDict({
            key: Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        self.convs = ModuleList()
        self.convs.append(
            RGCNConv(in_channels, hidden_channels, node_types, edge_types))
        for _ in range(num_layers - 2):
            self.convs.append(
                RGCNConv(hidden_channels, hidden_channels, node_types,
                         edge_types))
        self.convs.append(
            RGCNConv(hidden_channels, out_channels, node_types, edge_types))

        self.dropout = dropout

        self.reset_parameters()
Ejemplo n.º 9
0
class RGCN(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        num_layers,
        dropout,
        num_nodes_dict,
        x_types,
        num_edge_types,
    ):
        super(RGCN, self).__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types

        # Create embeddings for all node types that do not come with features.
        self.emb_dict = ParameterDict({
            f"{key}": Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        I, H, O = in_channels, hidden_channels, out_channels  # noqa

        # Create `num_layers` many message passing layers.
        self.convs = ModuleList()
        self.convs.append(RGCNConv(I, H, num_node_types, num_edge_types))
        for _ in range(num_layers - 2):
            self.convs.append(RGCNConv(H, H, num_node_types, num_edge_types))
        self.convs.append(RGCNConv(H, O, self.num_node_types, num_edge_types))

        self.reset_parameters()

    def reset_parameters(self):
        for emb in self.emb_dict.values():
            torch.nn.init.xavier_uniform_(emb)
        for conv in self.convs:
            conv.reset_parameters()

    def group_input(self, x_dict, node_type, local_node_idx):
        # Create global node feature matrix.
        h = torch.zeros((node_type.size(0), self.in_channels),
                        device=node_type.device)

        for key, x in x_dict.items():
            mask = node_type == key
            h[mask] = x[local_node_idx[mask]]

        for key, emb in self.emb_dict.items():
            mask = node_type == int(key)
            h[mask] = emb[local_node_idx[mask]]

        return h

    def forward(self, x_dict, edge_index, edge_type, node_type,
                local_node_idx):

        x = self.group_input(x_dict, node_type, local_node_idx)

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, edge_type, node_type)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)

        return x.log_softmax(dim=-1)

    def inference(self, x_dict, edge_index_dict, key2int):
        # We can perform full-batch inference on GPU.

        device = list(x_dict.values())[0].device

        x_dict = copy(x_dict)
        for key, emb in self.emb_dict.items():
            x_dict[int(key)] = emb

        adj_t_dict = {}
        for key, (row, col) in edge_index_dict.items():
            adj_t_dict[key] = SparseTensor(row=col, col=row).to(device)

        for i, conv in enumerate(self.convs):
            out_dict = {}

            for j, x in x_dict.items():
                out_dict[j] = conv.root_lins[j](x)

            for keys, adj_t in adj_t_dict.items():
                src_key, target_key = keys[0], keys[-1]
                out = out_dict[key2int[target_key]]
                tmp = adj_t.matmul(x_dict[key2int[src_key]], reduce="mean")
                out.add_(conv.rel_lins[key2int[keys]](tmp))

            if i != self.num_layers - 1:
                for j in range(self.num_node_types):
                    F.relu_(out_dict[j])

            x_dict = out_dict

        return x_dict
Ejemplo n.º 10
0
 def param_dict(self) -> ParameterDict:
     p = ParameterDict()
     if self.decay is not None:
         p['decay'] = self.decay.parameter
     return p
Ejemplo n.º 11
0
class RGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_nodes_dict, x_types, num_edge_types, args):
        super(RGNN, self).__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.args = args

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types

        # Create embeddings for all node types that do not come with features.
        self.emb_dict = ParameterDict({
            f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        I, H, O = in_channels, hidden_channels, out_channels  # noqa

        if self.args.conv_name == 'rgcn':
            self.convs = ModuleList()
            self.convs.append(
                RGCNConv(I, H, num_node_types, num_edge_types, self.args))
            for _ in range(num_layers - 2):
                self.convs.append(
                    RGCNConv(H, H, num_node_types, num_edge_types, self.args))
            self.convs.append(
                RGCNConv(H, O, self.num_node_types, num_edge_types, self.args))
        else:
            self.convs = ModuleList()
            self.convs.append(
                RGSNConv(I, H, num_node_types, num_edge_types, self.args))
            for _ in range(num_layers - 2):
                self.convs.append(
                    RGSNConv(H, H, num_node_types, num_edge_types, self.args))
            self.convs.append(
                RGSNConv(H, O, self.num_node_types, num_edge_types, self.args))

        if self.args.Norm4:
            self.norm = torch.nn.LayerNorm(I)

        self.reset_parameters()

    def reset_parameters(self):
        root = self.args.feat_dir
        Feat_list = [
            os.path.join(root, i) for i in [
                './author_FEAT.npy', './field_of_study_FEAT.npy',
                './institution_FEAT.npy'
            ]
        ]

        for emb, Feat_path in zip(self.emb_dict.values(), Feat_list):
            if self.args.FDFT:
                emb.data = torch.Tensor(np.load(Feat_path)).to(
                    self.args.device)
                emb.requires_grad = True
            else:
                torch.nn.init.xavier_uniform_(emb)

        for conv in self.convs:
            conv.reset_parameters()

        if self.args.Norm4:
            self.norm.reset_parameters()

    def group_input(self,
                    x_dict,
                    node_type,
                    local_node_idx,
                    n_id=None,
                    perturb=None):
        # Create global node feature matrix.
        if n_id is not None:
            node_type = node_type[n_id]
            local_node_idx = local_node_idx[n_id]

        h = torch.zeros((node_type.size(0), self.in_channels),
                        device=node_type.device)

        for key, x in x_dict.items():
            mask = node_type == int(key)
            h[mask] = x[local_node_idx[mask]] if perturb is None else x[
                local_node_idx[mask]] + perturb

        for key, emb in self.emb_dict.items():
            mask = node_type == int(key)
            h[mask] = emb[local_node_idx[mask]]

        return h

    def forward(self,
                n_id,
                x_dict,
                adjs,
                edge_type,
                node_type,
                local_node_idx,
                perturb=None):

        x = self.group_input(x_dict, node_type, local_node_idx, n_id, perturb)

        if self.args.FDFT:
            x = F.dropout(x, p=0.5, training=self.training)

        if self.args.Norm4:
            x = self.norm(x)

        node_type = node_type[n_id]

        for i, (edge_index, e_id, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target node embeddings.
            src_node_type = node_type
            node_type = node_type[:size[1]]  # Target node types.

            conv = self.convs[i]
            x = conv((x, x_target), edge_index, edge_type[e_id], node_type,
                     src_node_type)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)

        return x.log_softmax(dim=-1)
Ejemplo n.º 12
0
 def __init__(self, rank: int):
     super().__init__(rank=rank)
     self._param_dict = ParameterDict()
     self._param_dict['log_std_devs'] = Parameter(data=.01 *
                                                  torch.randn(rank))
Ejemplo n.º 13
0
 def param_dict(self) -> ParameterDict:
     p = ParameterDict()
     for k in ('position', 'velocity'):
         if k in self.decayed_transitions:
             p[k] = self.decayed_transitions[k].parameter
     return p
Ejemplo n.º 14
0
class RGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_nodes_dict, x_types, num_edge_types):
        super(RGCN, self).__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types

        # Create embeddings for all node types that do not come with features.
        self.emb_dict = ParameterDict({
            f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        I, H, O = in_channels, hidden_channels, out_channels  # noqa

        # Create `num_layers` many message passing layers.
        self.convs = ModuleList()
        self.convs.append(RGCNConv(I, H, num_node_types, num_edge_types))
        for _ in range(num_layers - 2):
            self.convs.append(RGCNConv(H, H, num_node_types, num_edge_types))
        self.convs.append(RGCNConv(H, O, self.num_node_types, num_edge_types))

        self.reset_parameters()

    def reset_parameters(self):
        for emb in self.emb_dict.values():
            torch.nn.init.xavier_uniform_(emb)
        for conv in self.convs:
            conv.reset_parameters()

    def group_input(self, x_dict, node_type, local_node_idx, n_id=None):
        # Create global node feature matrix.
        if n_id is not None:
            node_type = node_type[n_id]
            local_node_idx = local_node_idx[n_id]

        h = torch.zeros((node_type.size(0), self.in_channels),
                        device=node_type.device)

        for key, x in x_dict.items():
            mask = node_type == key
            h[mask] = x[local_node_idx[mask]]

        for key, emb in self.emb_dict.items():
            mask = node_type == int(key)
            h[mask] = emb[local_node_idx[mask]]

        return h

    def forward(self, n_id, x_dict, adjs, edge_type, node_type,
                local_node_idx):

        x = self.group_input(x_dict, node_type, local_node_idx, n_id)
        node_type = node_type[n_id]

        for i, (edge_index, e_id, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target node embeddings.
            node_type = node_type[:size[1]]  # Target node types.
            conv = self.convs[i]
            x = conv((x, x_target), edge_index, edge_type[e_id], node_type)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)

        return x.log_softmax(dim=-1)

    def inference(self, x_dict, edge_index_dict, key2int):
        # We can perform full-batch inference on GPU.

        device = list(x_dict.values())[0].device

        x_dict = copy(x_dict)
        for key, emb in self.emb_dict.items():
            x_dict[int(key)] = emb

        adj_t_dict = {}
        for key, (row, col) in edge_index_dict.items():
            adj_t_dict[key] = SparseTensor(row=col, col=row).to(device)

        for i, conv in enumerate(self.convs):
            out_dict = {}

            for j, x in x_dict.items():
                out_dict[j] = conv.root_lins[j](x)

            for keys, adj_t in adj_t_dict.items():
                src_key, target_key = keys[0], keys[-1]
                out = out_dict[key2int[target_key]]
                tmp = adj_t.matmul(x_dict[key2int[src_key]], reduce='mean')
                out.add_(conv.rel_lins[key2int[keys]](tmp))

            if i != self.num_layers - 1:
                for j in range(self.num_node_types):
                    F.relu_(out_dict[j])

            x_dict = out_dict

        return x_dict

    def fullBatch_inference(self, x_dict, edge_index, edge_type, node_type,
                            local_node_idx):
        x = self.group_input(x_dict, node_type, local_node_idx)
        node_type = node_type[edge_index[1].unique()]
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, edge_type, node_type)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        return x

    def group_inference(self, x_dict, edge_index_dict, key2int):
        x_dict = copy(x_dict)
        for k, emb in self.emb_dict.items():
            x_dict[int(k)] = emb
        for i, conv in enumerate(self.convs):
            out_dict = dict()
            for n, x_target in x_dict.items():
                edge_index_n = []
                edge_type_n = []
                x_nbr = []
                for k, e_i in edge_index_dict.items():
                    if key2int[k[-1]] == n:
                        edge_index_n.append(e_i)
                        edge_type_n.append(
                            e_i.new_full((e_i.size(1), ), key2int[k]))
                        x_nbr.append(x_dict[key2int[k[0]]])
                edge_index_n = torch.cat(edge_index_n, dim=1)
                edge_type_n = torch.cat(edge_type_n, dim=0)
                x = torch.cat([x_target] + x_nbr, dim=0)
                node_type_n = edge_index_n.new_full((x_target.size(0), ), n)
                x_out = conv((x, x_target), edge_index_n, edge_type_n,
                             node_type_n)
                if i != self.num_layers - 1:
                    x_out = F.relu(x_out)
                    x_out = F.dropout(x_out,
                                      p=self.dropout,
                                      training=self.training)
                out_dict[n] = x_out
            # if i != self.num_layers - 1:
            #     for j in range(self.num_node_types):
            #         F.relu_(out_dict[j])
            x_dict = out_dict
        return x_dict
Ejemplo n.º 15
0
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 num_layers,
                 dropout,
                 neg_slope,
                 heads,
                 num_relations,
                 num_nodes_dict,
                 x_types,
                 attention='Bilevel'):
        super(BRGCN, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.heads = heads
        self.num_relations = num_relations

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.emb_dict = ParameterDict({
            f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        self.convs = ModuleList()
        if attention == 'Bilevel':
            self.convs.append(
                BRGCNConv(self.in_channels, self.hidden_channels, neg_slope,
                          self.num_relations, num_node_types, self.heads,
                          self.dropout))
            for _ in range(num_layers - 2):
                self.convs.append(
                    BRGCNConv(self.hidden_channels, self.hidden_channels,
                              neg_slope, self.num_relations, num_node_types,
                              self.heads, self.dropout))
            self.convs.append(
                BRGCNConv(self.heads * self.hidden_channels, self.out_channels,
                          neg_slope, self.num_relations, num_node_types, 1,
                          self.dropout))
        elif attention == 'Node':
            self.convs = ModuleList()
            self.convs.append(
                BRGCNConvNode(self.in_channels, self.hidden_channels,
                              neg_slope, self.num_relations, num_node_types,
                              self.heads, self.dropout))
            for _ in range(num_layers - 2):
                self.convs.append(
                    BRGCNConvNode(self.hidden_channels, self.hidden_channels,
                                  self.neg_slope, self.num_relations,
                                  num_node_types, self.heads, self.dropout))
            self.convs.append(
                BRGCNConvNode(self.heads * self.hidden_channels,
                              self.out_channels, self.neg_slope,
                              self.num_relations, num_node_types, 1,
                              self.dropout))
        elif attention == 'Relation':
            self.convs = ModuleList()
            self.convs.append(
                BRGCNConvRel(self.in_channels, self.hidden_channels, neg_slope,
                             self.num_relations, num_node_types, self.heads,
                             self.dropout))
            for _ in range(num_layers - 2):
                self.convs.append(
                    BRGCNConvRel(self.hidden_channels, self.hidden_channels,
                                 self.neg_slope, self.num_relations,
                                 self.heads, num_node_types, self.dropout))
            self.convs.append(
                BRGCNConvRel(self.heads * self.hidden_channels,
                             self.out_channels, self.neg_slope,
                             self.num_relations, num_node_types, 1,
                             self.dropout))
        else:
            raise NotImplementedError

        self.reset_parameters()
Ejemplo n.º 16
0
class BRGCN(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 num_layers,
                 dropout,
                 neg_slope,
                 heads,
                 num_relations,
                 num_nodes_dict,
                 x_types,
                 attention='Bilevel'):
        super(BRGCN, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.heads = heads
        self.num_relations = num_relations

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.emb_dict = ParameterDict({
            f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        self.convs = ModuleList()
        if attention == 'Bilevel':
            self.convs.append(
                BRGCNConv(self.in_channels, self.hidden_channels, neg_slope,
                          self.num_relations, num_node_types, self.heads,
                          self.dropout))
            for _ in range(num_layers - 2):
                self.convs.append(
                    BRGCNConv(self.hidden_channels, self.hidden_channels,
                              neg_slope, self.num_relations, num_node_types,
                              self.heads, self.dropout))
            self.convs.append(
                BRGCNConv(self.heads * self.hidden_channels, self.out_channels,
                          neg_slope, self.num_relations, num_node_types, 1,
                          self.dropout))
        elif attention == 'Node':
            self.convs = ModuleList()
            self.convs.append(
                BRGCNConvNode(self.in_channels, self.hidden_channels,
                              neg_slope, self.num_relations, num_node_types,
                              self.heads, self.dropout))
            for _ in range(num_layers - 2):
                self.convs.append(
                    BRGCNConvNode(self.hidden_channels, self.hidden_channels,
                                  self.neg_slope, self.num_relations,
                                  num_node_types, self.heads, self.dropout))
            self.convs.append(
                BRGCNConvNode(self.heads * self.hidden_channels,
                              self.out_channels, self.neg_slope,
                              self.num_relations, num_node_types, 1,
                              self.dropout))
        elif attention == 'Relation':
            self.convs = ModuleList()
            self.convs.append(
                BRGCNConvRel(self.in_channels, self.hidden_channels, neg_slope,
                             self.num_relations, num_node_types, self.heads,
                             self.dropout))
            for _ in range(num_layers - 2):
                self.convs.append(
                    BRGCNConvRel(self.hidden_channels, self.hidden_channels,
                                 self.neg_slope, self.num_relations,
                                 self.heads, num_node_types, self.dropout))
            self.convs.append(
                BRGCNConvRel(self.heads * self.hidden_channels,
                             self.out_channels, self.neg_slope,
                             self.num_relations, num_node_types, 1,
                             self.dropout))
        else:
            raise NotImplementedError

        self.reset_parameters()

    def reset_parameters(self):
        for embedding in self.emb_dict.values():
            init.xavier_uniform_(embedding)
        for conv in self.convs:
            conv.reset_parameters()

    def group_input(self, x_dict, node_type, local_node_idx, n_id=None):
        # Create global node feature matrix.
        if n_id is not None:
            node_type = node_type[n_id]
            local_node_idx = local_node_idx[n_id]

        h = torch.zeros((node_type.size(0), self.in_channels),
                        device=node_type.device)

        for key, x in x_dict.items():
            mask = node_type == key
            h[mask] = x[local_node_idx[mask]]

        for key, emb in self.emb_dict.items():
            mask = node_type == int(key)
            h[mask] = emb[local_node_idx[mask]]

        return h

    def forward(self, n_id, x_dict, adjs, edge_type, node_type,
                local_node_idx):
        '''
        :param n_id: node index for the source nodes
        :param x_dict: Node embedding dictionary with node type  as key
        :param adjs: source to node structure, (edge_index, e_id, size)
        :param edge_type: the edge type for all edges
        :param node_type: the node type for all nodes
        :param local_node_idx: transform the global node id to a local one order for each type of node
        :return:
        '''
        x = self.group_input(x_dict, node_type, local_node_idx, n_id)
        node_type = node_type[n_id]

        for i, (edge_index, e_id, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target node embeddings.
            node_type = node_type[:size[1]]  # Target node types.
            conv = self.convs[i]
            x = conv((x, x_target), edge_index, edge_type[e_id], node_type)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        return x.log_softmax(dim=-1)

    def fullBatch_inference(self, x_dict, edge_index, edge_type, node_type,
                            local_node_idx):
        x = self.group_input(x_dict, node_type, local_node_idx)
        node_type = node_type[edge_index[1].unique()]
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, edge_type, node_type)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        return x

    def group_inference(self, x_dict, edge_index_dict, key2int):
        x_dict = copy(x_dict)
        for k, emb in self.emb_dict.items():
            x_dict[int(k)] = emb
        for i, conv in enumerate(self.convs):
            out_dict = dict()
            for n, x_target in x_dict.items():
                edge_index_n = []
                edge_type_n = []
                node_type_nbr = []
                x_nbr = []
                for k, e_i in edge_index_dict.items():
                    if key2int[k[-1]] == n:
                        edge_index_n.append(e_i)
                        edge_type_n.append(
                            e_i.new_full((e_i.size(1), ), key2int[k]))
                        node_type_nbr.append(
                            e_i.new_full((x_dict[key2int[k[0]]].size(0), ),
                                         key2int[k[0]]))
                        x_nbr.append(x_dict[key2int[k[0]]])
                edge_index_n = torch.cat(edge_index_n, dim=1)
                edge_type_n = torch.cat(edge_type_n, dim=0)
                x = torch.cat([x_target] + x_nbr, dim=0)
                node_type_target = edge_index_n.new_full((x_target.size(0), ),
                                                         n)
                node_type_n = torch.cat([node_type_target] + node_type_nbr,
                                        dim=0)
                x_out = conv((x, x_target), edge_index_n, edge_type_n,
                             node_type_n)
                if i != self.num_layers - 1:
                    x_out = F.relu(x_out)
                    x_out = F.dropout(x_out,
                                      p=self.dropout,
                                      training=self.training)
                out_dict[n] = x_out
            # if i != self.num_layers - 1:
            #     for j in range(self.num_node_types):
            #         F.relu_(out_dict[j])
            x_dict = out_dict
        return x_dict