class RGCNConv(torch.nn.Module): def __init__(self, in_channels, out_channels, node_types, edge_types): super(RGCNConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # `ModuleDict` does not allow tuples :( self.rel_lins = ModuleDict({ f'{key[0]}_{key[1]}_{key[2]}': Linear(in_channels, out_channels, bias=False) for key in edge_types }) self.root_lins = ModuleDict({ key: Linear(in_channels, out_channels, bias=True) for key in node_types }) self.reset_parameters() def reset_parameters(self): for lin in self.rel_lins.values(): lin.reset_parameters() for lin in self.root_lins.values(): lin.reset_parameters() def forward(self, x_dict, adj_t_dict): out_dict = {} for key, x in x_dict.items(): out_dict[key] = self.root_lins[key](x) for key, adj_t in adj_t_dict.items(): key_str = f'{key[0]}_{key[1]}_{key[2]}' x = x_dict[key[0]] out = self.rel_lins[key_str](adj_t.matmul(x, reduce='mean')) out_dict[key[2]].add_(out) return out_dict
class QUETCHEncoder(MetaModule): class Config(BaseConfig): window_size: int = 3 """Size of sliding window.""" embeddings: InputEmbeddingsConfig def __init__(self, vocabs: Dict[str, Vocabulary], config: Config, pre_load_model: bool = True): super().__init__(config=config) self.embeddings = ModuleDict() self.embeddings[const.TARGET] = TokenEmbeddings( num_embeddings=len(vocabs[const.TARGET]), pad_idx=vocabs[const.TARGET].pad_id, config=config.embeddings.target, vectors=vocabs[const.TARGET].vectors, ) self.embeddings[const.SOURCE] = TokenEmbeddings( num_embeddings=len(vocabs[const.SOURCE]), pad_idx=vocabs[const.SOURCE].pad_id, config=config.embeddings.source, vectors=vocabs[const.SOURCE].vectors, ) total_size = sum(emb.size() for emb in self.embeddings.values()) self._sizes = { const.TARGET: total_size * self.config.window_size, const.SOURCE: total_size * self.config.window_size, } @classmethod def input_data_encoders(cls, config: Config): return None # Use defaults, i.e., TextEncoder def size(self, field=None): if field: return self._sizes[field] return self._sizes def forward(self, batch_inputs): target_emb = self.embeddings[const.TARGET](batch_inputs[const.TARGET]) source_emb = self.embeddings[const.SOURCE](batch_inputs[const.SOURCE]) if const.TARGET_POS in self.embeddings: pos_emb, _ = self.embeddings[const.TARGET_POS]( batch_inputs[const.TARGET_POS]) target_emb = torch.cat((target_emb, pos_emb), dim=-1) if const.SOURCE_POS in self.embeddings: pos_emb, _ = self.embeddings[const.SOURCE_POS]( batch_inputs[const.SOURCE_POS]) source_emb = torch.cat((source_emb, pos_emb), dim=-1) # (bs, source_steps, target_steps) matrix_alignments = batch_inputs[const.ALIGNMENTS] # Timesteps might actually be longer when the last words are not aligned pad = [0, 0, 0, 0] if matrix_alignments.size(1) < source_emb.size(1): pad[3] = source_emb.size(1) - matrix_alignments.size(1) if matrix_alignments.size(2) < target_emb.size(1): pad[1] = target_emb.size(1) - matrix_alignments.size(2) if any(pad): matrix_alignments = F.pad(matrix_alignments, pad=pad, value=0) h_target = convolve_tensor( target_emb, self.config.window_size, pad_value=self.embeddings[const.TARGET].pad_idx, ) h_source = convolve_tensor( source_emb, self.config.window_size, pad_value=self.embeddings[const.SOURCE].pad_idx, ) h_target = h_target.contiguous().view(h_target.shape[0], h_target.shape[1], -1) h_source = h_source.contiguous().view(h_source.shape[0], h_source.shape[1], -1) # Target side matrix_alignments_t = matrix_alignments.transpose(1, 2).float() # (bs, target_steps, source_steps) x (bs, source_steps, *) # -> (bs, target_steps, *) # Take the mean of aligned tokens h_source_to_target = torch.matmul(matrix_alignments_t, h_source) z = matrix_alignments_t.sum(dim=2, keepdim=True) z[z == 0] = 1.0 h_source_to_target = h_source_to_target / z # h_source_to_target[h_source_to_target.sum(-1) == 0] = self.unaligned_source # assert torch.all(torch.eq(h_source_to_target, h_source_to_target)) features_target = torch.cat((h_source_to_target, h_target), dim=-1) # Source side matrix_alignments = matrix_alignments.float() # (bs, source_steps, target_steps) x (bs, target_steps, *) # -> (bs, source_steps, *) # Take the mean of aligned tokens h_target_to_source = torch.matmul(matrix_alignments, h_target) z = matrix_alignments.sum(dim=2, keepdim=True) z[z == 0] = 1.0 h_target_to_source = h_target_to_source / z # h_target_to_source[h_target_to_source.sum(-1) == 0] = self.unaligned_target features_source = torch.cat((h_source, h_target_to_source), dim=-1) # (bs, ts, window * emb) -> (bs, ts, 2 * window * emb) features = { const.TARGET: features_target, const.SOURCE: features_source } return features
class SentenceEmbeddings(Module): @dataclass class Options(OptionsBase): dim_word: "word embedding dim" = 100 dim_postag: "postag embedding dim. 0 for not using postag" = 100 dim_char_input: "character embedding input dim" = 100 dim_char: "character embedding dim. 0 for not using character" = 100 word_dropout: "word embedding dropout" = 0.4 postag_dropout: "postag embedding dropout" = 0.2 character_embedding: CharacterEmbedding.Options = field( default_factory=CharacterEmbedding.Options) input_layer_norm: "Use layer norm on input embeddings" = True mode: str = argfield("concat", choices=["add", "concat"]) replace_unk_with_chars: bool = False def __init__(self, hparams: "SentenceEmbeddings.Options", statistics, plugins=None): super().__init__() self.hparams = hparams self.mode = hparams.mode self.plugins = ModuleDict(plugins) if plugins is not None else {} # embedding input_dims = {} if hparams.dim_word != 0: self.word_embeddings = Embedding(len(statistics.words), hparams.dim_word, padding_idx=0) self.word_dropout = FeatureDropout2(hparams.word_dropout) input_dims["word"] = hparams.dim_word if hparams.dim_postag != 0: self.pos_embeddings = Embedding(len(statistics.postags), hparams.dim_postag, padding_idx=0) self.pos_dropout = FeatureDropout2(hparams.postag_dropout) input_dims["postag"] = hparams.dim_postag if hparams.dim_char > 0: self.bilm = None self.character_lookup = Embedding(len(statistics.characters), hparams.dim_char_input) self.char_embeded = CharacterEmbedding.get( hparams.character_embedding, dim_char_input=hparams.dim_char_input, input_size=hparams.dim_char) if not hparams.replace_unk_with_chars: input_dims["char"] = hparams.dim_char else: assert hparams.dim_word == hparams.dim_char else: self.character_lookup = None for name, plugin in self.plugins.items(): input_dims[name] = plugin.output_dim if hparams.mode == "concat": self.output_dim = sum(input_dims.values()) else: assert hparams.mode == "add" uniq_input_dims = list(set(input_dims.values())) if len(uniq_input_dims) != 1: raise ValueError(f"Different input dims: {input_dims}") print(input_dims) self.output_dim = uniq_input_dims[0] self.input_layer_norm = LayerNorm(self.output_dim, eps=1e-6) \ if hparams.input_layer_norm else None def reset_parameters(self): torch.nn.init.xavier_normal_(self.word_embeddings.weight.data) if self.hparams.dim_postag != 0: torch.nn.init.xavier_normal_(self.pos_embeddings.weight.data) if self.character_lookup is not None: torch.nn.init.xavier_normal_(self.character_lookup.weight.data) def forward(self, inputs, unk_idx=1): all_features = [] if self.character_lookup is not None: # use character embedding instead # batch_size, bucket_size, word_length, embedding_dims char_embeded_4d = self.character_lookup(inputs.chars) word_embeded_by_char = self.char_embeded(inputs.word_lengths, char_embeded_4d) if not self.hparams.replace_unk_with_chars: all_features.append(word_embeded_by_char) if self.hparams.dim_word != 0: word_embedding = self.word_dropout( self.word_embeddings(inputs.words)) if self.hparams.dim_char and self.hparams.replace_unk_with_chars: unk = inputs.words.eq(unk_idx) # noinspection PyUnboundLocalVariable unk_word_embeded_by_char = word_embeded_by_char[unk] word_embedding[unk] = unk_word_embeded_by_char all_features.append(word_embedding) if self.hparams.dim_postag != 0: all_features.append( self.pos_dropout(self.pos_embeddings(inputs.postags))) for plugin in self.plugins.values(): plugin_output = plugin(inputs) # FIXME: remove these two ugly tweak if plugin_output.shape[1] == inputs.words.shape[1] + 2: plugin_output = plugin_output[:, 1:-1] # pad external embedding to dim_word # if self.mode == "add" and plugin_output.shape[-1] < self.hparams.dim_word: # plugin_output = torch.cat( # [plugin_output, # plugin_output.new_zeros( # (*inputs.words.shape, self.hparams.dim_word - plugin_output.shape[-1]))], -1) all_features.append(plugin_output) if self.mode == "concat": total_input_embeded = torch.cat(all_features, -1) else: total_input_embeded = sum(all_features) if self.input_layer_norm is not None: total_input_embeded = self.input_layer_norm(total_input_embeded) return total_input_embeded
class HeteroConv(Module): r"""A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to :attr:`aggr`. In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is especially useful if you want to apply different message passing modules for different edge types. .. code-block:: python hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author'] Args: convs (Dict[Tuple[str, str, str], Module]): A dictionary holding a bipartite :class:`~torch_geometric.nn.conv.MessagePassing` layer for each individual edge type. aggr (string, optional): The aggregation scheme to use for grouping node embeddings generated by different relations. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`None`). (default: :obj:`"sum"`) """ def __init__(self, convs: Dict[EdgeType, Module], aggr: Optional[str] = "sum"): super().__init__() src_node_types = set([key[0] for key in convs.keys()]) dst_node_types = set([key[-1] for key in convs.keys()]) if len(src_node_types - dst_node_types) > 0: warnings.warn( f"There exist node types ({src_node_types - dst_node_types}) " f"whose representations do not get updated during message " f"passing as they do not occur as destination type in any " f"edge type. This may lead to unexpected behaviour.") self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) self.aggr = aggr def reset_parameters(self): for conv in self.convs.values(): conv.reset_parameters() def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj], *args_dict, **kwargs_dict, ) -> Dict[NodeType, Tensor]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary holding graph connectivity information for each individual edge type. *args_dict (optional): Additional forward arguments of invididual :class:`torch_geometric.nn.conv.MessagePassing` layers. **kwargs_dict (optional): Additional forward arguments of individual :class:`torch_geometric.nn.conv.MessagePassing` layers. For example, if a specific GNN layer at edge type :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a forward argument, then you can pass them to :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via :obj:`edge_attr_dict = { edge_type: edge_attr }`. """ out_dict = defaultdict(list) for edge_type, edge_index in edge_index_dict.items(): src, rel, dst = edge_type str_edge_type = '__'.join(edge_type) if str_edge_type not in self.convs: continue args = [] for value_dict in args_dict: if edge_type in value_dict: args.append(value_dict[edge_type]) elif src == dst and src in value_dict: args.append(value_dict[src]) elif src in value_dict or dst in value_dict: args.append( (value_dict.get(src, None), value_dict.get(dst, None))) kwargs = {} for arg, value_dict in kwargs_dict.items(): arg = arg[:-5] # `{*}_dict` if edge_type in value_dict: kwargs[arg] = value_dict[edge_type] elif src == dst and src in value_dict: kwargs[arg] = value_dict[src] elif src in value_dict or dst in value_dict: kwargs[arg] = (value_dict.get(src, None), value_dict.get(dst, None)) conv = self.convs[str_edge_type] if src == dst: out = conv(x_dict[src], edge_index, *args, **kwargs) else: out = conv((x_dict[src], x_dict[dst]), edge_index, *args, **kwargs) out_dict[dst].append(out) for key, value in out_dict.items(): out_dict[key] = group(value, self.aggr) return out_dict def __repr__(self) -> str: return f'{self.__class__.__name__}(num_relations={len(self.convs)})'
class HeteroConv(Module): r"""A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to :attr:`aggr`. In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is especially useful if you want to apply different message passing modules for different edge types. .. code-block:: python hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv(-1, 64), ('paper', 'written_by', 'author'): GATConv(-1, 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author'] Args: convs (Dict[Tuple[str, str, str], Module]): A dictionary holding a bipartite :class:`~torch_geometric.nn.conv.MessagePassing` layer for each individual edge type. aggr (string, optional): The aggregation scheme to use for grouping node embeddings generated by different relations. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`). (default: :obj:`"sum"`) """ def __init__(self, convs: Dict[EdgeType, Module], aggr: str = "sum"): super().__init__() self.keys = list(convs.keys()) self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) self.aggr = aggr def reset_parameters(self): for conv in self.convs.values(): conv.reset_parameters() def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Union[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor]], edge_weight_dict: Optional[Dict[EdgeType, Tensor]] = None, edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, ) -> Dict[NodeType, Tensor]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary holding graph connectivity information for each individual edge type. edge_weight_dict (Dict[Tuple[str, str, str], Tensor], optional): A dictionary holding one-dimensional edge weight information for individual edge types. (default: :obj:`None`) edge_attr_dict (Dict[Tuple[str, str, str], Tensor], optional): A dictionary holding multi-dimensional edge feature information for individual edge types. (default: :obj:`None`) """ out_dict = defaultdict(list) for key in self.keys: src, _, dst = key conv = self.convs['__'.join(key)] kwargs = {} if edge_weight_dict is not None and key in edge_weight_dict: kwargs['edge_weight'] = edge_weight_dict[key] if edge_weight_dict is not None and key in edge_attr_dict: kwargs['edge_attr'] = edge_attr_dict[key] if src == dst: out = conv(x=x_dict[src], edge_index=edge_index_dict[key], **kwargs) else: out = conv(x=(x_dict[src], x_dict[dst]), edge_index=edge_index_dict[key], **kwargs) out_dict[dst].append(out) for key, values in out_dict.items(): out_dict[key] = group(values, self.aggr) return out_dict def __repr__(self) -> str: return f'{self.__class__.__name__}(num_relations={len(self.convs)})'