def __init__(self, out_dim, node_embedding=False, hidden_channels=16, out_channels=None, n_update_layers=4, n_atom_types=MAX_ATOMIC_NUM, dropout_ratio=0.5, concat_hidden=False, weight_tying=False, activation=functions.identity, n_edge_types=4): super(GIN, self).__init__() n_message_layer = 1 if weight_tying else n_update_layers n_readout_layer = n_update_layers if concat_hidden else 1 with self.init_scope(): # embedding self.embed = EmbedAtomID(out_size=hidden_channels, in_size=n_atom_types) self.first_mlp = GINUpdate(hidden_channels=hidden_channels, dropout_ratio=dropout_ratio, out_channels=hidden_channels).graph_mlp # two non-linear MLP part if out_channels is None: out_channels = hidden_channels self.update_layers = chainer.ChainList(*[ GINUpdate(hidden_channels=hidden_channels, dropout_ratio=dropout_ratio, out_channels=(out_channels if i == n_message_layer - 1 else hidden_channels)) for i in range(n_message_layer) ]) # Readout self.readout_layers = chainer.ChainList(*[ GGNNReadout(out_dim=out_dim, in_channels=hidden_channels * 2, activation=activation, activation_agg=activation) for _ in range(n_readout_layer) ]) # end with self.node_embedding = node_embedding self.out_dim = out_dim self.hidden_channels = hidden_channels self.n_update_layers = n_update_layers self.n_message_layers = n_message_layer self.n_readout_layer = n_readout_layer self.dropout_ratio = dropout_ratio self.concat_hidden = concat_hidden self.weight_tying = weight_tying self.n_edge_types = n_edge_types
def __init__(self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, dropout_ratio=0.5, concat_hidden=False, weight_tying=True, activation=functions.identity): super(GIN, self).__init__() n_message_layer = 1 if weight_tying else n_layers n_readout_layer = n_layers if concat_hidden else 1 with self.init_scope(): # embedding self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) # two non-linear MLP part self.update_layers = chainer.ChainList(*[GINUpdate( hidden_dim=hidden_dim, dropout_ratio=dropout_ratio) for _ in range(n_message_layer)]) # Readout self.readout_layers = chainer.ChainList(*[GGNNReadout( out_dim=out_dim, hidden_dim=hidden_dim, activation=activation, activation_agg=activation) for _ in range(n_readout_layer)]) # end with self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_message_layers = n_message_layer self.n_readout_layer = n_readout_layer self.dropout_ratio = dropout_ratio self.concat_hidden = concat_hidden self.weight_tying = weight_tying
def __init__(self, out_dim, hidden_dim=16, hidden_dim_super=16, n_layers=4, n_heads=8, n_atom_types=MAX_ATOMIC_NUM, n_super_feature=2 + 2 + MAX_ATOMIC_NUM * 2, dropout_ratio=0.5, concat_hidden=False, weight_tying=True, activation=functions.identity): super(GIN_GWM, self).__init__() n_message_layer = 1 if weight_tying else n_layers n_readout_layer = n_layers if concat_hidden else 1 with self.init_scope(): # embedding self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) # two non-linear MLP part self.update_layers = chainer.ChainList(*[ GINUpdate(hidden_dim=hidden_dim, dropout_ratio=dropout_ratio) for _ in range(n_message_layer) ]) # GWM self.embed_super = links.Linear(in_size=n_super_feature, out_size=hidden_dim_super) self.gwm = GWM(hidden_dim=hidden_dim, hidden_dim_super=hidden_dim_super, n_layers=n_message_layer, n_heads=n_heads, dropout_ratio=dropout_ratio, tying_flag=weight_tying, gpu=-1) # Readout self.readout_layers = chainer.ChainList(*[ GINReadout(out_dim=out_dim, hidden_dim=hidden_dim, activation=activation, activation_agg=activation) for _ in range(n_readout_layer) ]) self.linear_for_concat_super = links.Linear(in_size=None, out_size=out_dim) # end with self.out_dim = out_dim self.hidden_dim = hidden_dim self.hidden_dim_super = hidden_dim_super self.n_message_layers = n_message_layer self.n_readout_layer = n_readout_layer self.dropout_ratio = dropout_ratio self.concat_hidden = concat_hidden self.weight_tying = weight_tying
def update(): # type: () -> GINUpdate return GINUpdate(in_channels=in_channels, hidden_channels=hidden_channels, dropout_ratio=0)