Esempio n. 1
0
 def __set_size__(self, size: List[Optional[int]], dim: int, src: Var):
     the_size = size[dim]
     if the_size is None:
         size[dim] = src.size(self.node_dim)
     elif the_size != src.size(self.node_dim):
         raise ValueError(
             (f'Encountered Var with size {src.size(self.node_dim)} in '
              f'dimension {self.node_dim}, but expected size {the_size}.'))
Esempio n. 2
0
def scatter(x: jt.Var, dim: int, index: jt.Var, src: jt.Var, reduce='void'):
    ''' if x is a 3-D array, rewrite x like:

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

Parameters::

    * x (jt.Var) – input array
    * dim (int) – the axis along which to index
    * index (jt.Var) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
    * src (jt.Var) – the source element(s) to scatter.
    * reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'.

Example::

    src = jt.arange(1, 11).reshape((2, 5))
    index = jt.array([[0, 1, 2, 0]])
    x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src)
    assert (x.data == 
        [[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]]).all()
    index = jt.array([[0, 1, 2], [0, 1, 4]])
    x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src)
    assert (x.data ==
        [[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]]).all()
    x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
            jt.array(1.23), reduce='multiply')
    assert np.allclose(x.data, 
        [[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]]), x
    x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
            jt.array(1.23), reduce='add')
    assert np.allclose(x.data,
        [[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

    '''
    shape = index.shape
    if src.shape != shape and src.numel() != 1:
        src = src[tuple(slice(None, s) for s in shape)]
    indexes = [f'i{i}' for i in range(len(shape))]
    indexes[dim] = index
    return x.setitem(tuple(indexes), src, reduce)
Esempio n. 3
0
    def execute(self,
                x: Var,
                edge_index: Adj,
                edge_weight: OptVar = None) -> Var:
        """"""

        if self.normalize:
            if isinstance(edge_index, Var):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(edge_index, edge_weight,
                                                       x.size(self.node_dim),
                                                       self.improved,
                                                       self.add_self_loops)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]
        x = x @ self.weight
        out = self.propagate(edge_index,
                             x=x,
                             edge_weight=edge_weight,
                             size=None)
        if self.bias is not None:
            out += self.bias

        return out
Esempio n. 4
0
 def execute(self, input_var: jittor.Var):
     input_var = input_var.clone().detach()
     out = self.net(input_var)
     return {
         OUT_KEY: out,
         IN_KEY: input_var
     }
Esempio n. 5
0
    def execute(self,
                x,
                edge_index,
                edge_weight: OptVar = None,
                batch: OptVar = None,
                lambda_max: OptVar = None):
        """"""
        if self.normalization != 'sym' and lambda_max is None:
            raise ValueError('You need to pass `lambda_max` to `execute() in`'
                             'case the normalization is non-symmetric.')

        if lambda_max is None:
            lambda_max = Var([2.0])
        if not isinstance(lambda_max, Var):
            lambda_max = Var([lambda_max])
        assert lambda_max is not None

        edge_index, norm = self.__norm__(edge_index,
                                         x.size(self.node_dim),
                                         edge_weight,
                                         self.normalization,
                                         lambda_max,
                                         dtype=x.dtype,
                                         batch=batch)

        Tx_0 = x
        # Tx_1 = x  # Dummy.
        out = jt.matmul(Tx_0, self.weight[0])
        # print('self weight:', self.weight)
        if self.weight.size(0) > 1:
            # print('norm: ', norm.shape,
            #       norm.min(), norm.max())
            Tx_1 = self.propagate(edge_index, x=x, norm=norm, size=None)
            # print('Tx_1: ', Tx_1.shape, Tx_1.min(), Tx_1.max())
            out = out + jt.matmul(Tx_1, self.weight[1])

        for k in range(2, self.weight.size(0)):
            Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm, size=None)
            Tx_2 = 2. * Tx_2 - Tx_0
            out = out + jt.matmul(Tx_2, self.weight[k])
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out += self.bias
        return out
Esempio n. 6
0
def no_inf_mean(x: jt.Var):
    """
    Computes the mean of a vector, throwing out all inf values.
    If there are no non-inf values, this will return inf (i.e., just the normal mean).
    """

    no_inf = [a for a in x if np.isfinite(a.numpy())]

    if len(no_inf) > 0:
        return sum(no_inf) / len(no_inf)
    else:
        return x.mean()
Esempio n. 7
0
 def execute(self,
             x: Var,
             edge_index: Adj,
             edge_weight: OptVar = None) -> Var:
     """"""
     cache = self._cached_x
     if cache is None:
         if isinstance(edge_index, Var):
             edge_index, edge_weight = gcn_norm(edge_index,
                                                edge_weight,
                                                x.size(self.node_dim),
                                                False,
                                                self.add_self_loops,
                                                dtype=x.dtype)
         for k in range(self.K):
             x = self.propagate(edge_index,
                                x=x,
                                edge_weight=edge_weight,
                                size=None)
             if self.cached:
                 self._cached_x = x
     else:
         x = cache
     return self.lin(x)
Esempio n. 8
0
 def message(self, x_j: Var, edge_weight: Var) -> Var:
     return edge_weight.view(-1, 1) * x_j