def __call__(self, x, edge_index, size=None, **kwargs): norm = self.norm(edge_index, size) gather_x, gather_norm, = self.gather_feature([x, norm], edge_index) out = self.apply_edge(gather_x[1], gather_norm[0], gather_norm[1]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out) return out
def to_dense_batch(x, batch): assert batch is not None batch_size = tf.reduce_max(batch) + 1 num_nodes = mp_ops.scatter_('add', tf.ones([tf.shape(batch)[0], 1]), batch, batch_size) num_nodes = tf.cast(tf.reshape(num_nodes, [-1]), dtype=tf.int32) cum_nodes = tf.concat([tf.zeros(1, dtype=tf.int32), tf.cumsum(num_nodes, axis=0)], axis=0) max_num_nodes = tf.reduce_max(num_nodes) idx = tf.range(tf.reduce_sum(num_nodes)) n = tf.gather(cum_nodes, batch) idx = idx - n + batch * max_num_nodes idx = tf.reshape(idx, [-1, 1]) size = [batch_size * max_num_nodes, tf.shape(x)[-1]] out = tf.scatter_nd(idx, x, shape=size) out_size = [batch_size, max_num_nodes, tf.shape(x)[-1]] out = tf.reshape(out, out_size) mask = tf.scatter_nd(idx, tf.ones(tf.shape(batch)[0]), shape=[batch_size * max_num_nodes]) return out, out_size, mask
def __call__(self, inputs, index, size=None): size = tf.reduce_max(index) + 1 if size is None else size gate = self.gate_nn(inputs) inputs = self.nn(inputs) if self.nn is not None else inputs gate = mp_ops.scatter_softmax(gate, index, size=size) outputs = mp_ops.scatter_(self.aggr, gate * inputs, index, size=size) return outputs
def __call__(self, x, edge_index, size=None, edge_attr=None, **kwarg): assert edge_attr is not None if not self.built: self.build() gather_x, = self.gather_feature([x], edge_index) out = self.apply_edge(gather_x[1], edge_attr) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(x[0], out) return out
def __call__(self, x, edge_index, size=None, **kwargs): norm = self.norm(edge_index, size) hidden = x for k in range(self.K): x, gather_norm, = self.gather_feature([x, norm], edge_index) out = self.apply_edge(x[1], gather_norm[0], gather_norm[1]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out, hidden[0]) x = [out, hidden[1]] return out
def __call__(self, x, edge_index, size=None, **kwargs): h = [ None if x[0] is None else self.fc(x[0]), None if x[1] is None else self.fc(x[1]) ] gather_x, gather_h = self.gather_feature([x, h], edge_index) out = self.apply_edge(gather_h[1]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out, x[0]) return out
def __call__(self, x, edge_index, size=None, **kwargs): norm = self.norm(edge_index, size) origin = x for t in range(self.T): x, gather_norm, = self.gather_feature([x, norm], edge_index) out = self.apply_edge(x[1], gather_norm[0], gather_norm[1], t) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out, origin[0], t) x = [out, origin[1]] out = tf.reshape(out, [-1, self.K, self.dim]) return tf.reduce_mean(out, axis=1)
def __call__(self, x, edge_index, size=None, **kwargs): if isinstance(x, tf.Tensor): x = self.fc(x) else: x = (None if x[0] is None else self.fc(x[0]), None if x[1] is None else self.fc(x[1])) gather_x, = self.gather_feature([x], edge_index) out = self.apply_edge(gather_x[0], gather_x[1], edge_index[0], size[0]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out, x[0]) return out
def __call__(self, x, edge_index, size=None, **kwargs): if not self.build: self.build = True if self.train_eps: self.eps = tf.Variable([self.eps_value], name='eps', dtype=tf.float32) else: self.eps = self.eps_value gather_x, = self.gather_feature([x], edge_index) out = self.apply_edge(gather_x[1]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out, x[0]) return out
def __call__(self, x, edge_index, size=None, **kwargs): norm = self.norm(edge_index, size) filter_out = x[0] for k in range(self.K): gather_x, gather_norm, = \ self.gather_feature([[filter_out, x[1]], norm], edge_index) filter_out = self.apply_edge(gather_x[1], gather_norm[0], gather_norm[1]) filter_out = mp_ops.scatter_(self.aggr, filter_out, edge_index[0], size=size[0]) filter_out = self.apply_node(filter_out) out = self.fc(filter_out) return out
def __call__(self, x, edge_index, size=None, **kwargs): if not self.build: self.build = True self.beta = tf.Variable([1.], name='beta', dtype=tf.float32) norm = \ [tf.nn.l2_normalize(x[0], axis=-1) if x[0] is not None else None, tf.nn.l2_normalize(x[1], axis=-1) if x[1] is not None else None] gather_x, gather_norm, = self.gather_feature([x, norm], edge_index) out = self.apply_edge(edge_index[0], gather_x[1], gather_norm[0], gather_norm[1], size[0]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out) return out
def __call__(self, inputs, index, size=None): size = tf.reduce_max(index) + 1 if size is None else size cell_in = tf.zeros([size, self.dim * 2], dtype=tf.float32) hidden_state = self.lstm.zero_state(tf.shape(cell_in)[0], dtype=tf.float32) for i in range(self.processing_steps): q = tf.expand_dims(cell_in, axis=1) q, hidden_state = tf.nn.dynamic_rnn(self.lstm, q, initial_state=hidden_state, dtype=tf.float32) q = tf.reshape(q, [-1, self.dim]) e = tf.reduce_sum((inputs * tf.gather(q, index)), axis=-1, keep_dims=True) a = mp_ops.scatter_softmax(e, index, size=size) r = mp_ops.scatter_(self.aggr, a * inputs, index, size=size) cell_in = tf.reshape(tf.concat([q, r], axis=-1), [-1, self.dim * 2]) return cell_in
def __call__(self, x, edge_index, size=None, **kwargs): h = x for i in range(self.processing_steps): m = [ None if h[0] is None else self.fc[i](h[0]), None if h[1] is None else self.fc[i](h[1]) ] gather_x, = self.gather_feature([m], edge_index) out = self.apply_edge(gather_x[1]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out) out = tf.expand_dims(out, axis=1) hidden_state = [h[0] for _ in range(self.lstm_layers)] with tf.variable_scope('rnn', reuse=tf.AUTO_REUSE): out, _ = tf.nn.dynamic_rnn(self.rnn, out, initial_state=tuple(hidden_state), dtype=tf.float32) out = tf.reshape(out, [-1, self.dim]) h = [out, h[1]] return out
def __call__(self, x, edge_index, size=None, **kwargs): gather_x, = self.gather_feature([x], edge_index) out = self.apply_edge(gather_x[1]) out = mp_ops.scatter_(self.aggr, out, edge_index[0], size=size[0]) out = self.apply_node(out, x[0]) return out
def __call__(self, inputs, index, size=None): size = tf.reduce_max(index) + 1 if size is None else size out = scatter_(self.aggr, inputs, index, size) return out