Ejemplo n.º 1
0
    def __init__(self, out_channels=64, num_edge_type=4, ch_list=None,
                 n_atom_types=MAX_ATOMIC_NUM, input_type='int',
                 scale_adj=False, readout=True):

        super(RelGCN, self).__init__()
        if ch_list is None:
            ch_list = [16, 128, 64]
        with self.init_scope():
            if input_type == 'int':
                self.embed = EmbedAtomID(out_size=ch_list[0],
                                         in_size=n_atom_types)
            elif input_type == 'float':
                self.embed = GraphLinear(None, ch_list[0])
            else:
                raise ValueError("[ERROR] Unexpected value input_type={}"
                                 .format(input_type))
            self.rgcn_convs = chainer.ChainList(*[
                RelGCNUpdate(ch_list[i], ch_list[i+1], num_edge_type)
                for i in range(len(ch_list)-1)])
            if readout:
                if chainer_chemistry.__version__ == '0.7.0':
                    self.rgcn_readout = GGNNReadout(
                        out_dim=out_channels, in_channels=ch_list[-1],
                        nobias=True, activation=functions.tanh)
                else:
                    self.rgcn_readout = GGNNReadout(
                        out_dim=out_channels, hidden_dim=ch_list[-1],
                        nobias=True, activation=functions.tanh)
        # self.num_relations = num_edge_type
        self.input_type = input_type
        self.scale_adj = scale_adj
        self.readout = readout
Ejemplo n.º 2
0
 def __init__(self,
              out_dim,
              hidden_dim=16,
              n_layers=4,
              n_atom_types=MAX_ATOMIC_NUM,
              concat_hidden=False,
              weight_tying=True,
              activation=functions.identity,
              num_edge_type=4):
     super(GGNN, self).__init__()
     n_readout_layer = n_layers if concat_hidden else 1
     n_message_layer = 1 if weight_tying else n_layers
     with self.init_scope():
         # Update
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.update_layers = chainer.ChainList(*[
             GGNNUpdate(hidden_dim=hidden_dim, num_edge_type=num_edge_type)
             for _ in range(n_message_layer)
         ])
         # Readout
         self.readout_layers = chainer.ChainList(*[
             GGNNReadout(out_dim=out_dim,
                         hidden_dim=hidden_dim,
                         activation=activation,
                         activation_agg=activation)
             for _ in range(n_readout_layer)
         ])
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.n_layers = n_layers
     self.num_edge_type = num_edge_type
     self.activation = activation
     self.concat_hidden = concat_hidden
     self.weight_tying = weight_tying
Ejemplo n.º 3
0
    def __init__(self, out_dim, hidden_dim=16,
                 n_layers=4, n_atom_types=MAX_ATOMIC_NUM,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=True,
                 activation=functions.identity):
        super(GIN, self).__init__()

        n_message_layer = 1 if weight_tying else n_layers
        n_readout_layer = n_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            # two non-linear MLP part
            self.update_layers = chainer.ChainList(*[GINUpdate(
                hidden_dim=hidden_dim, dropout_ratio=dropout_ratio)
                for _ in range(n_message_layer)])

            # Readout
            self.readout_layers = chainer.ChainList(*[GGNNReadout(
                out_dim=out_dim, hidden_dim=hidden_dim,
                activation=activation, activation_agg=activation)
                for _ in range(n_readout_layer)])
        # end with

        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
Ejemplo n.º 4
0
    def __init__(
            self,
            out_dim,  # type: int
            hidden_channels=16,  # type: int
            n_update_layers=4,  # type: int
            n_atom_types=MAX_ATOMIC_NUM,  # type: int
            concat_hidden=False,  # type: bool
            weight_tying=True,  # type: bool
            n_edge_types=4,  # type: int
            nn=None,  # type: Optional[chainer.Link]
            message_func='edgenet',  # type: str
            readout_func='set2set',  # type: str
    ):
        # type: (...) -> None
        super(MPNN, self).__init__()
        if message_func not in ('edgenet', 'ggnn'):
            raise ValueError(
                'Invalid message function: {}'.format(message_func))
        if readout_func not in ('set2set', 'ggnn'):
            raise ValueError(
                'Invalid readout function: {}'.format(readout_func))
        n_readout_layer = n_update_layers if concat_hidden else 1
        n_message_layer = 1 if weight_tying else n_update_layers
        with self.init_scope():
            # Update
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            if message_func == 'ggnn':
                self.update_layers = chainer.ChainList(*[
                    GGNNUpdate(hidden_channels=hidden_channels,
                               n_edge_types=n_edge_types)
                    for _ in range(n_message_layer)
                ])
            else:
                self.update_layers = chainer.ChainList(*[
                    MPNNUpdate(hidden_channels=hidden_channels, nn=nn)
                    for _ in range(n_message_layer)
                ])

            # Readout
            if readout_func == 'ggnn':
                self.readout_layers = chainer.ChainList(*[
                    GGNNReadout(out_dim=out_dim,
                                in_channels=hidden_channels * 2)
                    for _ in range(n_readout_layer)
                ])
            else:
                self.readout_layers = chainer.ChainList(*[
                    MPNNReadout(out_dim=out_dim,
                                in_channels=hidden_channels,
                                n_layers=1) for _ in range(n_readout_layer)
                ])
        self.out_dim = out_dim
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.n_edge_types = n_edge_types
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
        self.message_func = message_func
        self.readout_func = readout_func
Ejemplo n.º 5
0
    def __init__(self,
                 out_dim,
                 hidden_channels=16,
                 n_update_layers=4,
                 n_atom_types=MAX_ATOMIC_NUM,
                 concat_hidden=False,
                 dropout_ratio=-1.,
                 weight_tying=False,
                 activation=functions.identity,
                 n_edge_types=4,
                 n_heads=3,
                 negative_slope=0.2,
                 softmax_mode='across',
                 concat_heads=False):
        super(RelGAT, self).__init__()
        n_readout_layer = n_update_layers if concat_hidden else 1
        n_message_layer = n_update_layers
        with self.init_scope():
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            update_layers = []
            for i in range(n_message_layer):
                if i > 0 and concat_heads:
                    input_dim = hidden_channels * n_heads
                else:
                    input_dim = hidden_channels
                update_layers.append(
                    RelGATUpdate(input_dim,
                                 hidden_channels,
                                 n_heads=n_heads,
                                 n_edge_types=n_edge_types,
                                 dropout_ratio=dropout_ratio,
                                 negative_slope=negative_slope,
                                 softmax_mode=softmax_mode,
                                 concat_heads=concat_heads))
            self.update_layers = chainer.ChainList(*update_layers)
            if concat_heads:
                in_channels = hidden_channels * (n_heads + 1)
            else:
                in_channels = hidden_channels * 2
            self.readout_layers = chainer.ChainList(*[
                GGNNReadout(out_dim=out_dim,
                            in_channels=in_channels,
                            activation=activation,
                            activation_agg=activation)
                for _ in range(n_readout_layer)
            ])

        self.out_dim = out_dim
        self.n_heads = n_heads
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.concat_hidden = concat_hidden
        self.concat_heads = concat_heads
        self.weight_tying = weight_tying
        self.negative_slope = negative_slope
        self.n_edge_types = n_edge_types
        self.dropout_ratio = dropout_ratio
Ejemplo n.º 6
0
    def __init__(self,
                 out_dim,
                 node_embedding=False,
                 hidden_channels=16,
                 out_channels=None,
                 n_update_layers=4,
                 n_atom_types=MAX_ATOMIC_NUM,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=False,
                 activation=functions.identity,
                 n_edge_types=4):
        super(GIN, self).__init__()
        n_message_layer = 1 if weight_tying else n_update_layers
        n_readout_layer = n_update_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            self.first_mlp = GINUpdate(hidden_channels=hidden_channels,
                                       dropout_ratio=dropout_ratio,
                                       out_channels=hidden_channels).graph_mlp

            # two non-linear MLP part
            if out_channels is None:
                out_channels = hidden_channels
            self.update_layers = chainer.ChainList(*[
                GINUpdate(hidden_channels=hidden_channels,
                          dropout_ratio=dropout_ratio,
                          out_channels=(out_channels if i == n_message_layer -
                                        1 else hidden_channels))
                for i in range(n_message_layer)
            ])

            # Readout
            self.readout_layers = chainer.ChainList(*[
                GGNNReadout(out_dim=out_dim,
                            in_channels=hidden_channels * 2,
                            activation=activation,
                            activation_agg=activation)
                for _ in range(n_readout_layer)
            ])
        # end with

        self.node_embedding = node_embedding
        self.out_dim = out_dim
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
        self.n_edge_types = n_edge_types
Ejemplo n.º 7
0
 def __init__(self,
              out_dim,
              hidden_dim=16,
              hidden_dim_super=16,
              n_layers=4,
              n_heads=8,
              n_atom_types=MAX_ATOMIC_NUM,
              n_super_feature=2 + 2 * 4 + MAX_ATOMIC_NUM * 2,
              dropout_ratio=0.5,
              concat_hidden=False,
              weight_tying=True,
              activation=functions.identity,
              num_edge_type=4):
     super(GGNN_GWM, self).__init__()
     n_readout_layer = n_layers if concat_hidden else 1
     n_message_layer = 1 if weight_tying else n_layers
     with self.init_scope():
         # Update
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.update_layers = chainer.ChainList(*[
             GGNNUpdate(hidden_dim=hidden_dim, num_edge_type=num_edge_type)
             for _ in range(n_message_layer)
         ])
         # GWM
         self.embed_super = links.Linear(in_size=n_super_feature,
                                         out_size=hidden_dim_super)
         self.gwm = GWM(hidden_dim=hidden_dim,
                        hidden_dim_super=hidden_dim_super,
                        n_layers=n_message_layer,
                        n_heads=n_heads,
                        dropout_ratio=dropout_ratio,
                        tying_flag=weight_tying,
                        gpu=-1)
         # Readout
         self.readout_layers = chainer.ChainList(*[
             GGNNReadout(out_dim=out_dim,
                         hidden_dim=hidden_dim,
                         activation=activation,
                         activation_agg=activation)
             for _ in range(n_readout_layer)
         ])
         self.linear_for_concat_super = links.Linear(in_size=None,
                                                     out_size=out_dim)
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.hidden_dim_super = hidden_dim_super
     self.n_layers = n_layers
     self.n_heads = n_heads
     self.dropout_ratio = dropout_ratio
     self.num_edge_type = num_edge_type
     self.activation = activation
     self.concat_hidden = concat_hidden
     self.weight_tying = weight_tying
Ejemplo n.º 8
0
 def __init__(self,
              out_dim,
              hidden_dim,
              n_heads=3,
              negative_slope=0.2,
              n_edge_types=4,
              n_layers=4,
              dropout_ratio=-1,
              activation=F.tanh,
              softmax_mode="across",
              concat_hidden=False,
              concat_heads=True,
              weight_tying=False):
     super(RelationalGAT, self).__init__()
     n_readout_layer = n_layers if concat_hidden else 1
     n_message_layer = n_layers
     with self.init_scope():
         update_layers = []
         for i in range(n_message_layer):
             if i > 0 and concat_heads:
                 input_dim = hidden_dim * n_heads
             else:
                 input_dim = hidden_dim
             update_layers.append(
                 RelGATUpdate(input_dim,
                              hidden_dim,
                              n_heads=n_heads,
                              n_edge_types=n_edge_types,
                              dropout_ratio=dropout_ratio,
                              negative_slope=negative_slope,
                              softmax_mode=softmax_mode,
                              concat_heads=concat_heads))
         self.update_layers = chainer.ChainList(*update_layers)
         self.readout_layers = chainer.ChainList(*[
             GGNNReadout(out_dim=out_dim,
                         in_channels=input_dim,
                         activation=activation,
                         activation_agg=activation)
             for _ in range(n_readout_layer)
         ])
     self.out_dim = out_dim
     self.n_heads = n_heads
     self.hidden_dim = hidden_dim
     self.n_layers = n_layers
     self.concat_hidden = concat_hidden
     self.concat_heads = concat_heads
     self.weight_tying = weight_tying
     self.negative_slope = negative_slope
     self.n_edge_types = n_edge_types
     self.dropout_ratio = dropout_ratio
Ejemplo n.º 9
0
 def __init__(self,
              out_dim=64,
              hidden_channels=None,
              n_update_layers=None,
              n_atom_types=MAX_ATOMIC_NUM,
              n_edge_types=4,
              input_type='int',
              scale_adj=False):
     super(RelGCN, self).__init__()
     if hidden_channels is None:
         hidden_channels = [16, 128, 64]
     elif isinstance(hidden_channels, int):
         if not isinstance(n_update_layers, int):
             raise ValueError(
                 'Must specify n_update_layers when hidden_channels is int')
         hidden_channels = [hidden_channels] * n_update_layers
     with self.init_scope():
         if input_type == 'int':
             self.embed = EmbedAtomID(out_size=hidden_channels[0],
                                      in_size=n_atom_types)
         elif input_type == 'float':
             self.embed = GraphLinear(None, hidden_channels[0])
         else:
             raise ValueError(
                 "[ERROR] Unexpected value input_type={}".format(
                     input_type))
         self.rgcn_convs = chainer.ChainList(*[
             RelGCNUpdate(hidden_channels[i], hidden_channels[i + 1],
                          n_edge_types)
             for i in range(len(hidden_channels) - 1)
         ])
         self.rgcn_readout = GGNNReadout(out_dim=out_dim,
                                         in_channels=hidden_channels[-1],
                                         nobias=True,
                                         activation=functions.tanh)
     # self.num_relations = num_edge_type
     self.input_type = input_type
     self.scale_adj = scale_adj
Ejemplo n.º 10
0
def readout():
    return GGNNReadout(out_dim=out_dim, in_channels=None)
Ejemplo n.º 11
0
def readout():
    return GGNNReadout(out_dim=out_dim, hidden_dim=hidden_dim)