コード例 #1
0
    def __init__(self, dim, fixed_neigh=False):
        super().__init__()
        self.fixed_neigh = fixed_neigh

        self.broad_ones = P.Ones()((1, 1, dim), ms.int32)

        if fixed_neigh:
            self.gatherd = None
        else:
            self.gatherd = P.GatherD()
コード例 #2
0
ファイル: neighbors.py プロジェクト: helloyesterday/AirNet
    def __init__(self, fixed_atoms=False, dim=3):
        super().__init__()
        self.fixed_atoms = fixed_atoms
        self.reducesum = P.ReduceSum()
        self.pow = P.Pow()
        # self.concat = P.Concat()
        # self.pack = P.Pack()
        self.gatherd = P.GatherD()
        self.norm = nn.Norm(-1)

        self.gather_neighbors = GatherNeighbors(dim, fixed_atoms)
コード例 #3
0
ファイル: loss.py プロジェクト: chncwang/mindspore
    def __init__(self, weight=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__(reduction=reduction)

        self.gamma = validator.check_value_type("gamma", gamma, [float])
        if weight is not None and not isinstance(weight, Tensor):
            raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight)))
        self.weight = weight
        self.expand_dims = P.ExpandDims()
        self.gather_d = P.GatherD()
        self.squeeze = P.Squeeze(axis=1)
        self.tile = P.Tile()
        self.cast = P.Cast()
コード例 #4
0
 def __init__(self, dim=1):
     super(NetGatherD, self).__init__()
     self.gatherd = P.GatherD()
     self.dim = int(dim)
コード例 #5
0
 def __init__(self, dim=0):
     super(GatherDNet, self).__init__()
     self.gather_d = P.GatherD()
     self.dim = dim
コード例 #6
0
 def __init__(self, dim=0):
     super(Net, self).__init__()
     self.op = P.GatherD()
     self.dim = dim