예제 #1
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     norm = self.norm(edge_index, size)
     gather_x, gather_norm, = self.gather_feature([x, norm], edge_index)
     out = self.apply_edge(gather_x[1], gather_norm[0], gather_norm[1])
     out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
     out = self.apply_node(out)
     return out
예제 #2
0
def to_dense_batch(x, batch):
    assert batch is not None
    batch_size = tf.reduce_max(batch) + 1
    num_nodes = mp_ops.scatter_('add', tf.ones([tf.shape(batch)[0], 1]),
                                batch, batch_size)
    num_nodes = tf.cast(tf.reshape(num_nodes, [-1]), dtype=tf.int32)

    cum_nodes = tf.concat([tf.zeros(1, dtype=tf.int32),
                           tf.cumsum(num_nodes, axis=0)], axis=0)
    max_num_nodes = tf.reduce_max(num_nodes)

    idx = tf.range(tf.reduce_sum(num_nodes))

    n = tf.gather(cum_nodes, batch)
    idx = idx - n + batch * max_num_nodes

    idx = tf.reshape(idx, [-1, 1])

    size = [batch_size * max_num_nodes, tf.shape(x)[-1]]

    out = tf.scatter_nd(idx, x, shape=size)

    out_size = [batch_size, max_num_nodes, tf.shape(x)[-1]]
    out = tf.reshape(out, out_size)

    mask = tf.scatter_nd(idx,
                         tf.ones(tf.shape(batch)[0]),
                         shape=[batch_size * max_num_nodes])
    return out, out_size, mask
예제 #3
0
    def __call__(self, inputs, index, size=None):
        size = tf.reduce_max(index) + 1 if size is None else size

        gate = self.gate_nn(inputs)
        inputs = self.nn(inputs) if self.nn is not None else inputs

        gate = mp_ops.scatter_softmax(gate, index, size=size)
        outputs = mp_ops.scatter_(self.aggr, gate * inputs, index, size=size)
        return outputs
예제 #4
0
 def __call__(self, x, edge_index, size=None, edge_attr=None, **kwarg):
     assert edge_attr is not None
     if not self.built:
         self.build()
     gather_x, = self.gather_feature([x], edge_index)
     out = self.apply_edge(gather_x[1], edge_attr)
     out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
     out = self.apply_node(x[0], out)
     return out
예제 #5
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     norm = self.norm(edge_index, size)
     hidden = x
     for k in range(self.K):
         x, gather_norm, = self.gather_feature([x, norm], edge_index)
         out = self.apply_edge(x[1], gather_norm[0], gather_norm[1])
         out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
         out = self.apply_node(out, hidden[0])
         x = [out, hidden[1]]
     return out
예제 #6
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     h = [
         None if x[0] is None else self.fc(x[0]),
         None if x[1] is None else self.fc(x[1])
     ]
     gather_x, gather_h = self.gather_feature([x, h], edge_index)
     out = self.apply_edge(gather_h[1])
     out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
     out = self.apply_node(out, x[0])
     return out
예제 #7
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     norm = self.norm(edge_index, size)
     origin = x
     for t in range(self.T):
         x, gather_norm, = self.gather_feature([x, norm], edge_index)
         out = self.apply_edge(x[1], gather_norm[0], gather_norm[1], t)
         out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
         out = self.apply_node(out, origin[0], t)
         x = [out, origin[1]]
     out = tf.reshape(out, [-1, self.K, self.dim])
     return tf.reduce_mean(out, axis=1)
예제 #8
0
    def __call__(self, x, edge_index, size=None, **kwargs):
        if isinstance(x, tf.Tensor):
            x = self.fc(x)
        else:
            x = (None if x[0] is None else self.fc(x[0]),
                 None if x[1] is None else self.fc(x[1]))

        gather_x, = self.gather_feature([x], edge_index)
        out = self.apply_edge(gather_x[0], gather_x[1], edge_index[0], size[0])
        out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
        out = self.apply_node(out, x[0])
        return out
예제 #9
0
파일: gin_conv.py 프로젝트: zonghua94/euler
 def __call__(self, x, edge_index, size=None, **kwargs):
     if not self.build:
         self.build = True
         if self.train_eps:
             self.eps = tf.Variable([self.eps_value],
                                    name='eps',
                                    dtype=tf.float32)
         else:
             self.eps = self.eps_value
     gather_x, = self.gather_feature([x], edge_index)
     out = self.apply_edge(gather_x[1])
     out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
     out = self.apply_node(out, x[0])
     return out
예제 #10
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     norm = self.norm(edge_index, size)
     filter_out = x[0]
     for k in range(self.K):
         gather_x, gather_norm, = \
             self.gather_feature([[filter_out, x[1]], norm], edge_index)
         filter_out = self.apply_edge(gather_x[1], gather_norm[0],
                                      gather_norm[1])
         filter_out = mp_ops.scatter_(self.aggr,
                                      filter_out,
                                      edge_index[0],
                                      size=size[0])
         filter_out = self.apply_node(filter_out)
     out = self.fc(filter_out)
     return out
예제 #11
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     if not self.build:
         self.build = True
         self.beta = tf.Variable([1.], name='beta', dtype=tf.float32)
     norm = \
         [tf.nn.l2_normalize(x[0], axis=-1) if x[0] is not None else None,
          tf.nn.l2_normalize(x[1], axis=-1) if x[1] is not None else None]
     gather_x, gather_norm, = self.gather_feature([x, norm], edge_index)
     out = self.apply_edge(edge_index[0],
                           gather_x[1],
                           gather_norm[0],
                           gather_norm[1],
                           size[0])
     out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
     out = self.apply_node(out)
     return out
예제 #12
0
 def __call__(self, inputs, index, size=None):
     size = tf.reduce_max(index) + 1 if size is None else size
     cell_in = tf.zeros([size, self.dim * 2], dtype=tf.float32)
     hidden_state = self.lstm.zero_state(tf.shape(cell_in)[0],
                                         dtype=tf.float32)
     for i in range(self.processing_steps):
         q = tf.expand_dims(cell_in, axis=1)
         q, hidden_state = tf.nn.dynamic_rnn(self.lstm,
                                             q,
                                             initial_state=hidden_state,
                                             dtype=tf.float32)
         q = tf.reshape(q, [-1, self.dim])
         e = tf.reduce_sum((inputs * tf.gather(q, index)),
                           axis=-1,
                           keep_dims=True)
         a = mp_ops.scatter_softmax(e, index, size=size)
         r = mp_ops.scatter_(self.aggr, a * inputs, index, size=size)
         cell_in = tf.reshape(tf.concat([q, r], axis=-1),
                              [-1, self.dim * 2])
     return cell_in
예제 #13
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     h = x
     for i in range(self.processing_steps):
         m = [
             None if h[0] is None else self.fc[i](h[0]),
             None if h[1] is None else self.fc[i](h[1])
         ]
         gather_x, = self.gather_feature([m], edge_index)
         out = self.apply_edge(gather_x[1])
         out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
         out = self.apply_node(out)
         out = tf.expand_dims(out, axis=1)
         hidden_state = [h[0] for _ in range(self.lstm_layers)]
         with tf.variable_scope('rnn', reuse=tf.AUTO_REUSE):
             out, _ = tf.nn.dynamic_rnn(self.rnn,
                                        out,
                                        initial_state=tuple(hidden_state),
                                        dtype=tf.float32)
         out = tf.reshape(out, [-1, self.dim])
         h = [out, h[1]]
     return out
예제 #14
0
 def __call__(self, x, edge_index, size=None, **kwargs):
     gather_x, = self.gather_feature([x], edge_index)
     out = self.apply_edge(gather_x[1])
     out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0])
     out = self.apply_node(out, x[0])
     return out
예제 #15
0
 def __call__(self, inputs, index, size=None):
     size = tf.reduce_max(index) + 1 if size is None else size
     out = scatter_(self.aggr, inputs, index, size)
     return out