def __init__(self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, weight_tying=True): super(GGNN, self).__init__() n_readout_layer = 1 if concat_hidden else n_layers n_message_layer = 1 if weight_tying else n_layers with self.init_scope(): # Update self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.message_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim) for _ in range(n_message_layer) ]) self.update_layer = links.GRU(2 * hidden_dim, hidden_dim) # Readout self.i_layers = chainer.ChainList(*[ GraphLinear(2 * hidden_dim, out_dim) for _ in range(n_readout_layer) ]) self.j_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, out_dim) for _ in range(n_readout_layer) ]) self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.concat_hidden = concat_hidden self.weight_tying = weight_tying
def __init__(self, out_dim, hidden_channels=16, n_update_layers=4, max_degree=6, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False): super(NFP, self).__init__() n_degree_types = max_degree + 1 with self.init_scope(): self.embed = EmbedAtomID(in_size=n_atom_types, out_size=hidden_channels) self.layers = chainer.ChainList(*[ NFPUpdate( hidden_channels, hidden_channels, max_degree=max_degree) for _ in range(n_update_layers) ]) self.readout_layers = chainer.ChainList(*[ NFPReadout(out_dim=out_dim, in_channels=hidden_channels) for _ in range(n_update_layers) ]) self.out_dim = out_dim self.hidden_channels = hidden_channels self.max_degree = max_degree self.n_degree_types = n_degree_types self.n_update_layers = n_update_layers self.concat_hidden = concat_hidden
def __init__(self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, weight_tying=True, activation=functions.identity, num_edge_type=4): super(GGNN, self).__init__() n_readout_layer = n_layers if concat_hidden else 1 n_message_layer = 1 if weight_tying else n_layers with self.init_scope(): # Update self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.update_layers = chainer.ChainList(*[GGNNUpdate( hidden_dim=hidden_dim, num_edge_type=num_edge_type) for _ in range(n_message_layer)]) # Readout self.readout_layers = chainer.ChainList(*[GGNNReadout( out_dim=out_dim, hidden_dim=hidden_dim, activation=activation, activation_agg=activation) for _ in range(n_readout_layer)]) self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.num_edge_type = num_edge_type self.activation = activation self.concat_hidden = concat_hidden self.weight_tying = weight_tying
def __init__( self, out_dim, # type: int hidden_channels=16, # type: int n_update_layers=4, # type: int n_atom_types=MAX_ATOMIC_NUM, # type: int concat_hidden=False, # type: bool weight_tying=True, # type: bool n_edge_types=4, # type: int nn=None, # type: Optional[chainer.Link] message_func='edgenet', # type: str readout_func='set2set', # type: str ): # type: (...) -> None super(MPNN, self).__init__() if message_func not in ('edgenet', 'ggnn'): raise ValueError( 'Invalid message function: {}'.format(message_func)) if readout_func not in ('set2set', 'ggnn'): raise ValueError( 'Invalid readout function: {}'.format(readout_func)) n_readout_layer = n_update_layers if concat_hidden else 1 n_message_layer = 1 if weight_tying else n_update_layers with self.init_scope(): # Update self.embed = EmbedAtomID(out_size=hidden_channels, in_size=n_atom_types) if message_func == 'ggnn': self.update_layers = chainer.ChainList(*[ GGNNUpdate(hidden_channels=hidden_channels, n_edge_types=n_edge_types) for _ in range(n_message_layer) ]) else: self.update_layers = chainer.ChainList(*[ MPNNUpdate(hidden_channels=hidden_channels, nn=nn) for _ in range(n_message_layer) ]) # Readout if readout_func == 'ggnn': self.readout_layers = chainer.ChainList(*[ GGNNReadout(out_dim=out_dim, in_channels=hidden_channels * 2) for _ in range(n_readout_layer) ]) else: self.readout_layers = chainer.ChainList(*[ MPNNReadout(out_dim=out_dim, in_channels=hidden_channels, n_layers=1) for _ in range(n_readout_layer) ]) self.out_dim = out_dim self.hidden_channels = hidden_channels self.n_update_layers = n_update_layers self.n_edge_types = n_edge_types self.concat_hidden = concat_hidden self.weight_tying = weight_tying self.message_func = message_func self.readout_func = readout_func
def __init__(self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, dropout_ratio=0.5, concat_hidden=False, weight_tying=True, activation=functions.identity): super(GIN, self).__init__() n_message_layer = 1 if weight_tying else n_layers n_readout_layer = n_layers if concat_hidden else 1 with self.init_scope(): # embedding self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) # two non-linear MLP part self.update_layers = chainer.ChainList(*[GINUpdate( hidden_dim=hidden_dim, dropout_ratio=dropout_ratio) for _ in range(n_message_layer)]) # Readout self.readout_layers = chainer.ChainList(*[GGNNReadout( out_dim=out_dim, hidden_dim=hidden_dim, activation=activation, activation_agg=activation) for _ in range(n_readout_layer)]) # end with self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_message_layers = n_message_layer self.n_readout_layer = n_readout_layer self.dropout_ratio = dropout_ratio self.concat_hidden = concat_hidden self.weight_tying = weight_tying
def __init__(self, out_dim, hidden_channels=16, n_update_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, dropout_ratio=-1., weight_tying=False, activation=functions.identity, n_edge_types=4, n_heads=3, negative_slope=0.2, softmax_mode='across', concat_heads=False): super(RelGAT, self).__init__() n_readout_layer = n_update_layers if concat_hidden else 1 n_message_layer = n_update_layers with self.init_scope(): self.embed = EmbedAtomID(out_size=hidden_channels, in_size=n_atom_types) update_layers = [] for i in range(n_message_layer): if i > 0 and concat_heads: input_dim = hidden_channels * n_heads else: input_dim = hidden_channels update_layers.append( RelGATUpdate(input_dim, hidden_channels, n_heads=n_heads, n_edge_types=n_edge_types, dropout_ratio=dropout_ratio, negative_slope=negative_slope, softmax_mode=softmax_mode, concat_heads=concat_heads)) self.update_layers = chainer.ChainList(*update_layers) if concat_heads: in_channels = hidden_channels * (n_heads + 1) else: in_channels = hidden_channels * 2 self.readout_layers = chainer.ChainList(*[ GGNNReadout(out_dim=out_dim, in_channels=in_channels, activation=activation, activation_agg=activation) for _ in range(n_readout_layer) ]) self.out_dim = out_dim self.n_heads = n_heads self.hidden_channels = hidden_channels self.n_update_layers = n_update_layers self.concat_hidden = concat_hidden self.concat_heads = concat_heads self.weight_tying = weight_tying self.negative_slope = negative_slope self.n_edge_types = n_edge_types self.dropout_ratio = dropout_ratio
def __init__(self, out_dim, hidden_dim=16, hidden_dim_super=16, n_layers=4, n_heads=8, n_atom_types=MAX_ATOMIC_NUM, n_super_feature=2 + 2 + MAX_ATOMIC_NUM * 2, dropout_ratio=0.5, concat_hidden=False, weight_tying=True, activation=functions.identity): super(GIN_GWM, self).__init__() n_message_layer = 1 if weight_tying else n_layers n_readout_layer = n_layers if concat_hidden else 1 with self.init_scope(): # embedding self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) # two non-linear MLP part self.update_layers = chainer.ChainList(*[ GINUpdate(hidden_dim=hidden_dim, dropout_ratio=dropout_ratio) for _ in range(n_message_layer) ]) # GWM self.embed_super = links.Linear(in_size=n_super_feature, out_size=hidden_dim_super) self.gwm = GWM(hidden_dim=hidden_dim, hidden_dim_super=hidden_dim_super, n_layers=n_message_layer, n_heads=n_heads, dropout_ratio=dropout_ratio, tying_flag=weight_tying, gpu=-1) # Readout self.readout_layers = chainer.ChainList(*[ GINReadout(out_dim=out_dim, hidden_dim=hidden_dim, activation=activation, activation_agg=activation) for _ in range(n_readout_layer) ]) self.linear_for_concat_super = links.Linear(in_size=None, out_size=out_dim) # end with self.out_dim = out_dim self.hidden_dim = hidden_dim self.hidden_dim_super = hidden_dim_super self.n_message_layers = n_message_layer self.n_readout_layer = n_readout_layer self.dropout_ratio = dropout_ratio self.concat_hidden = concat_hidden self.weight_tying = weight_tying
def __init__( self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, dropout_rate=0.0, layer_aggr=None, batch_normalization=False, weight_tying=True, update_tying=True, ): super(GGNN, self).__init__() n_readout_layer = n_layers if concat_hidden else 1 n_message_layer = 1 if weight_tying else n_layers n_update_layer = 1 if update_tying else n_layers self.n_readout_layer = n_readout_layer self.n_message_layer = n_message_layer self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.concat_hidden = concat_hidden self.dropout_rate = dropout_rate self.batch_normalization = batch_normalization self.weight_tying = weight_tying self.update_tying = update_tying self.layer_aggr = layer_aggr with self.init_scope(): # Update self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.message_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim) for _ in range(n_message_layer) ]) self.update_layer = chainer.ChainList(*[ links.Linear(2 * hidden_dim, hidden_dim) for _ in range(n_update_layer) ]) # self.update_layer = links.GRU(2 * hidden_dim, hidden_dim) # Layer Aggregation self.aggr = select_aggr(layer_aggr, 1, hidden_dim, hidden_dim) # Readout self.i_layers = chainer.ChainList(*[ GraphLinear(2 * hidden_dim, out_dim) for _ in range(n_readout_layer) ]) self.j_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, out_dim) for _ in range(n_readout_layer) ])
def __init__(self, word_size, num_atom_type=MAX_ATOMIC_NUM, id_trans_fn=None): super(AtomEmbed, self).__init__() with self.init_scope(): self.embed = EmbedAtomID(out_size=word_size, in_size=num_atom_type) self.word_size = word_size self.num_atom_type = num_atom_type self.id_trans_fn = id_trans_fn
def __init__(self, out_dim, node_embedding=False, hidden_channels=16, out_channels=None, n_update_layers=4, n_atom_types=MAX_ATOMIC_NUM, dropout_ratio=0.5, concat_hidden=False, weight_tying=False, activation=functions.identity, n_edge_types=4): super(GIN, self).__init__() n_message_layer = 1 if weight_tying else n_update_layers n_readout_layer = n_update_layers if concat_hidden else 1 with self.init_scope(): # embedding self.embed = EmbedAtomID(out_size=hidden_channels, in_size=n_atom_types) self.first_mlp = GINUpdate(hidden_channels=hidden_channels, dropout_ratio=dropout_ratio, out_channels=hidden_channels).graph_mlp # two non-linear MLP part if out_channels is None: out_channels = hidden_channels self.update_layers = chainer.ChainList(*[ GINUpdate(hidden_channels=hidden_channels, dropout_ratio=dropout_ratio, out_channels=(out_channels if i == n_message_layer - 1 else hidden_channels)) for i in range(n_message_layer) ]) # Readout self.readout_layers = chainer.ChainList(*[ GGNNReadout(out_dim=out_dim, in_channels=hidden_channels * 2, activation=activation, activation_agg=activation) for _ in range(n_readout_layer) ]) # end with self.node_embedding = node_embedding self.out_dim = out_dim self.hidden_channels = hidden_channels self.n_update_layers = n_update_layers self.n_message_layers = n_message_layer self.n_readout_layer = n_readout_layer self.dropout_ratio = dropout_ratio self.concat_hidden = concat_hidden self.weight_tying = weight_tying self.n_edge_types = n_edge_types
def __init__( self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, dropout_rate=0.0, batch_normalization=False, weight_tying=True, output_atoms=True, ): super(GGNN, self).__init__() n_readout_layer = n_layers if concat_hidden else 1 n_message_layer = 1 if weight_tying else n_layers self.n_readout_layer = n_readout_layer self.n_message_layer = n_message_layer self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.concat_hidden = concat_hidden self.dropout_rate = dropout_rate self.batch_normalization = batch_normalization self.weight_tying = weight_tying self.output_atoms = output_atoms with self.init_scope(): # Update self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.message_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim) for _ in range(n_message_layer) ]) self.update_layer = links.GRU(2 * hidden_dim, hidden_dim) # Readout self.i_layers = chainer.ChainList(*[ GraphLinear(2 * hidden_dim, out_dim) for _ in range(n_readout_layer) ]) self.j_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, out_dim) for _ in range(n_readout_layer) ]) if self.output_atoms: self.atoms_list = [] self.g_vec_list = []
def __init__(self, out_dim=1, hidden_dim=64, n_layers=3, readout_hidden_dim=32, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False): super(SchNet, self).__init__() with self.init_scope(): self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.update_layers = chainer.ChainList( *[SchNetUpdate(hidden_dim) for _ in range(n_layers)]) self.readout_layer = SchNetReadout(out_dim, readout_hidden_dim) self.out_dim = out_dim self.hidden_dim = hidden_dim self.readout_hidden_dim = readout_hidden_dim self.n_layers = n_layers self.concat_hidden = concat_hidden
def data(): numpy.random.seed(0) atom_data = numpy.random.randint(0, high=MAX_ATOMIC_NUM, size=(batch_size, atom_size)).astype('i') adj_data = numpy.random.uniform(0, high=2, size=(batch_size, num_edge_type, atom_size, atom_size)).astype('f') y_grad = numpy.random.uniform( -1, 1, (batch_size, atom_size, hidden_dim)).astype('f') embed = EmbedAtomID(in_size=MAX_ATOMIC_NUM, out_size=hidden_dim) embed_atom_data = embed(atom_data).data return embed_atom_data, adj_data, y_grad
def data(): numpy.random.seed(0) atom_data = numpy.random.randint(0, high=MAX_ATOMIC_NUM, size=(batch_size, atom_size)).astype('i') # symmetric matrix dist_data = numpy.random.uniform(0, high=30, size=(batch_size, atom_size, atom_size)).astype('f') dist_data = (dist_data + dist_data.swapaxes(-1, -2)) / 2. y_grad = numpy.random.uniform( -1, 1, (batch_size, atom_size, hidden_dim)).astype('f') embed = EmbedAtomID(in_size=MAX_ATOMIC_NUM, out_size=hidden_dim) embed_atom_data = embed(atom_data).data return embed_atom_data, dist_data, y_grad
def __init__(self, weave_channels=None, hidden_dim=16, n_atom=WEAVE_DEFAULT_NUM_MAX_ATOMS, n_sub_layer=1, n_atom_types=MAX_ATOMIC_NUM, readout_mode='sum'): weave_channels = weave_channels or WEAVENET_DEFAULT_WEAVE_CHANNELS weave_module = [ WeaveModule(n_atom, c, n_sub_layer, readout_mode=readout_mode) for c in weave_channels ] super(WeaveNet, self).__init__() with self.init_scope(): self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.weave_module = chainer.ChainList(*weave_module) self.readout = GeneralReadout(mode=readout_mode) self.readout_mode = readout_mode
def __init__(self, out_dim=1, hidden_channels=64, n_update_layers=3, readout_hidden_dim=32, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, num_rbf=300, radius_resolution=0.1, gamma=10.0): super(SchNet, self).__init__() with self.init_scope(): self.embed = EmbedAtomID(out_size=hidden_channels, in_size=n_atom_types) self.update_layers = chainer.ChainList( *[SchNetUpdate( hidden_channels, num_rbf=num_rbf, radius_resolution=radius_resolution, gamma=gamma) for _ in range(n_update_layers)]) self.readout_layer = SchNetReadout( out_dim, in_channels=None, hidden_channels=readout_hidden_dim) self.out_dim = out_dim self.hidden_channels = hidden_channels self.readout_hidden_dim = readout_hidden_dim self.n_update_layers = n_update_layers self.concat_hidden = concat_hidden
def data(): numpy.random.seed(0) atom_data = numpy.random.randint(0, high=MAX_ATOMIC_NUM, size=(batch_size, atom_size)).astype('i') adj_data = numpy.random.randint(0, high=2, size=(batch_size, atom_size, atom_size)).astype('f') y_grad = numpy.random.uniform( -1, 1, (batch_size, atom_size, hidden_channels)).astype('f') embed = EmbedAtomID(in_size=MAX_ATOMIC_NUM, out_size=hidden_channels) embed_atom_data = embed(atom_data).data degree_mat = numpy.sum(adj_data, axis=1) deg_conds = numpy.array([ numpy.broadcast_to(((degree_mat - degree) == 0)[:, :, None], embed_atom_data.shape) for degree in range(1, num_degree_type + 1) ]) return embed_atom_data, adj_data, deg_conds, y_grad
def __init__(self, out_channels=64, num_edge_type=4, ch_list=None, n_atom_types=MAX_ATOMIC_NUM, input_type='int', scale_adj=False, activation=F.tanh): super(RelGCN, self).__init__() ch_list = ch_list or [16, 128, 64] # ch_list = [in_channels] + ch_list with self.init_scope(): if input_type == 'int': self.embed = EmbedAtomID(out_size=ch_list[0], in_size=n_atom_types) elif input_type == 'float': self.embed = GraphLinear(None, ch_list[0]) else: raise ValueError("[ERROR] Unexpected value input_type={}".format(input_type)) self.rgcn_convs = chainer.ChainList(*[ RelGCNUpdate(ch_list[i], ch_list[i+1], num_edge_type) for i in range(len(ch_list)-1)]) self.rgcn_readout = RelGCNReadout(ch_list[-1], out_channels) # self.num_relations = num_edge_type self.input_type = input_type self.scale_adj = scale_adj self.activation = activation
def __init__(self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, weight_tying=True): super(GGNN, self).__init__() n_readout_layer = n_layers if concat_hidden else 1 n_message_layer = 1 if weight_tying else n_layers with self.init_scope(): # Update self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.update_layer = GGNNUpdate( hidden_dim=hidden_dim, n_layers=n_message_layer, n_atom_types=self.NUM_EDGE_TYPE, weight_tying=weight_tying) # Readout self.readout_layer = GGNNReadout( out_dim=out_dim, hidden_dim=hidden_dim, n_layers=n_readout_layer, concat_hidden=concat_hidden, activation=functions.identity) self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.concat_hidden = concat_hidden self.weight_tying = weight_tying
def __init__(self, out_dim=64, hidden_channels=None, n_update_layers=None, n_atom_types=MAX_ATOMIC_NUM, n_edge_types=4, input_type='int', scale_adj=False): super(RelGCNSparse, self).__init__() if hidden_channels is None: hidden_channels = [16, 128, 64] elif isinstance(hidden_channels, int): if not isinstance(n_update_layers, int): raise ValueError( 'Must specify n_update_layers when hidden_channels is int') hidden_channels = [hidden_channels] * n_update_layers with self.init_scope(): if input_type == 'int': self.embed = EmbedAtomID(out_size=hidden_channels[0], in_size=n_atom_types) elif input_type == 'float': self.embed = Linear(None, hidden_channels[0]) else: raise ValueError( "[ERROR] Unexpected value input_type={}".format( input_type)) self.rgcn_convs = chainer.ChainList(*[ RelGCNSparseUpdate(hidden_channels[i], hidden_channels[i + 1], n_edge_types) for i in range(len(hidden_channels) - 1) ]) self.rgcn_readout = ScatterGGNNReadout( out_dim=out_dim, in_channels=hidden_channels[-1], nobias=True, activation=functions.tanh) # self.num_relations = num_edge_type self.input_type = input_type self.scale_adj = scale_adj
atom_size = 5 out_dim = 4 batch_size = 3 heads = 2 hidden_dim = 16 atom_data = numpy.random.randint(0, high=110, size=(batch_size, atom_size)).astype(numpy.int32) adj_data = numpy.random.randint(0, high=2, size=(batch_size, atom_size, atom_size)).astype(numpy.float32) embed = EmbedAtomID(out_size=hidden_dim, in_size=110) weight = GraphLinear(hidden_dim, heads * hidden_dim) att_weight = GraphLinear(hidden_dim * 2, 1) def test(atom_array, adj_data): x = embed(atom_array) mb, atom, ch = x.shape print(x.shape) test = weight(x) print(test.shape) x = functions.expand_dims(test, axis=1) print(x.shape) x = functions.broadcast_to(x, (mb, atom, atom, heads * ch)) print(x.shape) y = functions.copy(x, -1)
def __init__( self, out_dim, hidden_dim=16, n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, layer_aggregator=None, dropout_rate=0.0, batch_normalization=False, weight_tying=True, use_attention=False, update_attention=False, attention_tying=True, context=False, context_layers=1, context_dropout=0., message_function='matrix_multiply', edge_hidden_dim=16, readout_function='graph_level', num_timesteps=3, num_output_hidden_layers=0, output_hidden_dim=16, output_activation=functions.relu, output_atoms=False, ): super(GGNN, self).__init__() n_readout_layer = n_layers if concat_hidden else 1 n_message_layer = 1 if weight_tying else n_layers n_attention_layer = 1 if attention_tying else n_layers self.n_readout_layer = n_readout_layer self.n_message_layer = n_message_layer self.n_attention_layer = n_attention_layer self.out_dim = out_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.concat_hidden = concat_hidden self.layer_aggregator = layer_aggregator self.dropout_rate = dropout_rate self.batch_normalization = batch_normalization self.weight_tying = weight_tying self.use_attention = use_attention self.update_attention = update_attention self.attention_tying = attention_tying self.context = context self.context_layers = context_layers self.context_dropout = context_dropout self.message_functinon = message_function self.edge_hidden_dim = edge_hidden_dim self.readout_function = readout_function self.num_timesteps = num_timesteps self.num_output_hidden_layers = num_output_hidden_layers self.output_hidden_dim = output_hidden_dim self.output_activation = output_activation self.output_atoms = output_atoms with self.init_scope(): # Update self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) self.message_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim) for _ in range(n_message_layer) ]) if self.message_functinon == 'edge_network': del self.message_layers self.message_layers = chainer.ChainList(*[ EdgeNetwork(in_dim=self.NUM_EDGE_TYPE, hidden_dim=self.edge_hidden_dim, node_dim=self.hidden_dim) for _ in range(n_message_layer) ]) if self.context: self.context_bilstm = links.NStepBiLSTM( n_layers=self.context_layers, in_size=self.hidden_dim, out_size=self.hidden_dim / 2, dropout=context_dropout) # self-attention layer if use_attention or update_attention: # these commented layers are written for GAT impelmented by TensorFlow. # self.linear_transform_layer = chainer.ChainList( # *[links.ConvolutionND(1, in_channels=hidden_dim, out_channels=hidden_dim, ksize=1, nobias=True) # for _ in range(n_attention_layer)] # ) # self.conv1d_layer_1 = chainer.ChainList( # *[links.ConvolutionND(1, in_channels=hidden_dim, out_channels=1, ksize=1) # for _ in range(n_attention_layer)] # ) # self.conv1d_layer_2 = chainer.ChainList( # *[links.ConvolutionND(1, in_channels=hidden_dim, out_channels=1, ksize=1) # for _ in range(n_attention_layer)] # ) self.linear_transform_layer = chainer.ChainList(*[ links.Linear( in_size=hidden_dim, out_size=hidden_dim, nobias=True) for _ in range(n_attention_layer) ]) self.neural_network_layer = chainer.ChainList(*[ links.Linear( in_size=2 * self.hidden_dim, out_size=1, nobias=True) for _ in range(n_attention_layer) ]) # batch normalization if batch_normalization: self.batch_normalization_layer = links.BatchNormalization( size=hidden_dim) self.update_layer = links.GRU(2 * hidden_dim, hidden_dim) # Readout self.i_layers = chainer.ChainList(*[ GraphLinear(2 * hidden_dim, out_dim) for _ in range(n_readout_layer) ]) self.j_layers = chainer.ChainList(*[ GraphLinear(hidden_dim, out_dim) for _ in range(n_readout_layer) ]) if self.readout_function == 'set2vec': del self.i_layers, self.j_layers # def __init__(self, node_dim, output_dim, num_timesteps=3, inner_prod='default', # num_output_hidden_layers=0, output_hidden_dim=16, activation=chainer.functions.relu): self.readout_layer = chainer.ChainList(*[ Set2Vec(node_dim=self.hidden_dim * 2, output_dim=out_dim, num_timesteps=num_timesteps, num_output_hidden_layers=num_output_hidden_layers, output_hidden_dim=output_hidden_dim, activation=output_activation) for _ in range(n_readout_layer) ]) if self.layer_aggregator: self.construct_layer_aggregator() if self.layer_aggregator == 'gru-attn' or 'gru': self.bigru_layer = links.NStepBiGRU( n_layers=1, in_size=self.hidden_dim, out_size=self.hidden_dim, dropout=0.) if self.layer_aggregator == 'lstm-attn' or 'lstm': self.bilstm_layer = links.NStepBiLSTM( n_layers=1, in_size=self.hidden_dim, out_size=self.hidden_dim, dropout=0.) if self.layer_aggregator == 'gru-attn' or 'lstm-attn' or 'attn': self.attn_dense_layer = links.Linear( in_size=self.n_layers, out_size=self.n_layers) if self.layer_aggregator == 'self-attn': self.attn_linear_layer = links.Linear( in_size=self.n_layers, out_size=self.n_layers) if self.output_atoms: self.atoms = None