Exemple #1
0
 def __init__(self,
              out_dim,
              hidden_dim=16,
              n_layers=4,
              n_atom_types=MAX_ATOMIC_NUM,
              concat_hidden=False,
              weight_tying=True):
     super(GGNN, self).__init__()
     n_readout_layer = 1 if concat_hidden else n_layers
     n_message_layer = 1 if weight_tying else n_layers
     with self.init_scope():
         # Update
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.message_layers = chainer.ChainList(*[
             GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim)
             for _ in range(n_message_layer)
         ])
         self.update_layer = links.GRU(2 * hidden_dim, hidden_dim)
         # Readout
         self.i_layers = chainer.ChainList(*[
             GraphLinear(2 * hidden_dim, out_dim)
             for _ in range(n_readout_layer)
         ])
         self.j_layers = chainer.ChainList(*[
             GraphLinear(hidden_dim, out_dim)
             for _ in range(n_readout_layer)
         ])
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.n_layers = n_layers
     self.concat_hidden = concat_hidden
     self.weight_tying = weight_tying
Exemple #2
0
 def __init__(self,
              out_dim,
              hidden_channels=16,
              n_update_layers=4,
              max_degree=6,
              n_atom_types=MAX_ATOMIC_NUM,
              concat_hidden=False):
     super(NFP, self).__init__()
     n_degree_types = max_degree + 1
     with self.init_scope():
         self.embed = EmbedAtomID(in_size=n_atom_types,
                                  out_size=hidden_channels)
         self.layers = chainer.ChainList(*[
             NFPUpdate(
                 hidden_channels, hidden_channels, max_degree=max_degree)
             for _ in range(n_update_layers)
         ])
         self.readout_layers = chainer.ChainList(*[
             NFPReadout(out_dim=out_dim, in_channels=hidden_channels)
             for _ in range(n_update_layers)
         ])
     self.out_dim = out_dim
     self.hidden_channels = hidden_channels
     self.max_degree = max_degree
     self.n_degree_types = n_degree_types
     self.n_update_layers = n_update_layers
     self.concat_hidden = concat_hidden
Exemple #3
0
 def __init__(self, out_dim, hidden_dim=16, n_layers=4,
              n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False,
              weight_tying=True, activation=functions.identity,
              num_edge_type=4):
     super(GGNN, self).__init__()
     n_readout_layer = n_layers if concat_hidden else 1
     n_message_layer = 1 if weight_tying else n_layers
     with self.init_scope():
         # Update
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.update_layers = chainer.ChainList(*[GGNNUpdate(
             hidden_dim=hidden_dim, num_edge_type=num_edge_type)
             for _ in range(n_message_layer)])
         # Readout
         self.readout_layers = chainer.ChainList(*[GGNNReadout(
             out_dim=out_dim, hidden_dim=hidden_dim,
             activation=activation, activation_agg=activation)
             for _ in range(n_readout_layer)])
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.n_layers = n_layers
     self.num_edge_type = num_edge_type
     self.activation = activation
     self.concat_hidden = concat_hidden
     self.weight_tying = weight_tying
Exemple #4
0
    def __init__(
            self,
            out_dim,  # type: int
            hidden_channels=16,  # type: int
            n_update_layers=4,  # type: int
            n_atom_types=MAX_ATOMIC_NUM,  # type: int
            concat_hidden=False,  # type: bool
            weight_tying=True,  # type: bool
            n_edge_types=4,  # type: int
            nn=None,  # type: Optional[chainer.Link]
            message_func='edgenet',  # type: str
            readout_func='set2set',  # type: str
    ):
        # type: (...) -> None
        super(MPNN, self).__init__()
        if message_func not in ('edgenet', 'ggnn'):
            raise ValueError(
                'Invalid message function: {}'.format(message_func))
        if readout_func not in ('set2set', 'ggnn'):
            raise ValueError(
                'Invalid readout function: {}'.format(readout_func))
        n_readout_layer = n_update_layers if concat_hidden else 1
        n_message_layer = 1 if weight_tying else n_update_layers
        with self.init_scope():
            # Update
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            if message_func == 'ggnn':
                self.update_layers = chainer.ChainList(*[
                    GGNNUpdate(hidden_channels=hidden_channels,
                               n_edge_types=n_edge_types)
                    for _ in range(n_message_layer)
                ])
            else:
                self.update_layers = chainer.ChainList(*[
                    MPNNUpdate(hidden_channels=hidden_channels, nn=nn)
                    for _ in range(n_message_layer)
                ])

            # Readout
            if readout_func == 'ggnn':
                self.readout_layers = chainer.ChainList(*[
                    GGNNReadout(out_dim=out_dim,
                                in_channels=hidden_channels * 2)
                    for _ in range(n_readout_layer)
                ])
            else:
                self.readout_layers = chainer.ChainList(*[
                    MPNNReadout(out_dim=out_dim,
                                in_channels=hidden_channels,
                                n_layers=1) for _ in range(n_readout_layer)
                ])
        self.out_dim = out_dim
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.n_edge_types = n_edge_types
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
        self.message_func = message_func
        self.readout_func = readout_func
Exemple #5
0
    def __init__(self, out_dim, hidden_dim=16,
                 n_layers=4, n_atom_types=MAX_ATOMIC_NUM,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=True,
                 activation=functions.identity):
        super(GIN, self).__init__()

        n_message_layer = 1 if weight_tying else n_layers
        n_readout_layer = n_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            # two non-linear MLP part
            self.update_layers = chainer.ChainList(*[GINUpdate(
                hidden_dim=hidden_dim, dropout_ratio=dropout_ratio)
                for _ in range(n_message_layer)])

            # Readout
            self.readout_layers = chainer.ChainList(*[GGNNReadout(
                out_dim=out_dim, hidden_dim=hidden_dim,
                activation=activation, activation_agg=activation)
                for _ in range(n_readout_layer)])
        # end with

        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
Exemple #6
0
    def __init__(self,
                 out_dim,
                 hidden_channels=16,
                 n_update_layers=4,
                 n_atom_types=MAX_ATOMIC_NUM,
                 concat_hidden=False,
                 dropout_ratio=-1.,
                 weight_tying=False,
                 activation=functions.identity,
                 n_edge_types=4,
                 n_heads=3,
                 negative_slope=0.2,
                 softmax_mode='across',
                 concat_heads=False):
        super(RelGAT, self).__init__()
        n_readout_layer = n_update_layers if concat_hidden else 1
        n_message_layer = n_update_layers
        with self.init_scope():
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            update_layers = []
            for i in range(n_message_layer):
                if i > 0 and concat_heads:
                    input_dim = hidden_channels * n_heads
                else:
                    input_dim = hidden_channels
                update_layers.append(
                    RelGATUpdate(input_dim,
                                 hidden_channels,
                                 n_heads=n_heads,
                                 n_edge_types=n_edge_types,
                                 dropout_ratio=dropout_ratio,
                                 negative_slope=negative_slope,
                                 softmax_mode=softmax_mode,
                                 concat_heads=concat_heads))
            self.update_layers = chainer.ChainList(*update_layers)
            if concat_heads:
                in_channels = hidden_channels * (n_heads + 1)
            else:
                in_channels = hidden_channels * 2
            self.readout_layers = chainer.ChainList(*[
                GGNNReadout(out_dim=out_dim,
                            in_channels=in_channels,
                            activation=activation,
                            activation_agg=activation)
                for _ in range(n_readout_layer)
            ])

        self.out_dim = out_dim
        self.n_heads = n_heads
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.concat_hidden = concat_hidden
        self.concat_heads = concat_heads
        self.weight_tying = weight_tying
        self.negative_slope = negative_slope
        self.n_edge_types = n_edge_types
        self.dropout_ratio = dropout_ratio
Exemple #7
0
    def __init__(self,
                 out_dim,
                 hidden_dim=16,
                 hidden_dim_super=16,
                 n_layers=4,
                 n_heads=8,
                 n_atom_types=MAX_ATOMIC_NUM,
                 n_super_feature=2 + 2 + MAX_ATOMIC_NUM * 2,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=True,
                 activation=functions.identity):
        super(GIN_GWM, self).__init__()

        n_message_layer = 1 if weight_tying else n_layers
        n_readout_layer = n_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            # two non-linear MLP part
            self.update_layers = chainer.ChainList(*[
                GINUpdate(hidden_dim=hidden_dim, dropout_ratio=dropout_ratio)
                for _ in range(n_message_layer)
            ])

            # GWM
            self.embed_super = links.Linear(in_size=n_super_feature,
                                            out_size=hidden_dim_super)
            self.gwm = GWM(hidden_dim=hidden_dim,
                           hidden_dim_super=hidden_dim_super,
                           n_layers=n_message_layer,
                           n_heads=n_heads,
                           dropout_ratio=dropout_ratio,
                           tying_flag=weight_tying,
                           gpu=-1)

            # Readout
            self.readout_layers = chainer.ChainList(*[
                GINReadout(out_dim=out_dim,
                           hidden_dim=hidden_dim,
                           activation=activation,
                           activation_agg=activation)
                for _ in range(n_readout_layer)
            ])
            self.linear_for_concat_super = links.Linear(in_size=None,
                                                        out_size=out_dim)
        # end with

        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.hidden_dim_super = hidden_dim_super
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
Exemple #8
0
    def __init__(
        self,
        out_dim,
        hidden_dim=16,
        n_layers=4,
        n_atom_types=MAX_ATOMIC_NUM,
        concat_hidden=False,
        dropout_rate=0.0,
        layer_aggr=None,
        batch_normalization=False,
        weight_tying=True,
        update_tying=True,
    ):
        super(GGNN, self).__init__()
        n_readout_layer = n_layers if concat_hidden else 1
        n_message_layer = 1 if weight_tying else n_layers
        n_update_layer = 1 if update_tying else n_layers
        self.n_readout_layer = n_readout_layer
        self.n_message_layer = n_message_layer
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.concat_hidden = concat_hidden
        self.dropout_rate = dropout_rate
        self.batch_normalization = batch_normalization
        self.weight_tying = weight_tying
        self.update_tying = update_tying
        self.layer_aggr = layer_aggr

        with self.init_scope():
            # Update
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            self.message_layers = chainer.ChainList(*[
                GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim)
                for _ in range(n_message_layer)
            ])

            self.update_layer = chainer.ChainList(*[
                links.Linear(2 * hidden_dim, hidden_dim)
                for _ in range(n_update_layer)
            ])
            # self.update_layer = links.GRU(2 * hidden_dim, hidden_dim)

            # Layer Aggregation
            self.aggr = select_aggr(layer_aggr, 1, hidden_dim, hidden_dim)

            # Readout
            self.i_layers = chainer.ChainList(*[
                GraphLinear(2 * hidden_dim, out_dim)
                for _ in range(n_readout_layer)
            ])
            self.j_layers = chainer.ChainList(*[
                GraphLinear(hidden_dim, out_dim)
                for _ in range(n_readout_layer)
            ])
Exemple #9
0
 def __init__(self,
              word_size,
              num_atom_type=MAX_ATOMIC_NUM,
              id_trans_fn=None):
     super(AtomEmbed, self).__init__()
     with self.init_scope():
         self.embed = EmbedAtomID(out_size=word_size, in_size=num_atom_type)
     self.word_size = word_size
     self.num_atom_type = num_atom_type
     self.id_trans_fn = id_trans_fn
Exemple #10
0
    def __init__(self,
                 out_dim,
                 node_embedding=False,
                 hidden_channels=16,
                 out_channels=None,
                 n_update_layers=4,
                 n_atom_types=MAX_ATOMIC_NUM,
                 dropout_ratio=0.5,
                 concat_hidden=False,
                 weight_tying=False,
                 activation=functions.identity,
                 n_edge_types=4):
        super(GIN, self).__init__()
        n_message_layer = 1 if weight_tying else n_update_layers
        n_readout_layer = n_update_layers if concat_hidden else 1
        with self.init_scope():
            # embedding
            self.embed = EmbedAtomID(out_size=hidden_channels,
                                     in_size=n_atom_types)
            self.first_mlp = GINUpdate(hidden_channels=hidden_channels,
                                       dropout_ratio=dropout_ratio,
                                       out_channels=hidden_channels).graph_mlp

            # two non-linear MLP part
            if out_channels is None:
                out_channels = hidden_channels
            self.update_layers = chainer.ChainList(*[
                GINUpdate(hidden_channels=hidden_channels,
                          dropout_ratio=dropout_ratio,
                          out_channels=(out_channels if i == n_message_layer -
                                        1 else hidden_channels))
                for i in range(n_message_layer)
            ])

            # Readout
            self.readout_layers = chainer.ChainList(*[
                GGNNReadout(out_dim=out_dim,
                            in_channels=hidden_channels * 2,
                            activation=activation,
                            activation_agg=activation)
                for _ in range(n_readout_layer)
            ])
        # end with

        self.node_embedding = node_embedding
        self.out_dim = out_dim
        self.hidden_channels = hidden_channels
        self.n_update_layers = n_update_layers
        self.n_message_layers = n_message_layer
        self.n_readout_layer = n_readout_layer
        self.dropout_ratio = dropout_ratio
        self.concat_hidden = concat_hidden
        self.weight_tying = weight_tying
        self.n_edge_types = n_edge_types
Exemple #11
0
    def __init__(
        self,
        out_dim,
        hidden_dim=16,
        n_layers=4,
        n_atom_types=MAX_ATOMIC_NUM,
        concat_hidden=False,
        dropout_rate=0.0,
        batch_normalization=False,
        weight_tying=True,
        output_atoms=True,
    ):
        super(GGNN, self).__init__()
        n_readout_layer = n_layers if concat_hidden else 1
        n_message_layer = 1 if weight_tying else n_layers
        self.n_readout_layer = n_readout_layer
        self.n_message_layer = n_message_layer
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.concat_hidden = concat_hidden
        self.dropout_rate = dropout_rate
        self.batch_normalization = batch_normalization
        self.weight_tying = weight_tying
        self.output_atoms = output_atoms

        with self.init_scope():
            # Update
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            self.message_layers = chainer.ChainList(*[
                GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim)
                for _ in range(n_message_layer)
            ])

            self.update_layer = links.GRU(2 * hidden_dim, hidden_dim)
            # Readout
            self.i_layers = chainer.ChainList(*[
                GraphLinear(2 * hidden_dim, out_dim)
                for _ in range(n_readout_layer)
            ])
            self.j_layers = chainer.ChainList(*[
                GraphLinear(hidden_dim, out_dim)
                for _ in range(n_readout_layer)
            ])

        if self.output_atoms:
            self.atoms_list = []

        self.g_vec_list = []
Exemple #12
0
 def __init__(self, out_dim=1, hidden_dim=64, n_layers=3,
              readout_hidden_dim=32, n_atom_types=MAX_ATOMIC_NUM,
              concat_hidden=False):
     super(SchNet, self).__init__()
     with self.init_scope():
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.update_layers = chainer.ChainList(
             *[SchNetUpdate(hidden_dim) for _ in range(n_layers)])
         self.readout_layer = SchNetReadout(out_dim, readout_hidden_dim)
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.readout_hidden_dim = readout_hidden_dim
     self.n_layers = n_layers
     self.concat_hidden = concat_hidden
def data():
    numpy.random.seed(0)
    atom_data = numpy.random.randint(0,
                                     high=MAX_ATOMIC_NUM,
                                     size=(batch_size, atom_size)).astype('i')
    adj_data = numpy.random.uniform(0,
                                    high=2,
                                    size=(batch_size, num_edge_type, atom_size,
                                          atom_size)).astype('f')
    y_grad = numpy.random.uniform(
        -1, 1, (batch_size, atom_size, hidden_dim)).astype('f')

    embed = EmbedAtomID(in_size=MAX_ATOMIC_NUM, out_size=hidden_dim)
    embed_atom_data = embed(atom_data).data
    return embed_atom_data, adj_data, y_grad
Exemple #14
0
def data():
    numpy.random.seed(0)
    atom_data = numpy.random.randint(0,
                                     high=MAX_ATOMIC_NUM,
                                     size=(batch_size, atom_size)).astype('i')
    # symmetric matrix
    dist_data = numpy.random.uniform(0,
                                     high=30,
                                     size=(batch_size, atom_size,
                                           atom_size)).astype('f')
    dist_data = (dist_data + dist_data.swapaxes(-1, -2)) / 2.

    y_grad = numpy.random.uniform(
        -1, 1, (batch_size, atom_size, hidden_dim)).astype('f')
    embed = EmbedAtomID(in_size=MAX_ATOMIC_NUM, out_size=hidden_dim)
    embed_atom_data = embed(atom_data).data
    return embed_atom_data, dist_data, y_grad
Exemple #15
0
    def __init__(self,
                 weave_channels=None,
                 hidden_dim=16,
                 n_atom=WEAVE_DEFAULT_NUM_MAX_ATOMS,
                 n_sub_layer=1,
                 n_atom_types=MAX_ATOMIC_NUM,
                 readout_mode='sum'):
        weave_channels = weave_channels or WEAVENET_DEFAULT_WEAVE_CHANNELS
        weave_module = [
            WeaveModule(n_atom, c, n_sub_layer, readout_mode=readout_mode)
            for c in weave_channels
        ]

        super(WeaveNet, self).__init__()
        with self.init_scope():
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
            self.weave_module = chainer.ChainList(*weave_module)
            self.readout = GeneralReadout(mode=readout_mode)
        self.readout_mode = readout_mode
Exemple #16
0
 def __init__(self, out_dim=1, hidden_channels=64, n_update_layers=3,
              readout_hidden_dim=32, n_atom_types=MAX_ATOMIC_NUM,
              concat_hidden=False, num_rbf=300, radius_resolution=0.1,
              gamma=10.0):
     super(SchNet, self).__init__()
     with self.init_scope():
         self.embed = EmbedAtomID(out_size=hidden_channels,
                                  in_size=n_atom_types)
         self.update_layers = chainer.ChainList(
             *[SchNetUpdate(
                 hidden_channels,
                 num_rbf=num_rbf, radius_resolution=radius_resolution,
                 gamma=gamma) for _ in range(n_update_layers)])
         self.readout_layer = SchNetReadout(
             out_dim, in_channels=None, hidden_channels=readout_hidden_dim)
     self.out_dim = out_dim
     self.hidden_channels = hidden_channels
     self.readout_hidden_dim = readout_hidden_dim
     self.n_update_layers = n_update_layers
     self.concat_hidden = concat_hidden
Exemple #17
0
def data():
    numpy.random.seed(0)
    atom_data = numpy.random.randint(0,
                                     high=MAX_ATOMIC_NUM,
                                     size=(batch_size, atom_size)).astype('i')
    adj_data = numpy.random.randint(0,
                                    high=2,
                                    size=(batch_size, atom_size,
                                          atom_size)).astype('f')
    y_grad = numpy.random.uniform(
        -1, 1, (batch_size, atom_size, hidden_channels)).astype('f')

    embed = EmbedAtomID(in_size=MAX_ATOMIC_NUM, out_size=hidden_channels)
    embed_atom_data = embed(atom_data).data
    degree_mat = numpy.sum(adj_data, axis=1)
    deg_conds = numpy.array([
        numpy.broadcast_to(((degree_mat - degree) == 0)[:, :, None],
                           embed_atom_data.shape)
        for degree in range(1, num_degree_type + 1)
    ])
    return embed_atom_data, adj_data, deg_conds, y_grad
Exemple #18
0
    def __init__(self, out_channels=64, num_edge_type=4, ch_list=None,
                 n_atom_types=MAX_ATOMIC_NUM, input_type='int', scale_adj=False, activation=F.tanh):

        super(RelGCN, self).__init__()
        ch_list = ch_list or [16, 128, 64]
        # ch_list = [in_channels] + ch_list

        with self.init_scope():
            if input_type == 'int':
                self.embed = EmbedAtomID(out_size=ch_list[0], in_size=n_atom_types)
            elif input_type == 'float':
                self.embed = GraphLinear(None, ch_list[0])
            else:
                raise ValueError("[ERROR] Unexpected value input_type={}".format(input_type))
            self.rgcn_convs = chainer.ChainList(*[
                RelGCNUpdate(ch_list[i], ch_list[i+1], num_edge_type) for i in range(len(ch_list)-1)])
            self.rgcn_readout = RelGCNReadout(ch_list[-1], out_channels)
        # self.num_relations = num_edge_type
        self.input_type = input_type
        self.scale_adj = scale_adj
        self.activation = activation
Exemple #19
0
 def __init__(self, out_dim, hidden_dim=16,
              n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False,
              weight_tying=True):
     super(GGNN, self).__init__()
     n_readout_layer = n_layers if concat_hidden else 1
     n_message_layer = 1 if weight_tying else n_layers
     with self.init_scope():
         # Update
         self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
         self.update_layer = GGNNUpdate(
             hidden_dim=hidden_dim, n_layers=n_message_layer,
             n_atom_types=self.NUM_EDGE_TYPE, weight_tying=weight_tying)
         # Readout
         self.readout_layer = GGNNReadout(
             out_dim=out_dim, hidden_dim=hidden_dim,
             n_layers=n_readout_layer, concat_hidden=concat_hidden,
             activation=functions.identity)
     self.out_dim = out_dim
     self.hidden_dim = hidden_dim
     self.n_layers = n_layers
     self.concat_hidden = concat_hidden
     self.weight_tying = weight_tying
Exemple #20
0
 def __init__(self,
              out_dim=64,
              hidden_channels=None,
              n_update_layers=None,
              n_atom_types=MAX_ATOMIC_NUM,
              n_edge_types=4,
              input_type='int',
              scale_adj=False):
     super(RelGCNSparse, self).__init__()
     if hidden_channels is None:
         hidden_channels = [16, 128, 64]
     elif isinstance(hidden_channels, int):
         if not isinstance(n_update_layers, int):
             raise ValueError(
                 'Must specify n_update_layers when hidden_channels is int')
         hidden_channels = [hidden_channels] * n_update_layers
     with self.init_scope():
         if input_type == 'int':
             self.embed = EmbedAtomID(out_size=hidden_channels[0],
                                      in_size=n_atom_types)
         elif input_type == 'float':
             self.embed = Linear(None, hidden_channels[0])
         else:
             raise ValueError(
                 "[ERROR] Unexpected value input_type={}".format(
                     input_type))
         self.rgcn_convs = chainer.ChainList(*[
             RelGCNSparseUpdate(hidden_channels[i], hidden_channels[i + 1],
                                n_edge_types)
             for i in range(len(hidden_channels) - 1)
         ])
         self.rgcn_readout = ScatterGGNNReadout(
             out_dim=out_dim,
             in_channels=hidden_channels[-1],
             nobias=True,
             activation=functions.tanh)
     # self.num_relations = num_edge_type
     self.input_type = input_type
     self.scale_adj = scale_adj
Exemple #21
0
atom_size = 5
out_dim = 4
batch_size = 3
heads = 2
hidden_dim = 16

atom_data = numpy.random.randint(0, high=110,
                                 size=(batch_size,
                                       atom_size)).astype(numpy.int32)
adj_data = numpy.random.randint(0,
                                high=2,
                                size=(batch_size, atom_size,
                                      atom_size)).astype(numpy.float32)

embed = EmbedAtomID(out_size=hidden_dim, in_size=110)
weight = GraphLinear(hidden_dim, heads * hidden_dim)
att_weight = GraphLinear(hidden_dim * 2, 1)


def test(atom_array, adj_data):
    x = embed(atom_array)
    mb, atom, ch = x.shape
    print(x.shape)
    test = weight(x)
    print(test.shape)
    x = functions.expand_dims(test, axis=1)
    print(x.shape)
    x = functions.broadcast_to(x, (mb, atom, atom, heads * ch))
    print(x.shape)
    y = functions.copy(x, -1)
Exemple #22
0
    def __init__(
        self,
        out_dim,
        hidden_dim=16,
        n_layers=4,
        n_atom_types=MAX_ATOMIC_NUM,
        concat_hidden=False,
        layer_aggregator=None,
        dropout_rate=0.0,
        batch_normalization=False,
        weight_tying=True,
        use_attention=False,
        update_attention=False,
        attention_tying=True,
        context=False,
        context_layers=1,
        context_dropout=0.,
        message_function='matrix_multiply',
        edge_hidden_dim=16,
        readout_function='graph_level',
        num_timesteps=3,
        num_output_hidden_layers=0,
        output_hidden_dim=16,
        output_activation=functions.relu,
        output_atoms=False,
    ):
        super(GGNN, self).__init__()
        n_readout_layer = n_layers if concat_hidden else 1
        n_message_layer = 1 if weight_tying else n_layers
        n_attention_layer = 1 if attention_tying else n_layers
        self.n_readout_layer = n_readout_layer
        self.n_message_layer = n_message_layer
        self.n_attention_layer = n_attention_layer
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.concat_hidden = concat_hidden
        self.layer_aggregator = layer_aggregator
        self.dropout_rate = dropout_rate
        self.batch_normalization = batch_normalization
        self.weight_tying = weight_tying
        self.use_attention = use_attention
        self.update_attention = update_attention
        self.attention_tying = attention_tying
        self.context = context
        self.context_layers = context_layers
        self.context_dropout = context_dropout
        self.message_functinon = message_function
        self.edge_hidden_dim = edge_hidden_dim
        self.readout_function = readout_function
        self.num_timesteps = num_timesteps
        self.num_output_hidden_layers = num_output_hidden_layers
        self.output_hidden_dim = output_hidden_dim
        self.output_activation = output_activation
        self.output_atoms = output_atoms

        with self.init_scope():
            # Update
            self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)

            self.message_layers = chainer.ChainList(*[
                GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim)
                for _ in range(n_message_layer)
            ])

            if self.message_functinon == 'edge_network':
                del self.message_layers
                self.message_layers = chainer.ChainList(*[
                    EdgeNetwork(in_dim=self.NUM_EDGE_TYPE,
                                hidden_dim=self.edge_hidden_dim,
                                node_dim=self.hidden_dim)
                    for _ in range(n_message_layer)
                ])

            if self.context:
                self.context_bilstm = links.NStepBiLSTM(
                    n_layers=self.context_layers,
                    in_size=self.hidden_dim,
                    out_size=self.hidden_dim / 2,
                    dropout=context_dropout)

            # self-attention layer
            if use_attention or update_attention:
                # these commented layers are written for GAT impelmented by TensorFlow.
                # self.linear_transform_layer = chainer.ChainList(
                #     *[links.ConvolutionND(1, in_channels=hidden_dim, out_channels=hidden_dim, ksize=1, nobias=True)
                #         for _ in range(n_attention_layer)]
                # )
                # self.conv1d_layer_1 = chainer.ChainList(
                #     *[links.ConvolutionND(1, in_channels=hidden_dim, out_channels=1, ksize=1)
                #         for _ in range(n_attention_layer)]
                # )
                # self.conv1d_layer_2 = chainer.ChainList(
                #     *[links.ConvolutionND(1, in_channels=hidden_dim, out_channels=1, ksize=1)
                #       for _ in range(n_attention_layer)]
                # )
                self.linear_transform_layer = chainer.ChainList(*[
                    links.Linear(
                        in_size=hidden_dim, out_size=hidden_dim, nobias=True)
                    for _ in range(n_attention_layer)
                ])
                self.neural_network_layer = chainer.ChainList(*[
                    links.Linear(
                        in_size=2 * self.hidden_dim, out_size=1, nobias=True)
                    for _ in range(n_attention_layer)
                ])

            # batch normalization
            if batch_normalization:
                self.batch_normalization_layer = links.BatchNormalization(
                    size=hidden_dim)

            self.update_layer = links.GRU(2 * hidden_dim, hidden_dim)
            # Readout
            self.i_layers = chainer.ChainList(*[
                GraphLinear(2 * hidden_dim, out_dim)
                for _ in range(n_readout_layer)
            ])
            self.j_layers = chainer.ChainList(*[
                GraphLinear(hidden_dim, out_dim)
                for _ in range(n_readout_layer)
            ])

            if self.readout_function == 'set2vec':
                del self.i_layers, self.j_layers
                # def __init__(self, node_dim, output_dim, num_timesteps=3, inner_prod='default',
                #   num_output_hidden_layers=0, output_hidden_dim=16, activation=chainer.functions.relu):
                self.readout_layer = chainer.ChainList(*[
                    Set2Vec(node_dim=self.hidden_dim * 2,
                            output_dim=out_dim,
                            num_timesteps=num_timesteps,
                            num_output_hidden_layers=num_output_hidden_layers,
                            output_hidden_dim=output_hidden_dim,
                            activation=output_activation)
                    for _ in range(n_readout_layer)
                ])

            if self.layer_aggregator:
                self.construct_layer_aggregator()

                if self.layer_aggregator == 'gru-attn' or 'gru':
                    self.bigru_layer = links.NStepBiGRU(
                        n_layers=1,
                        in_size=self.hidden_dim,
                        out_size=self.hidden_dim,
                        dropout=0.)
                if self.layer_aggregator == 'lstm-attn' or 'lstm':
                    self.bilstm_layer = links.NStepBiLSTM(
                        n_layers=1,
                        in_size=self.hidden_dim,
                        out_size=self.hidden_dim,
                        dropout=0.)
                if self.layer_aggregator == 'gru-attn' or 'lstm-attn' or 'attn':
                    self.attn_dense_layer = links.Linear(
                        in_size=self.n_layers, out_size=self.n_layers)
                if self.layer_aggregator == 'self-attn':
                    self.attn_linear_layer = links.Linear(
                        in_size=self.n_layers, out_size=self.n_layers)

        if self.output_atoms:
            self.atoms = None