Exemplo n.º 1
0
 def __init__(self,
              out_dim,
              hidden_dim=16,
              n_layers=4,
              n_atom_types=MAX_ATOMIC_NUM,
              concat_hidden=False,
              weight_tying=True,
              activation=functions.identity,
              num_edge_type=4):
     super(GGNN, self).__init__()
     n_readout_layer = n_layers if concat_hidden else 1
     n_message_layer = 1 if weight_tying else n_layers
     with self.init_scope():
         # Update
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.update_layers = chainer.ChainList(*[
             GGNNUpdate(hidden_dim=hidden_dim, num_edge_type=num_edge_type)
             for _ in range(n_message_layer)
         ])
         # Readout
         self.readout_layers = chainer.ChainList(*[
             GGNNReadout(out_dim=out_dim,
                         hidden_dim=hidden_dim,
                         activation=activation,
                         activation_agg=activation)
             for _ in range(n_readout_layer)
         ])
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.n_layers = n_layers
     self.num_edge_type = num_edge_type
     self.activation = activation
     self.concat_hidden = concat_hidden
     self.weight_tying = weight_tying
Exemplo n.º 2
0
    def __init__(
            self,
            out_dim,  # type: int
            hidden_channels=16,  # type: int
            n_update_layers=4,  # type: int
            n_atom_types=MAX_ATOMIC_NUM,  # type: int
            concat_hidden=False,  # type: bool
            weight_tying=True,  # type: bool
            n_edge_types=4,  # type: int
            nn=None,  # type: Optional[chainer.Link]
            message_func='edgenet',  # type: str
            readout_func='set2set',  # type: str
    ):
        # type: (...) -> None
        super(MPNN, self).__init__()
        if message_func not in ('edgenet', 'ggnn'):
            raise ValueError(
                'Invalid message function: {}'.format(message_func))
        if readout_func not in ('set2set', 'ggnn'):
            raise ValueError(
                'Invalid readout function: {}'.format(readout_func))
        n_readout_layer = n_update_layers if concat_hidden else 1
        n_message_layer = 1 if weight_tying else n_update_layers
        with self.init_scope():
            # Update
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            if message_func == 'ggnn':
                self.update_layers = chainer.ChainList(*[
                    GGNNUpdate(hidden_channels=hidden_channels,
                               n_edge_types=n_edge_types)
                    for _ in range(n_message_layer)
                ])
            else:
                self.update_layers = chainer.ChainList(*[
                    MPNNUpdate(hidden_channels=hidden_channels, nn=nn)
                    for _ in range(n_message_layer)
                ])

            # Readout
            if readout_func == 'ggnn':
                self.readout_layers = chainer.ChainList(*[
                    GGNNReadout(out_dim=out_dim,
                                in_channels=hidden_channels * 2)
                    for _ in range(n_readout_layer)
                ])
            else:
                self.readout_layers = chainer.ChainList(*[
                    MPNNReadout(out_dim=out_dim,
                                in_channels=hidden_channels,
                                n_layers=1) for _ in range(n_readout_layer)
                ])
        self.out_dim = out_dim
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.n_edge_types = n_edge_types
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
        self.message_func = message_func
        self.readout_func = readout_func
Exemplo n.º 3
0
 def __init__(self,
              out_dim,
              hidden_dim=16,
              hidden_dim_super=16,
              n_layers=4,
              n_heads=8,
              n_atom_types=MAX_ATOMIC_NUM,
              n_super_feature=2 + 2 * 4 + MAX_ATOMIC_NUM * 2,
              dropout_ratio=0.5,
              concat_hidden=False,
              weight_tying=True,
              activation=functions.identity,
              num_edge_type=4):
     super(GGNN_GWM, self).__init__()
     n_readout_layer = n_layers if concat_hidden else 1
     n_message_layer = 1 if weight_tying else n_layers
     with self.init_scope():
         # Update
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.update_layers = chainer.ChainList(*[
             GGNNUpdate(hidden_dim=hidden_dim, num_edge_type=num_edge_type)
             for _ in range(n_message_layer)
         ])
         # GWM
         self.embed_super = links.Linear(in_size=n_super_feature,
                                         out_size=hidden_dim_super)
         self.gwm = GWM(hidden_dim=hidden_dim,
                        hidden_dim_super=hidden_dim_super,
                        n_layers=n_message_layer,
                        n_heads=n_heads,
                        dropout_ratio=dropout_ratio,
                        tying_flag=weight_tying,
                        gpu=-1)
         # Readout
         self.readout_layers = chainer.ChainList(*[
             GGNNReadout(out_dim=out_dim,
                         hidden_dim=hidden_dim,
                         activation=activation,
                         activation_agg=activation)
             for _ in range(n_readout_layer)
         ])
         self.linear_for_concat_super = links.Linear(in_size=None,
                                                     out_size=out_dim)
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.hidden_dim_super = hidden_dim_super
     self.n_layers = n_layers
     self.n_heads = n_heads
     self.dropout_ratio = dropout_ratio
     self.num_edge_type = num_edge_type
     self.activation = activation
     self.concat_hidden = concat_hidden
     self.weight_tying = weight_tying
Exemplo n.º 4
0
def update():
    return GGNNUpdate(hidden_dim=hidden_dim, num_edge_type=num_edge_type)
Exemplo n.º 5
0
def update():
    return GGNNUpdate(in_channels=in_channels,
                      hidden_channels=hidden_channels,
                      n_edge_types=n_edge_types)