Пример #1
0
    def __init__(self, out_channels=64, num_edge_type=4, ch_list=None,
                 n_atom_types=MAX_ATOMIC_NUM, input_type='int',
                 scale_adj=False):

        super(RelGCN, self).__init__()
        if ch_list is None:
            ch_list = [16, 128, 64]
        with self.init_scope():
            if input_type == 'int':
                self.embed = EmbedAtomID(out_size=ch_list[0],
                                         in_size=n_atom_types)
            elif input_type == 'float':
                self.embed = GraphLinear(None, ch_list[0])
            else:
                raise ValueError("[ERROR] Unexpected value input_type={}"
                                 .format(input_type))
            self.rgcn_convs = chainer.ChainList(*[
                RelGCNUpdate(ch_list[i], ch_list[i+1], num_edge_type)
                for i in range(len(ch_list)-1)])
            self.rgcn_readout = GGNNReadout(
                out_dim=out_channels, hidden_dim=ch_list[-1],
                nobias=True, activation=functions.tanh)
        # self.num_relations = num_edge_type
        self.input_type = input_type
        self.scale_adj = scale_adj
Пример #2
0
 def __init__(self,
              out_dim=64,
              hidden_channels=None,
              n_update_layers=None,
              n_atom_types=MAX_ATOMIC_NUM,
              n_edge_types=4,
              input_type='int',
              scale_adj=False):
     super(RelGCN, self).__init__()
     if hidden_channels is None:
         hidden_channels = [16, 128, 64]
     elif isinstance(hidden_channels, int):
         if not isinstance(n_update_layers, int):
             raise ValueError(
                 'Must specify n_update_layers when hidden_channels is int')
         hidden_channels = [hidden_channels] * n_update_layers
     with self.init_scope():
         if input_type == 'int':
             self.embed = EmbedAtomID(out_size=hidden_channels[0],
                                      in_size=n_atom_types)
         elif input_type == 'float':
             self.embed = GraphLinear(None, hidden_channels[0])
         else:
             raise ValueError(
                 "[ERROR] Unexpected value input_type={}".format(
                     input_type))
         self.rgcn_convs = chainer.ChainList(*[
             RelGCNUpdate(hidden_channels[i], hidden_channels[i + 1],
                          n_edge_types)
             for i in range(len(hidden_channels) - 1)
         ])
         self.rgcn_readout = GGNNReadout(out_dim=out_dim,
                                         in_channels=hidden_channels[-1],
                                         nobias=True,
                                         activation=functions.tanh)
     # self.num_relations = num_edge_type
     self.input_type = input_type
     self.scale_adj = scale_adj
def update():
    return RelGCNUpdate(in_channels=in_channels,
                        out_channels=hidden_dim,
                        num_edge_type=num_edge_type)