예제 #1
0
    def construct(self, offset):
        """get target position"""
        offset_shape = self.shape(offset) # b * 2N * h * w
        N, h, w = offset_shape[1] // 2, offset_shape[2], offset_shape[3]
        # get p_n
        range_pn = self.range()
        p_n_x, p_n_y = self.meshgrid((range_pn, range_pn))
        # (2N, 1)
        p_n = self.cat_a0((self.reshape(p_n_x, (N, 1)), self.reshape(p_n_y, (N, 1))))
        p_n = self.reshape(p_n, (1, 2 * N, 1, 1))

        # get p_0
        range_h = nn.Range(self.begin, h*self.stride + 1, self.stride)()
        range_w = nn.Range(self.begin, w*self.stride + 1, self.stride)()
        p_0_x, p_0_y = self.meshgrid((range_h, range_w))
        p_0_x = self.reshape(p_0_x, (1, 1, h, w))
        p_0_x = self.tile(p_0_x, (1, N, 1, 1))
        p_0_y = self.reshape(p_0_y, (1, 1, h, w))
        p_0_y = self.tile(p_0_y, (1, N, 1, 1))
        p_0 = self.cat_a1((p_0_x, p_0_y))

        # get p
        dtype = self.dtype(offset)
        p = self.cast(p_0, dtype) + self.cast(p_n, dtype) + offset
        return p
예제 #2
0
 def construct(self, feat, ind):
     """gather by index"""
     # feat: b, J, K, N
     # ind:  b, J, K
     b, J, K = self.shape(ind)
     feat = self.reshape(feat, (b, J, K, -1))
     _, _, _, N = self.shape(feat)
     if self.enable_cpu_gatherd:
         # (b, J, K, N)
         index = self.expand_dims(ind, -1)
         index = self.tile(index, (1, 1, 1, N))
         feat = self.gather_nd(feat, 2, index)
     else:
         ind = self.reshape(ind, (-1, 1))
         ind_b = nn.Range(0, b * J, 1)()
         ind_b = self.reshape(ind_b, (-1, 1))
         ind_b = self.tile(ind_b, (1, K))
         ind_b = self.reshape(ind_b, (-1, 1))
         index = self.concat((ind_b, ind))
         # (b*J, K, 2)
         index = self.reshape(index, (-1, K, 2))
         # (b*J, K)
         feat = self.reshape(feat, (-1, K, N))
         feat = self.gather_nd(feat, index)
         feat = self.reshape(feat, (b, J, K, -1))
     return feat
예제 #3
0
    def construct(self, x, q_h, q_w):
        """gather feature by specified index"""
        b, c, _, w_p = self.shape(x)
        _, h, w, N = self.shape(q_h)
        hwn = h * w * N
        # (b * hw * c)
        x = self.transpose(x, self.perm_list)
        x = self.reshape(x, (b, -1, c))

        # (b * hwN)
        q = q_h * w_p + q_w
        q = self.reshape(q, (-1, 1))
        ind_b = nn.Range(0, b, 1)()
        ind_b = self.reshape(ind_b, (-1, 1))
        ind_b = self.tile(ind_b, (1, hwn))
        ind_b = self.reshape(ind_b, (-1, 1))
        index = self.concat((ind_b, q))
        # (b, hwn, 2)
        index = self.reshape(index, (b, hwn, -1))
        # (b, hwn, c)
        x_offset = self.gather_nd(x, index)
        # (b, c, h, w, N)
        x_offset = self.reshape(x_offset, (b, h * w, N, c))
        x_offset = self.transpose(x_offset, self.order_list)
        x_offset = self.reshape(x_offset, (b, c, h, w, N))

        return x_offset
예제 #4
0
def test_float():
    op = nn.Range(10., 100., 20.)
    op_wrapper = OpNetWrapper(op)

    outputs = op_wrapper()
    print(outputs)
    assert outputs.shape == (5, )
    assert np.allclose(outputs.asnumpy(), [10., 30., 50., 70., 90.])
예제 #5
0
def test_int():
    op = nn.Range(0, 100, 10)
    op_wrapper = OpNetWrapper(op)

    outputs = op_wrapper()
    print(outputs)
    assert outputs.shape == (10, )
    assert np.allclose(outputs.asnumpy(), range(0, 100, 10))
예제 #6
0
 def _mean(self, probs=None):
     r"""
     .. math::
         E[X] = \sum_{i=0}^{num_classes-1} i*p_i
     """
     probs = self._check_param_type(probs)
     num_classes = self.shape(probs)[-1]
     index = nn.Range(0., num_classes, 1.)()
     return self.reduce_sum(index * probs, -1)
예제 #7
0
    def _log_prob(self, value, probs=None):
        r"""
        Evaluate log probability.

        Args:
            value (Tensor): The value to be evaluated.
            probs (Tensor): Event probabilities. Default: self.probs.
        """
        value = self._check_value(value, 'value')
        value = self.cast(value, self.parameter_type)
        probs = self._check_param_type(probs)
        logits = self.log(probs)

        # handle the case when value is of shape () and probs is a scalar batch
        drop_dim = False
        if self.shape(value) == () and self.shape(probs)[:-1] == ():
            drop_dim = True
            # manually add one more dimension: () -> (1,)
            # drop this dimension before return
            value = self.expand_dim(value, -1)

        value = self.expand_dim(value, -1)

        broadcast_shape_tensor = logits * value
        broadcast_shape = self.shape(broadcast_shape_tensor)
        # broadcast_shape (N, C)
        num_classes = broadcast_shape[-1]
        label_shape = broadcast_shape[:-1]

        # broadcasting logits and value
        # logit_pmf shape (num of labels, C)
        logits = self.broadcast(logits, broadcast_shape_tensor)
        value = self.broadcast(value, broadcast_shape_tensor)[..., :1]

        # flatten value to shape (number of labels, 1)
        # clip value to be in range from 0 to num_classes -1 and cast into int32
        value = self.reshape(value, (-1, 1))
        out_of_bound = self.squeeze_last_axis(self.logicor(\
                        self.less(value, 0.0), self.less(num_classes-1, value)))
        value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
        value_clipped = self.cast(value_clipped, self.index_type)
        # create index from 0 ... NumOfLabels
        index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
        index = self.concat((index, value_clipped))

        # index into logit_pmf, fill in out_of_bound places with -inf
        # reshape into label shape N
        logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)),
                                 index)
        neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf),
                            -np.inf)
        logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
        ans = self.reshape(logits_pmf, label_shape)
        if drop_dim:
            return self.squeeze(ans)
        return ans
예제 #8
0
 def _var(self, probs=None):
     r"""
     .. math::
         VAR(X) = E[X^{2}] - (E[X])^{2}
     """
     probs = self._check_param_type(probs)
     num_classes = self.shape(probs)[-1]
     index = nn.Range(0., num_classes, 1.)()
     return self.reduce_sum(self.square(index) * probs, -1) -\
            self.square(self.reduce_sum(index * probs, -1))
예제 #9
0
 def enumerate_support(self, expand=True):
     r"""
    Enumerate categories.
    """
     num_events = self._num_events
     values = nn.Range(0., num_events, 1)()
     values = self.reshape(values, (num_events, 1))
     if expand:
         values = P.BroadcastTo((num_events, self._batch_shape))(values)
     values = self.cast(values, mstype.int32)
     return values
예제 #10
0
    def _cdf(self, value, probs=None):
        r"""
        Cumulative distribution function (cdf) of Categorical distributions.

        Args:
            value (Tensor): The value to be evaluated.
            probs (Tensor): Event probabilities. Default: self.probs.
        """
        value = self._check_value(value, 'value')
        value = self.cast(value, self.parameter_type)
        value = self.floor(value)
        probs = self._check_param_type(probs)

        # handle the case when value is of shape () and probs is a scalar batch
        drop_dim = False
        if self.shape(value) == () and self.shape(probs)[:-1] == ():
            drop_dim = True
            # manually add one more dimension: () -> (1,)
            # drop this dimension before return
            value = self.expand_dim(value, -1)

        value = self.expand_dim(value, -1)

        broadcast_shape_tensor = probs * value
        broadcast_shape = self.shape(broadcast_shape_tensor)
        # broadcast_shape (N, C)
        num_classes = broadcast_shape[-1]
        label_shape = broadcast_shape[:-1]

        probs = self.broadcast(probs, broadcast_shape_tensor)
        value = self.broadcast(value, broadcast_shape_tensor)[..., :1]

        # flatten value to shape (number of labels, 1)
        value = self.reshape(value, (-1, 1))

        # drop one dimension to match cdf
        # clip value to be in range from 0 to num_classes -1 and cast into int32
        less_than_zero = self.squeeze_last_axis(self.less(value, 0.0))
        value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
        value_clipped = self.cast(value_clipped, self.index_type)

        index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
        index = self.concat((index, value_clipped))

        # reshape probs and fill less_than_zero places with 0
        probs = self.reshape(probs, (-1, num_classes))
        cdf = self.gather(self.cumsum(probs, 1), index)
        zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
        cdf = self.select(less_than_zero, zeros, cdf)
        cdf = self.reshape(cdf, label_shape)

        if drop_dim:
            return self.squeeze(cdf)
        return cdf
예제 #11
0
    def __init__(self, tot_atoms):
        super().__init__()
        # tot_atoms: A
        # tot_neigh: N =  A - 1
        tot_neigh = tot_atoms - 1
        arange = nn.Range(tot_atoms)
        nrange = nn.Range(tot_neigh)

        self.ones = P.Ones()
        self.aones = self.ones((tot_atoms), ms.int32)
        self.nones = self.ones((tot_neigh), ms.int32)

        # neighbors for no connection (A*N)
        # [[0,0,...,0],
        #  [1,1,...,1],
        #  ...........,
        #  [N,N,...,N]]
        self.nnc = F.expand_dims(arange(), -1) * self.nones
        # copy of the index range (A*N)
        # [[0,1,...,N-1],
        #  [0,1,...,N-1],
        #  ...........,
        #  [0,1,...,N-1]]
        crange = self.ones((tot_atoms, 1), ms.int32) * nrange()
        # neighbors for full connection (A*N)
        # [[1,2,3,...,N],
        #  [0,2,3,...,N],
        #  [0,1,3,....N],
        #  .............,
        #  [0,1,2,...,N-1]]
        self.nfc = crange + F.cast(self.nnc <= crange, ms.int32)

        crange1 = crange + 1
        # the matrix for index range (A*N)
        # [[1,2,3,...,N],
        #  [1,2,3,...,N],
        #  [2,2,3,....N],
        #  [3,3,3,....N],
        #  .............,
        #  [N,N,N,...,N]]
        self.mat_idx = F.select(crange1 > self.nnc, crange1, self.nnc)
예제 #12
0
 def __init__(self, begin, stride):
     super(GetOffsetPosition, self).__init__()
     self.begin = begin
     self.stride = stride
     self.meshgrid = ops.Meshgrid()
     self.shape = ops.Shape()
     self.reshape = ops.Reshape()
     self.cat_a0 = ops.Concat(axis=0)
     self.cat_a1 = ops.Concat(axis=1)
     self.tile = ops.Tile()
     self.dtype = ops.DType()
     self.range = nn.Range(-self.begin, self.begin + 1)
     self.cast = ops.Cast()
예제 #13
0
 def __init__(self,
              weight,
              start,
              limit,
              delta,
              strategy1=None,
              strategy2=None,
              strategy3=None):
     super().__init__()
     self.mul = P.Mul().shard(strategy1)
     self.range = nn.Range(start, limit, delta)
     self.range.range_x.shard(strategy2)
     self.mul2 = P.Mul().shard(strategy3)
     self.weight = Parameter(weight, "w")
예제 #14
0
    def __init__(self, tot_atoms):
        super().__init__()
        # tot_atoms: A
        # tot_neigh: N =  A - 1
        tot_neigh = tot_atoms - 1
        arange = nn.Range(tot_atoms)
        nrange = nn.Range(tot_neigh)

        self.ones = P.Ones()
        self.aones = self.ones((tot_atoms), ms.int32)
        self.nones = self.ones((tot_neigh), ms.int32)
        self.eaones = F.expand_dims(self.aones, -1)

        # neighbors for no connection (A*N)
        # [[0,0,...,0],
        #  [1,1,...,1],
        #  ...........,
        #  [N,N,...,N]]
        self.nnc = F.expand_dims(arange(), -1) * self.nones

        # copy of the index range (A*N)
        # [[0,1,...,N-1],
        #  [0,1,...,N-1],
        #  ...........,
        #  [0,1,...,N-1]]
        exrange = self.ones((tot_atoms, 1), ms.int32) * nrange()

        # neighbors for full connection (A*N)
        # [[1,2,3,...,N],
        #  [0,2,3,...,N],
        #  [0,1,3,....N],
        #  .............,
        #  [0,1,2,...,N-1]]
        self.nfc = exrange + F.cast(self.nnc <= exrange, ms.int32)

        self.ar0 = nn.Range(0, tot_neigh)()
        self.ar1 = nn.Range(1, tot_atoms)()
예제 #15
0
    def _log_prob(self, value):
        r"""
        Evaluate log probability.

        Args:
            value (Tensor): The value to be evaluated.
        """
        value = self._check_value(value, 'value')
        value = self.expandim(self.cast(value, mstype.float32), -1)
        broad_shape = self.shape(value + self._logits)
        broad = P.BroadcastTo(broad_shape)
        logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1]))
        value = self.reshape(broad(value)[..., :1], (-1, 1))
        index = nn.Range(0., self.shape(value)[0], 1)()
        index = self.reshape(index, (-1, 1))
        value = self.concat((index, value))
        value = self.cast(value, mstype.int32)
        return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1])
예제 #16
0
 def __init__(self, batch_size, temperature=1, world_size=1):
     super(NT_Xent_Loss, self).__init__()
     # Parameters.
     self.LARGE_NUM = 1e9
     self.batch_size = batch_size
     self.temperature = temperature
     self.world_size = world_size
     self.N = 2 * self.batch_size * self.world_size
     # Tail_Loss.
     self.criterion = CrossEntropyLoss(reduction="mean")
     self.norm = P.L2Normalize(axis=1)
     self.one_hot = P.OneHot()
     self.range = nn.Range(0, self.batch_size)
     self.one = Tensor(1.0, mstype.float32)
     self.zero = Tensor(0.0, mstype.float32)
     self.transpose = P.Transpose()
     self.matmul = nn.MatMul()
     # Operations.
     self.ones = P.Ones()
     self.zeros = P.Zeros()
     self.cat1 = P.Concat(axis=1)
예제 #17
0
 def construct(self, feat, ind):
     """gather by specified index"""
     if self.enable_cpu_gather:
         _, _, c = self.shape(feat)
         # (b, N, c)
         index = self.expand_dims(ind, -1)
         index = self.tile(index, (1, 1, c))
         feat = self.gather_nd(feat, 1, index)
     else:
         # (b, N)->(b*N, 1)
         b, N = self.shape(ind)
         ind = self.reshape(ind, (-1, 1))
         ind_b = nn.Range(0, b, 1)()
         ind_b = self.reshape(ind_b, (-1, 1))
         ind_b = self.tile(ind_b, (1, N))
         ind_b = self.reshape(ind_b, (-1, 1))
         index = self.concat((ind_b, ind))
         # (b, N, 2)
         index = self.reshape(index, (b, N, -1))
         # (b, N, c)
         feat = self.gather_nd(feat, index)
     return feat
예제 #18
0
    def _log_prob(self, value):
        r"""
        Evaluate log probability.

        Args:
            value (Tensor): value to be evaluated. The dtype could be mstype.float32, bool, mstype.int32.
        """
        if value is not None:
            check_tensor_type("value", value,
                              [mstype.float32, bool, mstype.int32])
            value = self.expandim(self.cast(value, mstype.float32), -1)
            broad_shape = self._broad_cast_shape(value, self._logits)
            broad = P.BroadcastTo(broad_shape)
            logits_pmf = self.reshape(broad(self._logits),
                                      (-1, broad_shape[-1]))
            value = self.reshape(broad(value)[..., :1], (-1, 1))
            index = nn.Range(0., self.shape(value)[0], 1)()
            index = self.reshape(index, (-1, 1))
            value = self.concat((index, value))
            value = self.cast(value, mstype.int32)
            return self.reshape(self.gather(logits_pmf, value),
                                broad_shape[:-1])
        return None
예제 #19
0
    def construct(
        self,
        input_ids: ms.Tensor,
        token_type_ids: Optional[ms.Tensor] = None,
        position_ids: Optional[ms.Tensor] = None,
    ) -> ms.Tensor:
        input_shape = input_ids.shape
        seq_length = input_shape[1]

        if token_type_ids is None:
            token_type_ids = ops.zeros_like(input_ids)
        if position_ids is None:
            position_ids = ops.ExpandDims()(nn.Range(0, seq_length)(), 0)

        input_embeddings = self.token_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = input_embeddings + position_embeddings + \
            token_type_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings
예제 #20
0
 def __init__(self, dim):
     super().__init__()
     self.range = nn.Range(dim)
     ones = P.Ones()
     self.ones = ones((dim), ms.int32)
    def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals):
        super(BboxAssignSampleForRcnn, self).__init__()
        cfg = config
        self.batch_size = batch_size
        self.neg_iou_thr = cfg.neg_iou_thr_stage2
        self.pos_iou_thr = cfg.pos_iou_thr_stage2
        self.min_pos_iou = cfg.min_pos_iou_stage2
        self.num_gts = cfg.num_gts
        self.num_bboxes = num_bboxes
        self.num_expected_pos = cfg.num_expected_pos_stage2
        self.num_expected_neg = cfg.num_expected_neg_stage2
        self.num_expected_total = cfg.num_expected_total_stage2

        self.add_gt_as_proposals = add_gt_as_proposals
        self.label_inds = Tensor(
            np.arange(1, self.num_gts + 1).astype(np.int32))
        self.add_gt_as_proposals_valid = Tensor(
            np.array(self.add_gt_as_proposals * np.ones(self.num_gts),
                     dtype=np.int32))

        self.concat = P.Concat(axis=0)
        self.max_gt = P.ArgMaxWithValue(axis=0)
        self.max_anchor = P.ArgMaxWithValue(axis=1)
        self.sum_inds = P.ReduceSum()
        self.iou = P.IOU()
        self.greaterequal = P.GreaterEqual()
        self.greater = P.Greater()
        self.select = P.Select()
        self.gatherND = P.GatherNd()
        self.squeeze = P.Squeeze()
        self.cast = P.Cast()
        self.logicaland = P.LogicalAnd()
        self.less = P.Less()
        self.random_choice_with_mask_pos = P.RandomChoiceWithMask(
            self.num_expected_pos)
        self.random_choice_with_mask_neg = P.RandomChoiceWithMask(
            self.num_expected_neg)
        self.reshape = P.Reshape()
        self.equal = P.Equal()
        self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0,
                                                              0.0),
                                                       stds=(0.1, 0.1, 0.2,
                                                             0.2))
        self.concat_axis1 = P.Concat(axis=1)
        self.logicalnot = P.LogicalNot()
        self.tile = P.Tile()

        # Check
        self.check_gt_one = Tensor(
            np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
        self.check_anchor_two = Tensor(
            np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))

        # Init tensor
        self.assigned_gt_inds = Tensor(
            np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
        self.assigned_gt_zeros = Tensor(
            np.array(np.zeros(num_bboxes), dtype=np.int32))
        self.assigned_gt_ones = Tensor(
            np.array(np.ones(num_bboxes), dtype=np.int32))
        self.assigned_gt_ignores = Tensor(
            np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
        self.assigned_pos_ones = Tensor(
            np.array(np.ones(self.num_expected_pos), dtype=np.int32))

        self.gt_ignores = Tensor(
            np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
        self.range_pos_size = Tensor(
            np.arange(self.num_expected_pos).astype(np.float16))
        self.check_neg_mask = Tensor(
            np.array(np.ones(self.num_expected_neg - self.num_expected_pos),
                     dtype=np.bool))
        self.bboxs_neg_mask = Tensor(
            np.zeros((self.num_expected_neg, 4), dtype=np.float16))
        self.labels_neg_mask = Tensor(
            np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))

        self.reshape_shape_pos = (self.num_expected_pos, 1)
        self.reshape_shape_neg = (self.num_expected_neg, 1)

        self.scalar_zero = Tensor(0.0, dtype=mstype.float16)
        self.scalar_neg_iou_thr = Tensor(self.neg_iou_thr,
                                         dtype=mstype.float16)
        self.scalar_pos_iou_thr = Tensor(self.pos_iou_thr,
                                         dtype=mstype.float16)
        self.scalar_min_pos_iou = Tensor(self.min_pos_iou,
                                         dtype=mstype.float16)

        self.expand_dims = P.ExpandDims()
        self.split = P.Split(axis=1, output_num=4)
        self.concat_last_axis = P.Concat(axis=-1)
        self.round = P.Round()
        self.image_h_w = Tensor(
            [cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width],
            dtype=mstype.float16)
        self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2)
        self.crop_and_resize = P.CropAndResize(method="bilinear_v2")
        self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1])
        self.squeeze_mask_last = P.Squeeze(axis=-1)
예제 #22
0
    def _log_prob(self, value, probs=None):
        r"""
        Evaluate log probability.

        Args:
            value (Tensor): The value to be evaluated.
            probs (Tensor): Event probabilities. Default: self.probs.
        """
        value = self._check_value(value, 'value')

        probs = self._check_param_type(probs)
        logits = self.log(probs)

        # find the right integer to compute index
        # here we simulate casting to int but still keeping float dtype
        value = self.cast(value, self.dtypeop(probs))

        zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
        neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
        value = self.select(self.is_nan(value), neg_one, value)
        between_zero_neone = self.logicand(self.less(
            value,
            0,
        ), self.greater(value, -1.))
        value = self.select(between_zero_neone, zeros, P.Floor()(value))

        # handle the case when value is of shape () and probs is a scalar batch
        drop_dim = False
        if self.shape(value) == () and self.shape(probs)[:-1] == ():
            drop_dim = True
            # manually add one more dimension: () -> (1,)
            # drop this dimension before return
            value = self.expand_dim(value, -1)

        value = self.expand_dim(value, -1)

        broadcast_shape_tensor = logits * value
        broadcast_shape = self.shape(broadcast_shape_tensor)
        num_classes = broadcast_shape[-1]
        label_shape = broadcast_shape[:-1]

        # broadcasting logits and value
        # logit_pmf shape (num of labels, C)
        logits = self.broadcast(logits, broadcast_shape_tensor)
        value = self.broadcast(value, broadcast_shape_tensor)[..., :1]

        # flatten value to shape (number of labels, 1)
        # clip value to be in range from 0 to num_classes -1 and cast into int32
        value = self.reshape(value, (-1, 1))
        out_of_bound = self.squeeze_last_axis(self.logicor(\
                        self.less(value, 0.0), self.less(num_classes-1, value)))
        # deal with the case the there is only one class.
        value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
        value_clipped = self.cast(value_clipped, self.index_type)
        # create index from 0 ... NumOfLabels
        index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
        index = self.concat((index, value_clipped))

        # index into logit_pmf, fill in out_of_bound places with -inf
        # reshape into label shape N
        logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)),
                                 index)
        nan = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf),
                        self.nan)
        logits_pmf = self.select(out_of_bound, nan, logits_pmf)
        ans = self.reshape(logits_pmf, label_shape)
        if drop_dim:
            return self.squeeze(ans)
        return ans