Ejemplo n.º 1
0
    def construct(self, inputs, mask=None):
        r"""Compute layer output.

        Args:
            input (torch.Tensor): input data.
            mask (torch.Tensor, optional): mask to be applied; e.g. neighbors mask.

        Returns:
            torch.Tensor: layer output.

        """
        # mask input
        if mask is not None:
            inputs = inputs * F.expand_dims(mask, -1)
        # compute sum of input along axis

        y = self.reduce_sum(inputs, self.axis)
        # compute average of input along axis
        if self.average:
            # get the number of items along axis
            if mask is not None:
                N = self.reduce_sum(mask, self.axis)
                N = self.maximum(N, other=F.ones_like(N))
            else:
                N = inputs.shape[self.axis]

            y = y / N
        return y
Ejemplo n.º 2
0
    def construct(self, distances):
        """Compute smeared-gaussian distance values.

        Args:
            distances (torch.Tensor): interatomic distance values of
                (N_b x N_at x N_nbh) shape.

        Returns:
            torch.Tensor: layer output of (N_b x N_at x N_nbh x N_g) shape.

        """
        ex_dis = F.expand_dims(distances, -1)
        if not self.centered:
            # compute width of Gaussian functions (using an overlap of 1 STDDEV)
            coeff = -0.5 / F.square(self.width)
            # Use advanced indexing to compute the individual components
            # ~ diff = distances[:, :, :, None] - offset[None, None, None, :]
            ex_offset = F.reshape(self.offset, (1, 1, 1, -1))
            diff = ex_dis - ex_offset
        else:
            # if Gaussian functions are centered, use offsets to compute widths
            coeff = -0.5 / F.square(self.offset)
            # if Gaussian functions are centered, no offset is subtracted
            diff = ex_dis
        # compute smear distance values
        exp = P.Exp()
        gauss = exp(coeff * F.square(diff))
        return gauss
Ejemplo n.º 3
0
    def construct(self, xi, g_ii, xij, g_ij, t=0, c_ij=None):
        r"""Get query, key and query from atom types and positions

        Args:
            xi   (Mindspore.Tensor [B, A, V]):
            g_ii (Mindspore.Tensor [B, A, V]):
            xij  (Mindspore.Tensor [B, A, N, V]):
            g_ij (Mindspore.Tensor [B, A, N, V]):
            t    (Mindspore.Tensor [V]):

        Marks:
            B:  Batch size
            A:  Number of atoms
            N:  Number of neighbor atoms
            N': Number of neighbor atoms and itself (N' = N + 1)
            V:  Dimensions of atom embedding (V = v * h)

        Returns:
            query  (Mindspore.Tensor [B, A, 1, V]):
            key    (Mindspore.Tensor [B, A, N', V]):
            value  (Mindspore.Tensor [B, A, N', V]):

        """
        # [B, A, V] * [B, A, V] = [B, A, V]
        xgii = self.mul(xi, g_ii)
        # [B, A, N, V] * [B, A, N, V] = [B, A, N, V]
        xgij = self.mul(xij, g_ij)

        # [B, A, 1, V]
        xgii = F.expand_dims(xgii, -2)
        # [B, A, N', V]
        xgij = self.concat((xgii, xgij))
        if c_ij is not None:
            # [B, A, N', V] * [B, A, N', 1]
            xgij = xgij * F.expand_dims(c_ij, -1)

        xgii = self.layer_norm(xgii + t)
        xgij = self.layer_norm(xgij + t)

        # [B, A, 1, V]
        query = self.xg2q(xgii)
        # [B, A, N', V]
        key = self.xg2k(xgij)
        # [B, A, N', V]
        value = self.xg2v(xgij)

        return query, key, value
Ejemplo n.º 4
0
    def construct(self, num_atoms):
        # broadcast atom numbers to [B*A*N]
        # a_i: number of atoms in each molecule
        # [[a_0]*A*N,[a_1]*A*N,...,[a_N]*A*N]]
        exnum = num_atoms * self.aones
        exnum = F.expand_dims(exnum, -1) * self.nones

        # [B,1,1]
        exones = self.ones((num_atoms.shape[0], 1, 1), ms.int32)
        # broadcast to [B*A*N]: [B,1,1] * [1,A,N]
        exnfc = exones * F.expand_dims(self.nfc, 0)
        exnnc = exones * F.expand_dims(self.nnc, 0)
        exmat = exones * F.expand_dims(self.mat_idx, 0)

        mask = exmat < exnum

        neighbors = F.select(mask, exnfc, exnnc)

        return neighbors, mask
Ejemplo n.º 5
0
    def _get_rbf(self, dis):
        # expand interatomic distances (for example, Gaussian smearing)
        if self.distance_expansion is None:
            rbf = F.expand_dims(dis, -1)
        else:
            rbf = self.distance_expansion(dis)

        if self.rescale_rbf:
            rbf = rbf * 2.0 - 1.0

        if self.filter is not None:
            return self.filter(rbf)
        else:
            return rbf
Ejemplo n.º 6
0
    def __init__(self, tot_atoms):
        super().__init__()
        # tot_atoms: A
        # tot_neigh: N =  A - 1
        tot_neigh = tot_atoms - 1
        arange = nn.Range(tot_atoms)
        nrange = nn.Range(tot_neigh)

        self.ones = P.Ones()
        self.aones = self.ones((tot_atoms), ms.int32)
        self.nones = self.ones((tot_neigh), ms.int32)
        self.eaones = F.expand_dims(self.aones, -1)

        # neighbors for no connection (A*N)
        # [[0,0,...,0],
        #  [1,1,...,1],
        #  ...........,
        #  [N,N,...,N]]
        self.nnc = F.expand_dims(arange(), -1) * self.nones

        # copy of the index range (A*N)
        # [[0,1,...,N-1],
        #  [0,1,...,N-1],
        #  ...........,
        #  [0,1,...,N-1]]
        exrange = self.ones((tot_atoms, 1), ms.int32) * nrange()

        # neighbors for full connection (A*N)
        # [[1,2,3,...,N],
        #  [0,2,3,...,N],
        #  [0,1,3,....N],
        #  .............,
        #  [0,1,2,...,N-1]]
        self.nfc = exrange + F.cast(self.nnc <= exrange, ms.int32)

        self.ar0 = nn.Range(0, tot_neigh)()
        self.ar1 = nn.Range(1, tot_atoms)()
Ejemplo n.º 7
0
    def construct(self, atom_types):
        # [B,1,1]
        exones = self.ones((atom_types.shape[0], 1, 1), ms.int32)
        # broadcast to [B*A*N]: [B,1,1] * [1,A,N]
        exnfc = exones * F.expand_dims(self.nfc, 0)
        exnnc = exones * F.expand_dims(self.nnc, 0)

        tmask = F.select(atom_types > 0, F.ones_like(atom_types),
                         F.ones_like(atom_types) * -1)
        tmask = F.cast(tmask, ms.float32)
        extmask = F.expand_dims(tmask, -1) * self.nones

        mask0 = F.gather(tmask, self.ar0, -1)
        mask0 = F.expand_dims(mask0, -2) * self.eaones
        mask1 = F.gather(tmask, self.ar1, -1)
        mask1 = F.expand_dims(mask1, -2) * self.eaones

        mtmp = F.select(exnfc > exnnc, mask1, mask0)
        mask = F.select(extmask > 0, mtmp, F.ones_like(mtmp) * -1)
        mask = mask > 0

        idx = F.select(mask, exnfc, exnnc)

        return idx, mask
Ejemplo n.º 8
0
    def construct(self, prob, halting_prob, n_updates):
        # zeros = self.zeros_like(halting_prob)
        # ones = self.ones_like(halting_prob)

        # Mask for inputs which have not halted last cy
        running = F.cast(halting_prob < 1.0, ms.float32)
        # running = self.select(halting_prob < 1.0,ones,zeros)

        # Add the halting probability for this step to the halting
        # probabilities for those input which haven't halted yet
        add_prob = prob * running
        new_prob = halting_prob + add_prob
        mask_run = F.cast(new_prob <= self.threshold, ms.float32)
        mask_halt = F.cast(new_prob > self.threshold, ms.float32)
        # mask_run = self.select(new_prob <= self.threshold,ones,zeros)
        # mask_halt = self.select(new_prob > self.threshold,ones,zeros)

        # Mask of inputs which haven't halted, and didn't halt this step
        still_running = mask_run * running
        running_prob = halting_prob + prob * still_running

        # Mask of inputs which halted at this step
        new_halted = mask_halt * running

        # Compute remainders for the inputs which halted at this step
        remainders = new_halted * (1.0 - running_prob)

        # Add the remainders to those inputs which halted at this step
        # halting_prob = new_prob + remainders
        dp = add_prob + remainders

        # Increment n_updates for all inputs which are still running
        # n_updates = n_updates + running
        dn = running

        # Compute the weight to be applied to the new state and output
        # 0 when the input has already halted
        # prob when the input hasn't halted yet
        # the remainders when it halted this step
        update_weights = prob * still_running + new_halted * remainders
        w = F.expand_dims(update_weights, -1)

        return w, dp, dn
Ejemplo n.º 9
0
    def __init__(self, tot_atoms):
        super().__init__()
        # tot_atoms: A
        # tot_neigh: N =  A - 1
        tot_neigh = tot_atoms - 1
        arange = nn.Range(tot_atoms)
        nrange = nn.Range(tot_neigh)

        self.ones = P.Ones()
        self.aones = self.ones((tot_atoms), ms.int32)
        self.nones = self.ones((tot_neigh), ms.int32)

        # neighbors for no connection (A*N)
        # [[0,0,...,0],
        #  [1,1,...,1],
        #  ...........,
        #  [N,N,...,N]]
        self.nnc = F.expand_dims(arange(), -1) * self.nones
        # copy of the index range (A*N)
        # [[0,1,...,N-1],
        #  [0,1,...,N-1],
        #  ...........,
        #  [0,1,...,N-1]]
        crange = self.ones((tot_atoms, 1), ms.int32) * nrange()
        # neighbors for full connection (A*N)
        # [[1,2,3,...,N],
        #  [0,2,3,...,N],
        #  [0,1,3,....N],
        #  .............,
        #  [0,1,2,...,N-1]]
        self.nfc = crange + F.cast(self.nnc <= crange, ms.int32)

        crange1 = crange + 1
        # the matrix for index range (A*N)
        # [[1,2,3,...,N],
        #  [1,2,3,...,N],
        #  [2,2,3,....N],
        #  [3,3,3,....N],
        #  .............,
        #  [N,N,N,...,N]]
        self.mat_idx = F.select(crange1 > self.nnc, crange1, self.nnc)
Ejemplo n.º 10
0
    def construct(self, distance):
        dis = distance / self.d_max

        if self.min_cutoff:
            dis = self.max(dis, self.d_min)

        exdis = F.expand_dims(dis, -1)
        rbfdis = exdis * self.ones

        log_dis = self.log(rbfdis)
        log_diff = log_dis - self.centers
        log_diff2 = F.square(log_diff)
        log_gauss = self.exp(self.rescale * log_diff2)

        if self.max_cutoff:
            ones = self.onesslike(exdis)
            zeros = self.zeroslike(exdis)
            cuts = F.select(exdis < 1.0, ones, zeros)
            log_gauss = log_gauss * cuts

        return log_gauss
Ejemplo n.º 11
0
    def construct(self,
                  positions,
                  neighbors,
                  neighbor_mask=None,
                  cell=None,
                  cell_offsets=None):
        r"""Compute distance of every atom to its neighbors.

        Args:
            positions (ms.Tensor[float]): atomic Cartesian coordinates with
                (N_b x N_at x 3) shape.
            neighbors (ms.Tensor[int]): indices of neighboring atoms to consider
                with (N_b x N_at x N_nbh) or (N_at x N_nbh) shape.
            cell (ms.tensor[float], optional): periodic cell of (N_b x 3 x 3) shape.
            cell_offsets (ms.Tensor[float], optional): offset of atom in cell coordinates
                with (N_b x N_at x N_nbh x 3) shape.
            neighbor_mask (ms.Tensor[bool], optional): boolean mask for neighbor
                positions. Required for the stable computation of forces in
                molecules with different sizes.

        Returns:
            ms.Tensor[float]: layer output of (N_b x N_at x N_nbh) shape.

        """

        pos_xyz = self.gather_neighbors(positions, neighbors)

        # Subtract positions of central atoms to get distance vectors
        dist_vec = pos_xyz - F.expand_dims(positions, -2)

        # distances = self.norm(dist_vec)
        distances = F.square(dist_vec)
        distances = self.reducesum(distances, -1)
        distances = self.pow(distances, 0.5)

        if neighbor_mask is not None:
            distances = F.select(neighbor_mask, distances,
                                 F.ones_like(distances) * 999)

        return distances
Ejemplo n.º 12
0
def batch_dot(x1, x2, axes=None):
    """
    Computation of batch dot product between samples in two tensors containing batch dims.

    Inputs:
        - **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
        - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
          be same as x1's.
        - **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
          specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
          `a` input shape and last N dims from `b` input shape in order as axes for each respectively.

    Outputs:
        Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
          (batch, d3, axes, d4) is (batch, d1, d2, d3, d4)

    .. math::
        output = x1[batch, :] * x2[batch, :]

    Raises:
        TypeError: If shapes of x1 and x2 are not the same.
        ValueError: If rank of x1 or x2 less than 2.
        ValueError: If batch dim used in axes.
        ValueError: If dtype of x1 or x2 is not float32.
        ValueError: If len(axes) less than 2.
        ValueError: If axes is not one of those: None, int, (int, int).
        ValueError: If axes value is too high for dimensions of input arrays.
        ValueError: If batch size of x1 and x2 are not the same.

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Examples:
        >>> input_x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
        >>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
        >>> axes = (-1, -2)
        >>> output = C.batch_dot(input_x1, input_x2, axes)
        >>> print(output)
        [[[3. 3.]
          [3. 3.]]
         [[3. 3.]
          [3. 3.]]]
    """

    transpose_op = P.Transpose()
    batch_matmul_op = P.BatchMatMul()
    squeeze_one_op = P.Squeeze(1)
    squeeze_minus_one_op = P.Squeeze(-1)
    # input validity checks
    x1_shape = F.shape(x1)
    x2_shape = F.shape(x2)
    x1_dim_num = len(x1_shape)
    x2_dim_num = len(x2_shape)
    x1_type = F.dtype(x1)
    x2_type = F.dtype(x2)

    x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)

    _typecheck_input_batch_dot(x1_type, x2_type)
    _check_batch_size(x1_batch_size, x2_batch_size)
    axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)

    if x1_dim_num == 2:
        x1 = F.expand_dims(x1, 1)
        axes[0] += 1
    if x2_dim_num == 2:
        x2 = F.expand_dims(x2, 2)

    x1_shape = F.shape(x1)
    x2_shape = F.shape(x2)

    x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(
        x1_shape, axes, 0)
    x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(
        x2_shape, axes, 1)
    output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)

    x1_transposed = transpose_op(x1, x1_transpose_fwd)
    x2_transposed = transpose_op(x2, x2_transpose_fwd)
    x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
    x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)

    # Batch matmal op part
    mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)

    final_result = F.reshape(mul_result, output_shape)

    # if the original dims are expanded, restore them from 3 to 2
    if x1_dim_num == 2:
        final_result = squeeze_one_op(final_result)
    elif x2_dim_num == 2:
        final_result = squeeze_minus_one_op(final_result)

    return final_result
Ejemplo n.º 13
0
def _expand(x, ndim):
    """Expand x to ndim from axis, which can be 0 or -1."""
    while F.rank(x) < ndim:
        x = F.expand_dims(x, 0)
    return x
Ejemplo n.º 14
0
def _get_square_sum(grad):
    norm = P.ReduceSum(False)(F.square(grad), ())
    norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
    return norm
Ejemplo n.º 15
0
def _compute_norm(grad):
    norm = nn.Norm()
    norm = norm(F.cast(grad, mstype.float32))
    ret = F.expand_dims(F.cast(norm, mstype.float32), 0)
    return ret
Ejemplo n.º 16
0
    def construct(self, query, key, value, cutoff=None, mask=None):
        r"""Compute multi-head attention.

        Args:
            query  (Mindspore.Tensor [B, A, 1, V]):
            key    (Mindspore.Tensor [B, A, N', V]):
            value  (Mindspore.Tensor [B, A, N', V]):
            cutoff (Mindspore.Tensor [B, A, 1, N'] or [B, A, 1, 1, N']):

        Returns:
            Mindspore.Tensor [B, A, V]: multi-head attention output.

        """
        if self.n_heads > 1:
            q_reshape = query.shape[:-1] + self.reshape_tail
            k_reshape = key.shape[:-1] + self.reshape_tail
            v_reshape = value.shape[:-1] + self.reshape_tail

            # [B, A, 1, h, v]
            Q = F.reshape(query, q_reshape)
            # [B, A, h, 1, v]
            Q = self.transpose(Q, self.trans_shape)

            # [B, A, N', h, v]
            K = F.reshape(key, k_reshape)
            # [B, A, h, N', v]
            K = self.transpose(K, self.trans_shape)

            # [B, A, N', h, v]
            V = F.reshape(value, v_reshape)
            # [B, A, h, N', v]
            V = self.transpose(V, self.trans_shape)

            # [B, A, h, 1, v] x [B, A, h, N', v]^T / \sqrt(v)
            # [B, A, h, 1, v] x [B, A, h, v, N'] = [B, A, h, 1, N']
            attention_scores = self.bmmt(Q, K)
            attention_scores = self.mul(attention_scores, self.scores_mul)

            if cutoff is None:
                attention_probs = self.softmax(attention_scores)
            else:
                # [B, A, 1, 1, N']
                exmask = F.expand_dims(F.expand_dims(mask, -2), -2)
                # [B, A, h, 1, N']
                mhmask = exmask * self.exones
                large_neg = F.ones_like(attention_scores) * -5e4
                attention_scores = F.select(mhmask > 0, attention_scores,
                                            large_neg)
                attention_probs = self.softmax(attention_scores)
                excut = F.expand_dims(F.expand_dims(cutoff, -2), -2)
                # [B, A, h, 1, N'] * [B, A, 1, 1, N']
                attention_probs = self.mul(attention_probs, excut)

            # [B, A, h, 1, N'] x [B, A, h, N', v] = [B, A, h, 1, v]
            context = self.bmm(attention_probs, V)
            # [B, A, 1, h, v]
            context = self.transpose(context, self.trans_shape)
            # [B, A, 1, V]
            context = F.reshape(context, query.shape)
        else:
            # [B, A, 1, V] x [B, A, N', V]^T / \sqrt(V)
            # [B, A, 1, V] x [B, A, V, N'] = [B, A, 1, N']
            attention_scores = self.bmmt(query, key) * self.scores_mul

            if cutoff is None:
                attention_probs = self.softmax(attention_scores)
            else:
                large_neg = F.ones_like(attention_scores) * -5e4
                attention_scores = F.select(mask, attention_scores, large_neg)
                attention_probs = self.softmax(attention_scores)
                # [B, A, 1, N'] * [B, A, 1, N']
                attention_probs = attention_probs * F.expand_dims(cutoff, -2)

            # [B, A, 1, N'] x [B, A, N', V] = [B, A, 1, V]
            context = self.bmm(attention_probs, value)

        # [B, A, V]
        context = self.squeeze(context)

        return self.output(context)
Ejemplo n.º 17
0
 def get_full_neighbors(self):
     return F.expand_dims(self.nfc, 0)
Ejemplo n.º 18
0
def pairwise_displacement(R: Tensor):
    dR = F.expand_dims(R, 1) - F.expand_dims(R, 0)
    # periodic
    # np.mod(dR + box_size * 0.5, box_size) - box_size * 0.5
    return dR