def __init__(self): super(GatherFlipFeature, self).__init__() self.gather_nd = ops.GatherNd() self.transpose = ops.Transpose() self.perm_list = (1, 0, 2, 3) self.shape = ops.Shape() self.reshape = ops.Reshape()
def construct(self, x, seq_lengths): """Defines the ReverseSequence operator computation performed.""" batch_size = x.shape[self.batch_dim] max_seq_len = x.shape[self.seq_dim] seq_lens_type = seq_lengths.dtype back = ops.Sub()(seq_lengths, ops.OnesLike()(seq_lengths)) batch_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 0) forward_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 1) back = back.view(-1, 1) reverse_idx = ops.Sub()(back, forward_idx) condition = ops.Less()(reverse_idx, ops.ZerosLike()(reverse_idx)) reverse_idx = ops.Select()(condition, forward_idx, reverse_idx) reverse_idx = ops.ExpandDims()(reverse_idx, 2) batch_idx = ops.ExpandDims()(batch_idx, 2) if self.batch_dim > self.seq_dim: batch_idx = ops.Transpose()(batch_idx, (1, 0, 2)) reverse_idx = ops.Transpose()(reverse_idx, (1, 0, 2)) x = ops.Transpose()(x, (1, 0, 2)) start_indices = ops.Concat(2)((batch_idx, reverse_idx)) output = ops.GatherNd()(x, start_indices) return output
def __init__(self): super(FlipLR, self).__init__() self.gather_flip_feat = GatherFlipFeature() self.half = ops.Split(axis=0, output_num=2) self.flip = ops.ReverseV2(axis=[3]) self.flip_index = Tensor( np.array( [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15], np.int32)) self.gather_nd = ops.GatherNd()
def __init__(self): super(GetSurroundFeature, self).__init__() self.shape = ops.Shape() self.concat = ops.Concat(axis=1) self.reshape = ops.Reshape() self.half = ops.Split(axis=-1, output_num=2) self.tile = ops.Tile() self.gather_nd = ops.GatherNd() self.transpose = ops.Transpose() self.perm_list = (0, 2, 3, 1) self.order_list = (0, 3, 1, 2) self.expand_dims = ops.ExpandDims()
def __init__(self, enable_cpu_gatherd=True): super(GatherFeatureByInd, self).__init__() self.tile = ops.Tile() self.shape = ops.Shape() self.concat = ops.Concat(axis=1) self.reshape = ops.Reshape() self.enable_cpu_gatherd = enable_cpu_gatherd if self.enable_cpu_gatherd: self.gather_nd = ops.GatherD() self.expand_dims = ops.ExpandDims() else: self.gather_nd = ops.GatherNd()
def __init__(self): super(FlipTensor, self).__init__() self.half = ops.Split(axis=0, output_num=2) self.flip = ops.ReverseV2(axis=[3]) self.gather_nd = ops.GatherNd()