def mfunc(edges): op = getattr(F, binary_op) lhs_data = target_switch(edges, lhs)[lhs] rhs_data = target_switch(edges, rhs)[rhs] # NOTE(zihao): we need to do batched broadcast # e.g. (68, 3, 1) op (68, 5, 3, 4) while F.ndim(lhs_data) < F.ndim(rhs_data): lhs_data = F.unsqueeze(lhs_data, 1) while F.ndim(rhs_data) < F.ndim(lhs_data): rhs_data = F.unsqueeze(rhs_data, 1) return {"m": op(lhs_data, rhs_data)}
def __call__(self, edges): sdata = edges.src[self.src_field] edata = edges.data[self.edge_field] # Due to the different broadcasting semantics of different backends, # we need to broadcast the sdata and edata to be of the same rank. rank = max(F.ndim(sdata), F.ndim(edata)) sshape = F.shape(sdata) eshape = F.shape(edata) sdata = F.reshape(sdata, sshape + (1, ) * (rank - F.ndim(sdata))) edata = F.reshape(edata, eshape + (1, ) * (rank - F.ndim(edata))) ret = self.mul_op(sdata, edata) return {self.out_field: ret}
def softmax(x): ndim = K.ndim(x) if ndim == 2: return K.softmax(x) elif ndim == 3: e = K.exp(x - K.max(x, axis=-1, keepdims=True)) s = K.sum(e, axis=-1, keepdims=True) return e / s else: raise ValueError('Cannot apply softmax to a tensor ' 'that is not 2D or 3D. ' 'Here, ndim=' + str(ndim))
def reduce_func(nodes): msgs = nodes.mailbox['m'] reduce_msg_shapes.add(tuple(msgs.shape)) assert F.ndim(msgs) == 3 assert F.shape(msgs)[2] == D return {'accum': F.sum(msgs, 1)}
def message_func(edges): assert F.ndim(edges.src['h']) == 2 assert F.shape(edges.src['h'])[1] == D return {'m': edges.src['h']}