def __set_size__(self, size: List[Optional[int]], dim: int, src: Var): the_size = size[dim] if the_size is None: size[dim] = src.size(self.node_dim) elif the_size != src.size(self.node_dim): raise ValueError( (f'Encountered Var with size {src.size(self.node_dim)} in ' f'dimension {self.node_dim}, but expected size {the_size}.'))
def scatter(x: jt.Var, dim: int, index: jt.Var, src: jt.Var, reduce='void'): ''' if x is a 3-D array, rewrite x like: self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 Parameters:: * x (jt.Var) – input array * dim (int) – the axis along which to index * index (jt.Var) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged. * src (jt.Var) – the source element(s) to scatter. * reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'. Example:: src = jt.arange(1, 11).reshape((2, 5)) index = jt.array([[0, 1, 2, 0]]) x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) assert (x.data == [[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]]).all() index = jt.array([[0, 1, 2], [0, 1, 4]]) x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src) assert (x.data == [[1, 2, 3, 0, 0], [6, 7, 0, 0, 8], [0, 0, 0, 0, 0]]).all() x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), jt.array(1.23), reduce='multiply') assert np.allclose(x.data, [[2.0000, 2.0000, 2.4600, 2.0000], [2.0000, 2.0000, 2.0000, 2.4600]]), x x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), jt.array(1.23), reduce='add') assert np.allclose(x.data, [[2.0000, 2.0000, 3.2300, 2.0000], [2.0000, 2.0000, 2.0000, 3.2300]]) ''' shape = index.shape if src.shape != shape and src.numel() != 1: src = src[tuple(slice(None, s) for s in shape)] indexes = [f'i{i}' for i in range(len(shape))] indexes[dim] = index return x.setitem(tuple(indexes), src, reduce)
def execute(self, x: Var, edge_index: Adj, edge_weight: OptVar = None) -> Var: """""" if self.normalize: if isinstance(edge_index, Var): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm(edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] x = x @ self.weight out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.bias is not None: out += self.bias return out
def execute(self, input_var: jittor.Var): input_var = input_var.clone().detach() out = self.net(input_var) return { OUT_KEY: out, IN_KEY: input_var }
def execute(self, x, edge_index, edge_weight: OptVar = None, batch: OptVar = None, lambda_max: OptVar = None): """""" if self.normalization != 'sym' and lambda_max is None: raise ValueError('You need to pass `lambda_max` to `execute() in`' 'case the normalization is non-symmetric.') if lambda_max is None: lambda_max = Var([2.0]) if not isinstance(lambda_max, Var): lambda_max = Var([lambda_max]) assert lambda_max is not None edge_index, norm = self.__norm__(edge_index, x.size(self.node_dim), edge_weight, self.normalization, lambda_max, dtype=x.dtype, batch=batch) Tx_0 = x # Tx_1 = x # Dummy. out = jt.matmul(Tx_0, self.weight[0]) # print('self weight:', self.weight) if self.weight.size(0) > 1: # print('norm: ', norm.shape, # norm.min(), norm.max()) Tx_1 = self.propagate(edge_index, x=x, norm=norm, size=None) # print('Tx_1: ', Tx_1.shape, Tx_1.min(), Tx_1.max()) out = out + jt.matmul(Tx_1, self.weight[1]) for k in range(2, self.weight.size(0)): Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm, size=None) Tx_2 = 2. * Tx_2 - Tx_0 out = out + jt.matmul(Tx_2, self.weight[k]) Tx_0, Tx_1 = Tx_1, Tx_2 if self.bias is not None: out += self.bias return out
def no_inf_mean(x: jt.Var): """ Computes the mean of a vector, throwing out all inf values. If there are no non-inf values, this will return inf (i.e., just the normal mean). """ no_inf = [a for a in x if np.isfinite(a.numpy())] if len(no_inf) > 0: return sum(no_inf) / len(no_inf) else: return x.mean()
def execute(self, x: Var, edge_index: Adj, edge_weight: OptVar = None) -> Var: """""" cache = self._cached_x if cache is None: if isinstance(edge_index, Var): edge_index, edge_weight = gcn_norm(edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) for k in range(self.K): x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.cached: self._cached_x = x else: x = cache return self.lin(x)
def message(self, x_j: Var, edge_weight: Var) -> Var: return edge_weight.view(-1, 1) * x_j