示例#1
0
class RGCNConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, node_types, edge_types):
        super(RGCNConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # `ModuleDict` does not allow tuples :(
        self.rel_lins = ModuleDict({
            f'{key[0]}_{key[1]}_{key[2]}': Linear(in_channels,
                                                  out_channels,
                                                  bias=False)
            for key in edge_types
        })

        self.root_lins = ModuleDict({
            key: Linear(in_channels, out_channels, bias=True)
            for key in node_types
        })

        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.rel_lins.values():
            lin.reset_parameters()
        for lin in self.root_lins.values():
            lin.reset_parameters()

    def forward(self, x_dict, adj_t_dict):
        out_dict = {}
        for key, x in x_dict.items():
            out_dict[key] = self.root_lins[key](x)

        for key, adj_t in adj_t_dict.items():
            key_str = f'{key[0]}_{key[1]}_{key[2]}'
            x = x_dict[key[0]]
            out = self.rel_lins[key_str](adj_t.matmul(x, reduce='mean'))
            out_dict[key[2]].add_(out)

        return out_dict
示例#2
0
class QUETCHEncoder(MetaModule):
    class Config(BaseConfig):
        window_size: int = 3
        """Size of sliding window."""

        embeddings: InputEmbeddingsConfig

    def __init__(self,
                 vocabs: Dict[str, Vocabulary],
                 config: Config,
                 pre_load_model: bool = True):
        super().__init__(config=config)

        self.embeddings = ModuleDict()
        self.embeddings[const.TARGET] = TokenEmbeddings(
            num_embeddings=len(vocabs[const.TARGET]),
            pad_idx=vocabs[const.TARGET].pad_id,
            config=config.embeddings.target,
            vectors=vocabs[const.TARGET].vectors,
        )
        self.embeddings[const.SOURCE] = TokenEmbeddings(
            num_embeddings=len(vocabs[const.SOURCE]),
            pad_idx=vocabs[const.SOURCE].pad_id,
            config=config.embeddings.source,
            vectors=vocabs[const.SOURCE].vectors,
        )

        total_size = sum(emb.size() for emb in self.embeddings.values())
        self._sizes = {
            const.TARGET: total_size * self.config.window_size,
            const.SOURCE: total_size * self.config.window_size,
        }

    @classmethod
    def input_data_encoders(cls, config: Config):
        return None  # Use defaults, i.e., TextEncoder

    def size(self, field=None):
        if field:
            return self._sizes[field]
        return self._sizes

    def forward(self, batch_inputs):
        target_emb = self.embeddings[const.TARGET](batch_inputs[const.TARGET])
        source_emb = self.embeddings[const.SOURCE](batch_inputs[const.SOURCE])

        if const.TARGET_POS in self.embeddings:
            pos_emb, _ = self.embeddings[const.TARGET_POS](
                batch_inputs[const.TARGET_POS])
            target_emb = torch.cat((target_emb, pos_emb), dim=-1)
        if const.SOURCE_POS in self.embeddings:
            pos_emb, _ = self.embeddings[const.SOURCE_POS](
                batch_inputs[const.SOURCE_POS])
            source_emb = torch.cat((source_emb, pos_emb), dim=-1)

        # (bs, source_steps, target_steps)
        matrix_alignments = batch_inputs[const.ALIGNMENTS]
        # Timesteps might actually be longer when the last words are not aligned
        pad = [0, 0, 0, 0]
        if matrix_alignments.size(1) < source_emb.size(1):
            pad[3] = source_emb.size(1) - matrix_alignments.size(1)
        if matrix_alignments.size(2) < target_emb.size(1):
            pad[1] = target_emb.size(1) - matrix_alignments.size(2)
        if any(pad):
            matrix_alignments = F.pad(matrix_alignments, pad=pad, value=0)

        h_target = convolve_tensor(
            target_emb,
            self.config.window_size,
            pad_value=self.embeddings[const.TARGET].pad_idx,
        )
        h_source = convolve_tensor(
            source_emb,
            self.config.window_size,
            pad_value=self.embeddings[const.SOURCE].pad_idx,
        )
        h_target = h_target.contiguous().view(h_target.shape[0],
                                              h_target.shape[1], -1)
        h_source = h_source.contiguous().view(h_source.shape[0],
                                              h_source.shape[1], -1)

        # Target side
        matrix_alignments_t = matrix_alignments.transpose(1, 2).float()
        # (bs, target_steps, source_steps) x (bs, source_steps, *)
        # -> (bs, target_steps, *)
        # Take the mean of aligned tokens
        h_source_to_target = torch.matmul(matrix_alignments_t, h_source)
        z = matrix_alignments_t.sum(dim=2, keepdim=True)
        z[z == 0] = 1.0
        h_source_to_target = h_source_to_target / z
        # h_source_to_target[h_source_to_target.sum(-1) == 0] = self.unaligned_source
        # assert torch.all(torch.eq(h_source_to_target, h_source_to_target))
        features_target = torch.cat((h_source_to_target, h_target), dim=-1)

        # Source side
        matrix_alignments = matrix_alignments.float()
        # (bs, source_steps, target_steps) x (bs, target_steps, *)
        # -> (bs, source_steps, *)
        # Take the mean of aligned tokens
        h_target_to_source = torch.matmul(matrix_alignments, h_target)
        z = matrix_alignments.sum(dim=2, keepdim=True)
        z[z == 0] = 1.0
        h_target_to_source = h_target_to_source / z
        # h_target_to_source[h_target_to_source.sum(-1) == 0] = self.unaligned_target
        features_source = torch.cat((h_source, h_target_to_source), dim=-1)

        # (bs, ts, window * emb) -> (bs, ts, 2 * window * emb)
        features = {
            const.TARGET: features_target,
            const.SOURCE: features_source
        }

        return features
示例#3
0
class SentenceEmbeddings(Module):
    @dataclass
    class Options(OptionsBase):
        dim_word: "word embedding dim" = 100
        dim_postag: "postag embedding dim. 0 for not using postag" = 100
        dim_char_input: "character embedding input dim" = 100
        dim_char: "character embedding dim. 0 for not using character" = 100
        word_dropout: "word embedding dropout" = 0.4
        postag_dropout: "postag embedding dropout" = 0.2
        character_embedding: CharacterEmbedding.Options = field(
            default_factory=CharacterEmbedding.Options)
        input_layer_norm: "Use layer norm on input embeddings" = True
        mode: str = argfield("concat", choices=["add", "concat"])
        replace_unk_with_chars: bool = False

    def __init__(self,
                 hparams: "SentenceEmbeddings.Options",
                 statistics,
                 plugins=None):

        super().__init__()
        self.hparams = hparams
        self.mode = hparams.mode
        self.plugins = ModuleDict(plugins) if plugins is not None else {}

        # embedding
        input_dims = {}
        if hparams.dim_word != 0:
            self.word_embeddings = Embedding(len(statistics.words),
                                             hparams.dim_word,
                                             padding_idx=0)
            self.word_dropout = FeatureDropout2(hparams.word_dropout)
            input_dims["word"] = hparams.dim_word

        if hparams.dim_postag != 0:
            self.pos_embeddings = Embedding(len(statistics.postags),
                                            hparams.dim_postag,
                                            padding_idx=0)
            self.pos_dropout = FeatureDropout2(hparams.postag_dropout)
            input_dims["postag"] = hparams.dim_postag

        if hparams.dim_char > 0:
            self.bilm = None
            self.character_lookup = Embedding(len(statistics.characters),
                                              hparams.dim_char_input)
            self.char_embeded = CharacterEmbedding.get(
                hparams.character_embedding,
                dim_char_input=hparams.dim_char_input,
                input_size=hparams.dim_char)
            if not hparams.replace_unk_with_chars:
                input_dims["char"] = hparams.dim_char
            else:
                assert hparams.dim_word == hparams.dim_char
        else:
            self.character_lookup = None

        for name, plugin in self.plugins.items():
            input_dims[name] = plugin.output_dim

        if hparams.mode == "concat":
            self.output_dim = sum(input_dims.values())
        else:
            assert hparams.mode == "add"
            uniq_input_dims = list(set(input_dims.values()))
            if len(uniq_input_dims) != 1:
                raise ValueError(f"Different input dims: {input_dims}")
            print(input_dims)
            self.output_dim = uniq_input_dims[0]

        self.input_layer_norm = LayerNorm(self.output_dim, eps=1e-6) \
            if hparams.input_layer_norm else None

    def reset_parameters(self):
        torch.nn.init.xavier_normal_(self.word_embeddings.weight.data)
        if self.hparams.dim_postag != 0:
            torch.nn.init.xavier_normal_(self.pos_embeddings.weight.data)
        if self.character_lookup is not None:
            torch.nn.init.xavier_normal_(self.character_lookup.weight.data)

    def forward(self, inputs, unk_idx=1):
        all_features = []

        if self.character_lookup is not None:
            # use character embedding instead
            # batch_size, bucket_size, word_length, embedding_dims
            char_embeded_4d = self.character_lookup(inputs.chars)
            word_embeded_by_char = self.char_embeded(inputs.word_lengths,
                                                     char_embeded_4d)
            if not self.hparams.replace_unk_with_chars:
                all_features.append(word_embeded_by_char)

        if self.hparams.dim_word != 0:
            word_embedding = self.word_dropout(
                self.word_embeddings(inputs.words))
            if self.hparams.dim_char and self.hparams.replace_unk_with_chars:
                unk = inputs.words.eq(unk_idx)
                # noinspection PyUnboundLocalVariable
                unk_word_embeded_by_char = word_embeded_by_char[unk]
                word_embedding[unk] = unk_word_embeded_by_char
            all_features.append(word_embedding)

        if self.hparams.dim_postag != 0:
            all_features.append(
                self.pos_dropout(self.pos_embeddings(inputs.postags)))

        for plugin in self.plugins.values():
            plugin_output = plugin(inputs)
            # FIXME: remove these two ugly tweak
            if plugin_output.shape[1] == inputs.words.shape[1] + 2:
                plugin_output = plugin_output[:, 1:-1]
            # pad external embedding to dim_word
            # if self.mode == "add" and plugin_output.shape[-1] < self.hparams.dim_word:
            #     plugin_output = torch.cat(
            #         [plugin_output,
            #          plugin_output.new_zeros(
            #              (*inputs.words.shape, self.hparams.dim_word - plugin_output.shape[-1]))], -1)
            all_features.append(plugin_output)

        if self.mode == "concat":
            total_input_embeded = torch.cat(all_features, -1)
        else:
            total_input_embeded = sum(all_features)

        if self.input_layer_norm is not None:
            total_input_embeded = self.input_layer_norm(total_input_embeded)

        return total_input_embeded
示例#4
0
class HeteroConv(Module):
    r"""A generic wrapper for computing graph convolution on heterogeneous
    graphs.
    This layer will pass messages from source nodes to target nodes based on
    the bipartite GNN layer given for a specific edge type.
    If multiple relations point to the same destination, their results will be
    aggregated according to :attr:`aggr`.
    In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is
    especially useful if you want to apply different message passing modules
    for different edge types.

    .. code-block:: python

        hetero_conv = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, 64),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
            ('paper', 'written_by', 'author'): GATConv((-1, -1), 64),
        }, aggr='sum')

        out_dict = hetero_conv(x_dict, edge_index_dict)

        print(list(out_dict.keys()))
        >>> ['paper', 'author']

    Args:
        convs (Dict[Tuple[str, str, str], Module]): A dictionary
            holding a bipartite
            :class:`~torch_geometric.nn.conv.MessagePassing` layer for each
            individual edge type.
        aggr (string, optional): The aggregation scheme to use for grouping
            node embeddings generated by different relations.
            (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
            :obj:`None`). (default: :obj:`"sum"`)
    """
    def __init__(self,
                 convs: Dict[EdgeType, Module],
                 aggr: Optional[str] = "sum"):
        super().__init__()

        src_node_types = set([key[0] for key in convs.keys()])
        dst_node_types = set([key[-1] for key in convs.keys()])
        if len(src_node_types - dst_node_types) > 0:
            warnings.warn(
                f"There exist node types ({src_node_types - dst_node_types}) "
                f"whose representations do not get updated during message "
                f"passing as they do not occur as destination type in any "
                f"edge type. This may lead to unexpected behaviour.")

        self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()})
        self.aggr = aggr

    def reset_parameters(self):
        for conv in self.convs.values():
            conv.reset_parameters()

    def forward(
        self,
        x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Dict[EdgeType, Adj],
        *args_dict,
        **kwargs_dict,
    ) -> Dict[NodeType, Tensor]:
        r"""
        Args:
            x_dict (Dict[str, Tensor]): A dictionary holding node feature
                information for each individual node type.
            edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary
                holding graph connectivity information for each individual
                edge type.
            *args_dict (optional): Additional forward arguments of invididual
                :class:`torch_geometric.nn.conv.MessagePassing` layers.
            **kwargs_dict (optional): Additional forward arguments of
                individual :class:`torch_geometric.nn.conv.MessagePassing`
                layers.
                For example, if a specific GNN layer at edge type
                :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a
                forward argument, then you can pass them to
                :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via
                :obj:`edge_attr_dict = { edge_type: edge_attr }`.
        """
        out_dict = defaultdict(list)
        for edge_type, edge_index in edge_index_dict.items():
            src, rel, dst = edge_type

            str_edge_type = '__'.join(edge_type)
            if str_edge_type not in self.convs:
                continue

            args = []
            for value_dict in args_dict:
                if edge_type in value_dict:
                    args.append(value_dict[edge_type])
                elif src == dst and src in value_dict:
                    args.append(value_dict[src])
                elif src in value_dict or dst in value_dict:
                    args.append(
                        (value_dict.get(src, None), value_dict.get(dst, None)))

            kwargs = {}
            for arg, value_dict in kwargs_dict.items():
                arg = arg[:-5]  # `{*}_dict`
                if edge_type in value_dict:
                    kwargs[arg] = value_dict[edge_type]
                elif src == dst and src in value_dict:
                    kwargs[arg] = value_dict[src]
                elif src in value_dict or dst in value_dict:
                    kwargs[arg] = (value_dict.get(src, None),
                                   value_dict.get(dst, None))

            conv = self.convs[str_edge_type]

            if src == dst:
                out = conv(x_dict[src], edge_index, *args, **kwargs)
            else:
                out = conv((x_dict[src], x_dict[dst]), edge_index, *args,
                           **kwargs)

            out_dict[dst].append(out)

        for key, value in out_dict.items():
            out_dict[key] = group(value, self.aggr)

        return out_dict

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_relations={len(self.convs)})'
示例#5
0
class HeteroConv(Module):
    r"""A generic wrapper for computing graph convolution on heterogeneous
    graphs.
    This layer will pass messages from source nodes to target nodes based on
    the bipartite GNN layer given for a specific edge type.
    If multiple relations point to the same destination, their results will be
    aggregated according to :attr:`aggr`.
    In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is
    especially useful if you want to apply different message passing modules
    for different edge types.

    .. code-block:: python

        hetero_conv = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, 64),
            ('author', 'writes', 'paper'): SAGEConv(-1, 64),
            ('paper', 'written_by', 'author'): GATConv(-1, 64),
        }, aggr='sum')

        out_dict = hetero_conv(x_dict, edge_index_dict)

        print(list(out_dict.keys()))
        >>> ['paper', 'author']

    Args:
        convs (Dict[Tuple[str, str, str], Module]): A dictionary
            holding a bipartite
            :class:`~torch_geometric.nn.conv.MessagePassing` layer for each
            individual edge type.
        aggr (string, optional): The aggregation scheme to use for grouping
            node embeddings generated by different relations.
            (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`).
            (default: :obj:`"sum"`)
    """
    def __init__(self, convs: Dict[EdgeType, Module], aggr: str = "sum"):
        super().__init__()
        self.keys = list(convs.keys())
        self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()})
        self.aggr = aggr

    def reset_parameters(self):
        for conv in self.convs.values():
            conv.reset_parameters()

    def forward(
        self,
        x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Union[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor]],
        edge_weight_dict: Optional[Dict[EdgeType, Tensor]] = None,
        edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,
    ) -> Dict[NodeType, Tensor]:
        r"""
        Args:
            x_dict (Dict[str, Tensor]): A dictionary holding node feature
                information for each individual node type.
            edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary
                holding graph connectivity information for each individual
                edge type.
            edge_weight_dict (Dict[Tuple[str, str, str], Tensor], optional): A
                dictionary holding one-dimensional edge weight information
                for individual edge types. (default: :obj:`None`)
            edge_attr_dict (Dict[Tuple[str, str, str], Tensor], optional): A
                dictionary holding multi-dimensional edge feature information
                for individual edge types. (default: :obj:`None`)
        """

        out_dict = defaultdict(list)
        for key in self.keys:
            src, _, dst = key
            conv = self.convs['__'.join(key)]

            kwargs = {}
            if edge_weight_dict is not None and key in edge_weight_dict:
                kwargs['edge_weight'] = edge_weight_dict[key]
            if edge_weight_dict is not None and key in edge_attr_dict:
                kwargs['edge_attr'] = edge_attr_dict[key]

            if src == dst:
                out = conv(x=x_dict[src],
                           edge_index=edge_index_dict[key],
                           **kwargs)
            else:
                out = conv(x=(x_dict[src], x_dict[dst]),
                           edge_index=edge_index_dict[key],
                           **kwargs)

            out_dict[dst].append(out)

        for key, values in out_dict.items():
            out_dict[key] = group(values, self.aggr)

        return out_dict

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_relations={len(self.convs)})'