Esempio n. 1
0
    def __init__(self,
                 out_dim,
                 node_embedding=False,
                 hidden_channels=16,
                 out_channels=None,
                 n_update_layers=4,
                 n_atom_types=MAX_ATOMIC_NUM,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=False,
                 activation=functions.identity,
                 n_edge_types=4):
        super(GIN, self).__init__()
        n_message_layer = 1 if weight_tying else n_update_layers
        n_readout_layer = n_update_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            self.first_mlp = GINUpdate(hidden_channels=hidden_channels,
                                       dropout_ratio=dropout_ratio,
                                       out_channels=hidden_channels).graph_mlp

            # two non-linear MLP part
            if out_channels is None:
                out_channels = hidden_channels
            self.update_layers = chainer.ChainList(*[
                GINUpdate(hidden_channels=hidden_channels,
                          dropout_ratio=dropout_ratio,
                          out_channels=(out_channels if i == n_message_layer -
                                        1 else hidden_channels))
                for i in range(n_message_layer)
            ])

            # Readout
            self.readout_layers = chainer.ChainList(*[
                GGNNReadout(out_dim=out_dim,
                            in_channels=hidden_channels * 2,
                            activation=activation,
                            activation_agg=activation)
                for _ in range(n_readout_layer)
            ])
        # end with

        self.node_embedding = node_embedding
        self.out_dim = out_dim
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
        self.n_edge_types = n_edge_types
Esempio n. 2
0
    def __init__(self, out_dim, hidden_dim=16,
                 n_layers=4, n_atom_types=MAX_ATOMIC_NUM,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=True,
                 activation=functions.identity):
        super(GIN, self).__init__()

        n_message_layer = 1 if weight_tying else n_layers
        n_readout_layer = n_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            # two non-linear MLP part
            self.update_layers = chainer.ChainList(*[GINUpdate(
                hidden_dim=hidden_dim, dropout_ratio=dropout_ratio)
                for _ in range(n_message_layer)])

            # Readout
            self.readout_layers = chainer.ChainList(*[GGNNReadout(
                out_dim=out_dim, hidden_dim=hidden_dim,
                activation=activation, activation_agg=activation)
                for _ in range(n_readout_layer)])
        # end with

        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
Esempio n. 3
0
    def __init__(self,
                 out_dim,
                 hidden_dim=16,
                 hidden_dim_super=16,
                 n_layers=4,
                 n_heads=8,
                 n_atom_types=MAX_ATOMIC_NUM,
                 n_super_feature=2 + 2 + MAX_ATOMIC_NUM * 2,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=True,
                 activation=functions.identity):
        super(GIN_GWM, self).__init__()

        n_message_layer = 1 if weight_tying else n_layers
        n_readout_layer = n_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            # two non-linear MLP part
            self.update_layers = chainer.ChainList(*[
                GINUpdate(hidden_dim=hidden_dim, dropout_ratio=dropout_ratio)
                for _ in range(n_message_layer)
            ])

            # GWM
            self.embed_super = links.Linear(in_size=n_super_feature,
                                            out_size=hidden_dim_super)
            self.gwm = GWM(hidden_dim=hidden_dim,
                           hidden_dim_super=hidden_dim_super,
                           n_layers=n_message_layer,
                           n_heads=n_heads,
                           dropout_ratio=dropout_ratio,
                           tying_flag=weight_tying,
                           gpu=-1)

            # Readout
            self.readout_layers = chainer.ChainList(*[
                GINReadout(out_dim=out_dim,
                           hidden_dim=hidden_dim,
                           activation=activation,
                           activation_agg=activation)
                for _ in range(n_readout_layer)
            ])
            self.linear_for_concat_super = links.Linear(in_size=None,
                                                        out_size=out_dim)
        # end with

        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.hidden_dim_super = hidden_dim_super
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
Esempio n. 4
0
def update():
    # type: () -> GINUpdate
    return GINUpdate(in_channels=in_channels,
                     hidden_channels=hidden_channels,
                     dropout_ratio=0)