예제 #1
0
 def construct(self, x):
     square_sum = self.hyper_map(get_square_sum, x)
     global_norm = F.sqrt(F.addn(square_sum))
     cond = self.greater_equal(global_norm, self.clip_norm)
     global_norm = F.select(cond, global_norm, self.clip_norm)
     clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x)
     return clip_x
예제 #2
0
 def construct(self, grads):
     global_norm = self.global_norm(grads)
     cond = P.GreaterEqual()(global_norm, self.clip_norm)
     global_norm = F.select(cond, global_norm, self.clip_norm)
     grads = self.hyper_map(
         F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
     return grads
예제 #3
0
파일: loss.py 프로젝트: dongkcs/mindspore
    def construct(self, x1, x2, y):
        F.same_type_shape(x1, x2)
        _check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name)
        # if target > 0, 1-cosine(x1, x2)
        # else, max(0, cosine(x1, x2)-margin)
        prod_sum = self.reduce_sum(x1 * x2, (1,))
        square1 = self.reduce_sum(F.square(x1), (1,))
        square2 = self.reduce_sum(F.square(x2), (1,))
        denom = F.sqrt(square1 * square2)
        cosine = prod_sum / denom

        pos_value = 1.0 - cosine
        neg_value = self.maximum(cosine - self.margin, 0.0)
        zeros = F.zeros_like(cosine)
        pos_part = F.select(y == 1, pos_value, zeros)
        neg_part = F.select(y == -1, neg_value, zeros)
        output_unreduced = pos_part + neg_part

        return self.get_loss(output_unreduced)
예제 #4
0
    def construct(self, atom_types):
        # [B,1,1]
        exones = self.ones((atom_types.shape[0], 1, 1), ms.int32)
        # broadcast to [B*A*N]: [B,1,1] * [1,A,N]
        exnfc = exones * F.expand_dims(self.nfc, 0)
        exnnc = exones * F.expand_dims(self.nnc, 0)

        tmask = F.select(atom_types > 0, F.ones_like(atom_types),
                         F.ones_like(atom_types) * -1)
        tmask = F.cast(tmask, ms.float32)
        extmask = F.expand_dims(tmask, -1) * self.nones

        mask0 = F.gather(tmask, self.ar0, -1)
        mask0 = F.expand_dims(mask0, -2) * self.eaones
        mask1 = F.gather(tmask, self.ar1, -1)
        mask1 = F.expand_dims(mask1, -2) * self.eaones

        mtmp = F.select(exnfc > exnnc, mask1, mask0)
        mask = F.select(extmask > 0, mtmp, F.ones_like(mtmp) * -1)
        mask = mask > 0

        idx = F.select(mask, exnfc, exnnc)

        return idx, mask
예제 #5
0
    def construct(self, num_atoms):
        # broadcast atom numbers to [B*A*N]
        # a_i: number of atoms in each molecule
        # [[a_0]*A*N,[a_1]*A*N,...,[a_N]*A*N]]
        exnum = num_atoms * self.aones
        exnum = F.expand_dims(exnum, -1) * self.nones

        # [B,1,1]
        exones = self.ones((num_atoms.shape[0], 1, 1), ms.int32)
        # broadcast to [B*A*N]: [B,1,1] * [1,A,N]
        exnfc = exones * F.expand_dims(self.nfc, 0)
        exnnc = exones * F.expand_dims(self.nnc, 0)
        exmat = exones * F.expand_dims(self.mat_idx, 0)

        mask = exmat < exnum

        neighbors = F.select(mask, exnfc, exnnc)

        return neighbors, mask
예제 #6
0
    def __init__(self, tot_atoms):
        super().__init__()
        # tot_atoms: A
        # tot_neigh: N =  A - 1
        tot_neigh = tot_atoms - 1
        arange = nn.Range(tot_atoms)
        nrange = nn.Range(tot_neigh)

        self.ones = P.Ones()
        self.aones = self.ones((tot_atoms), ms.int32)
        self.nones = self.ones((tot_neigh), ms.int32)

        # neighbors for no connection (A*N)
        # [[0,0,...,0],
        #  [1,1,...,1],
        #  ...........,
        #  [N,N,...,N]]
        self.nnc = F.expand_dims(arange(), -1) * self.nones
        # copy of the index range (A*N)
        # [[0,1,...,N-1],
        #  [0,1,...,N-1],
        #  ...........,
        #  [0,1,...,N-1]]
        crange = self.ones((tot_atoms, 1), ms.int32) * nrange()
        # neighbors for full connection (A*N)
        # [[1,2,3,...,N],
        #  [0,2,3,...,N],
        #  [0,1,3,....N],
        #  .............,
        #  [0,1,2,...,N-1]]
        self.nfc = crange + F.cast(self.nnc <= crange, ms.int32)

        crange1 = crange + 1
        # the matrix for index range (A*N)
        # [[1,2,3,...,N],
        #  [1,2,3,...,N],
        #  [2,2,3,....N],
        #  [3,3,3,....N],
        #  .............,
        #  [N,N,N,...,N]]
        self.mat_idx = F.select(crange1 > self.nnc, crange1, self.nnc)
예제 #7
0
    def construct(self, distance):
        dis = distance / self.d_max

        if self.min_cutoff:
            dis = self.max(dis, self.d_min)

        exdis = F.expand_dims(dis, -1)
        rbfdis = exdis * self.ones

        log_dis = self.log(rbfdis)
        log_diff = log_dis - self.centers
        log_diff2 = F.square(log_diff)
        log_gauss = self.exp(self.rescale * log_diff2)

        if self.max_cutoff:
            ones = self.onesslike(exdis)
            zeros = self.zeroslike(exdis)
            cuts = F.select(exdis < 1.0, ones, zeros)
            log_gauss = log_gauss * cuts

        return log_gauss
예제 #8
0
    def construct(self,
                  positions,
                  neighbors,
                  neighbor_mask=None,
                  cell=None,
                  cell_offsets=None):
        r"""Compute distance of every atom to its neighbors.

        Args:
            positions (ms.Tensor[float]): atomic Cartesian coordinates with
                (N_b x N_at x 3) shape.
            neighbors (ms.Tensor[int]): indices of neighboring atoms to consider
                with (N_b x N_at x N_nbh) or (N_at x N_nbh) shape.
            cell (ms.tensor[float], optional): periodic cell of (N_b x 3 x 3) shape.
            cell_offsets (ms.Tensor[float], optional): offset of atom in cell coordinates
                with (N_b x N_at x N_nbh x 3) shape.
            neighbor_mask (ms.Tensor[bool], optional): boolean mask for neighbor
                positions. Required for the stable computation of forces in
                molecules with different sizes.

        Returns:
            ms.Tensor[float]: layer output of (N_b x N_at x N_nbh) shape.

        """

        pos_xyz = self.gather_neighbors(positions, neighbors)

        # Subtract positions of central atoms to get distance vectors
        dist_vec = pos_xyz - F.expand_dims(positions, -2)

        # distances = self.norm(dist_vec)
        distances = F.square(dist_vec)
        distances = self.reducesum(distances, -1)
        distances = self.pow(distances, 0.5)

        if neighbor_mask is not None:
            distances = F.select(neighbor_mask, distances,
                                 F.ones_like(distances) * 999)

        return distances
예제 #9
0
    def construct(self, query, key, value, cutoff=None, mask=None):
        r"""Compute multi-head attention.

        Args:
            query  (Mindspore.Tensor [B, A, 1, V]):
            key    (Mindspore.Tensor [B, A, N', V]):
            value  (Mindspore.Tensor [B, A, N', V]):
            cutoff (Mindspore.Tensor [B, A, 1, N'] or [B, A, 1, 1, N']):

        Returns:
            Mindspore.Tensor [B, A, V]: multi-head attention output.

        """
        if self.n_heads > 1:
            q_reshape = query.shape[:-1] + self.reshape_tail
            k_reshape = key.shape[:-1] + self.reshape_tail
            v_reshape = value.shape[:-1] + self.reshape_tail

            # [B, A, 1, h, v]
            Q = F.reshape(query, q_reshape)
            # [B, A, h, 1, v]
            Q = self.transpose(Q, self.trans_shape)

            # [B, A, N', h, v]
            K = F.reshape(key, k_reshape)
            # [B, A, h, N', v]
            K = self.transpose(K, self.trans_shape)

            # [B, A, N', h, v]
            V = F.reshape(value, v_reshape)
            # [B, A, h, N', v]
            V = self.transpose(V, self.trans_shape)

            # [B, A, h, 1, v] x [B, A, h, N', v]^T / \sqrt(v)
            # [B, A, h, 1, v] x [B, A, h, v, N'] = [B, A, h, 1, N']
            attention_scores = self.bmmt(Q, K)
            attention_scores = self.mul(attention_scores, self.scores_mul)

            if cutoff is None:
                attention_probs = self.softmax(attention_scores)
            else:
                # [B, A, 1, 1, N']
                exmask = F.expand_dims(F.expand_dims(mask, -2), -2)
                # [B, A, h, 1, N']
                mhmask = exmask * self.exones
                large_neg = F.ones_like(attention_scores) * -5e4
                attention_scores = F.select(mhmask > 0, attention_scores,
                                            large_neg)
                attention_probs = self.softmax(attention_scores)
                excut = F.expand_dims(F.expand_dims(cutoff, -2), -2)
                # [B, A, h, 1, N'] * [B, A, 1, 1, N']
                attention_probs = self.mul(attention_probs, excut)

            # [B, A, h, 1, N'] x [B, A, h, N', v] = [B, A, h, 1, v]
            context = self.bmm(attention_probs, V)
            # [B, A, 1, h, v]
            context = self.transpose(context, self.trans_shape)
            # [B, A, 1, V]
            context = F.reshape(context, query.shape)
        else:
            # [B, A, 1, V] x [B, A, N', V]^T / \sqrt(V)
            # [B, A, 1, V] x [B, A, V, N'] = [B, A, 1, N']
            attention_scores = self.bmmt(query, key) * self.scores_mul

            if cutoff is None:
                attention_probs = self.softmax(attention_scores)
            else:
                large_neg = F.ones_like(attention_scores) * -5e4
                attention_scores = F.select(mask, attention_scores, large_neg)
                attention_probs = self.softmax(attention_scores)
                # [B, A, 1, N'] * [B, A, 1, N']
                attention_probs = attention_probs * F.expand_dims(cutoff, -2)

            # [B, A, 1, N'] x [B, A, N', V] = [B, A, 1, V]
            context = self.bmm(attention_probs, value)

        # [B, A, V]
        context = self.squeeze(context)

        return self.output(context)