コード例 #1
0
ファイル: decode.py プロジェクト: louis100/mindspore
 def __init__(self):
     super(GatherFlipFeature, self).__init__()
     self.gather_nd = ops.GatherNd()
     self.transpose = ops.Transpose()
     self.perm_list = (1, 0, 2, 3)
     self.shape = ops.Shape()
     self.reshape = ops.Reshape()
コード例 #2
0
    def construct(self, x, seq_lengths):
        """Defines the ReverseSequence operator computation performed."""
        batch_size = x.shape[self.batch_dim]
        max_seq_len = x.shape[self.seq_dim]
        seq_lens_type = seq_lengths.dtype

        back = ops.Sub()(seq_lengths, ops.OnesLike()(seq_lengths))

        batch_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type,
                                    0)
        forward_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type,
                                      1)

        back = back.view(-1, 1)
        reverse_idx = ops.Sub()(back, forward_idx)

        condition = ops.Less()(reverse_idx, ops.ZerosLike()(reverse_idx))
        reverse_idx = ops.Select()(condition, forward_idx, reverse_idx)

        reverse_idx = ops.ExpandDims()(reverse_idx, 2)
        batch_idx = ops.ExpandDims()(batch_idx, 2)

        if self.batch_dim > self.seq_dim:
            batch_idx = ops.Transpose()(batch_idx, (1, 0, 2))
            reverse_idx = ops.Transpose()(reverse_idx, (1, 0, 2))
            x = ops.Transpose()(x, (1, 0, 2))
        start_indices = ops.Concat(2)((batch_idx, reverse_idx))

        output = ops.GatherNd()(x, start_indices)

        return output
コード例 #3
0
ファイル: decode.py プロジェクト: louis100/mindspore
 def __init__(self):
     super(FlipLR, self).__init__()
     self.gather_flip_feat = GatherFlipFeature()
     self.half = ops.Split(axis=0, output_num=2)
     self.flip = ops.ReverseV2(axis=[3])
     self.flip_index = Tensor(
         np.array(
             [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15],
             np.int32))
     self.gather_nd = ops.GatherNd()
コード例 #4
0
 def __init__(self):
     super(GetSurroundFeature, self).__init__()
     self.shape = ops.Shape()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
     self.half = ops.Split(axis=-1, output_num=2)
     self.tile = ops.Tile()
     self.gather_nd = ops.GatherNd()
     self.transpose = ops.Transpose()
     self.perm_list = (0, 2, 3, 1)
     self.order_list = (0, 3, 1, 2)
     self.expand_dims = ops.ExpandDims()
コード例 #5
0
ファイル: decode.py プロジェクト: louis100/mindspore
 def __init__(self, enable_cpu_gatherd=True):
     super(GatherFeatureByInd, self).__init__()
     self.tile = ops.Tile()
     self.shape = ops.Shape()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
     self.enable_cpu_gatherd = enable_cpu_gatherd
     if self.enable_cpu_gatherd:
         self.gather_nd = ops.GatherD()
         self.expand_dims = ops.ExpandDims()
     else:
         self.gather_nd = ops.GatherNd()
コード例 #6
0
ファイル: decode.py プロジェクト: louis100/mindspore
 def __init__(self):
     super(FlipTensor, self).__init__()
     self.half = ops.Split(axis=0, output_num=2)
     self.flip = ops.ReverseV2(axis=[3])
     self.gather_nd = ops.GatherNd()