Ejemplo n.º 1
0
def subsample_labels(labels,
                     num_samples,
                     fg_fraction,
                     bg_label=0,
                     use_random=True):
    positive = paddle.nonzero(
        paddle.logical_and(labels != -1, labels != bg_label))
    negative = paddle.nonzero(labels == bg_label)

    positive = positive.cast('int32').flatten()
    negative = negative.cast('int32').flatten()

    fg_num = int(num_samples * fg_fraction)
    fg_num = min(positive.numel(), fg_num)
    bg_num = num_samples - fg_num
    bg_num = min(negative.numel(), bg_num)
    # randomly select positive and negative examples
    fg_perm = paddle.randperm(positive.numel(), dtype='int32')
    fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
    bg_perm = paddle.randperm(negative.numel(), dtype='int32')
    bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
    if use_random:
        fg_inds = paddle.gather(positive, fg_perm)
        bg_inds = paddle.gather(negative, bg_perm)
    else:
        fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
        bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
    return fg_inds, bg_inds
Ejemplo n.º 2
0
 def row_column_shuffle(embedding):
     embedding = paddle.transpose(embedding, perm=[1, 0])
     corrupted_embedding = paddle.transpose(embedding[paddle.randperm(
         paddle.shape(embedding)[0])],
                                            perm=[1, 0])
     return corrupted_embedding[paddle.randperm(
         paddle.shape(corrupted_embedding)[0])]
Ejemplo n.º 3
0
    def test_generator_randperm_static(self):

        fluid.disable_dygraph()

        paddle.seed(123123143)

        startup_program = fluid.Program()
        train_program = fluid.Program()
        with fluid.program_guard(train_program, startup_program):
            # example 1:
            # attr shape is a list which doesn't contain tensor Variable.
            result_1 = paddle.randperm(10)
            result_2 = paddle.randperm(10)

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(startup_program)
            out1 = exe.run(train_program,
                           feed={},
                           fetch_list=[result_1, result_2])

            paddle.seed(123123143)
            out2 = exe.run(train_program,
                           feed={},
                           fetch_list=[result_1, result_2])

            out1_res1 = np.array(out1[0])
            out1_res2 = np.array(out1[1])
            out2_res1 = np.array(out2[0])
            out2_res2 = np.array(out2[1])

            if not core.is_compiled_with_cuda():
                print(">>>>>>> randperm static >>>>>>>")
                self.assertTrue(np.allclose(out1_res1, out2_res1))
                self.assertTrue(np.allclose(out1_res2, out2_res2))
                self.assertTrue(not np.allclose(out1_res2, out1_res1))
Ejemplo n.º 4
0
    def test_check_output(self):
        with fluid.dygraph.guard():
            n = 10
            data_1 = paddle.randperm(n, dtype="int64")
            data_1_np = data_1.numpy()
            self.assertTrue(check_randperm_out(n, data_1_np),
                            msg=error_msg(data_1_np))

            data_2 = paddle.randperm(n, dtype="int32", device="cpu")
            data_2_np = data_2.numpy()
            self.assertTrue(check_randperm_out(n, data_2_np),
                            msg=error_msg(data_2_np))
Ejemplo n.º 5
0
    def test_out(self):
        n = 10
        place = paddle.NPUPlace(0)
        with program_guard(Program(), Program()):
            x1 = paddle.randperm(n)
            x2 = paddle.randperm(n, 'float32')

            exe = paddle.static.Executor(place)
            res = exe.run(fetch_list=[x1, x2])

            self.assertEqual(res[0].dtype, np.int64)
            self.assertEqual(res[1].dtype, np.float32)
            self.assertTrue(check_randperm_out(n, res[0]))
            self.assertTrue(check_randperm_out(n, res[1]))
Ejemplo n.º 6
0
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = paddle.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = paddle.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = paddle.distributed.get_rank()
        idx_this = idx_shuffle.reshape([num_gpus, -1])[gpu_idx]
        return paddle.index_select(x_gather, idx_this), idx_unshuffle
Ejemplo n.º 7
0
def sparse_(tensor, sparsity, std=0.01):
    r"""Fills the 2D input `Tensor` as a sparse matrix, where the
    non-zero elements will be drawn from the normal distribution
    :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
    Hessian-free optimization` - Martens, J. (2010).

    Args:
        tensor: an n-dimensional `torch.Tensor`
        sparsity: The fraction of elements in each column to be set to zero
        std: the standard deviation of the normal distribution used to generate
            the non-zero values

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.sparse_(w, sparsity=0.1)
    """
    if tensor.ndimension() != 2:
        raise ValueError("Only tensors with 2 dimensions are supported")

    rows, cols = tensor.shape
    num_zeros = int(math.ceil(sparsity * rows))

    with paddle.no_grad():
        tensor.normal_(0, std)
        for col_idx in range(cols):
            row_indices = paddle.randperm(rows)
            zero_indices = row_indices[:num_zeros]
            tensor[zero_indices, col_idx] = 0
    return tensor
Ejemplo n.º 8
0
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this
        idx_shuffle = paddle.randperm(batch_size_all)

        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = paddle.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = paddle.distributed.get_rank()
        idx_this = idx_shuffle.reshape([num_gpus, -1])[gpu_idx]

        x = paddle.gather(x_gather, idx_this, axis=0)

        return x, idx_unshuffle
Ejemplo n.º 9
0
    def node_batch_iter(self, batch_size, shuffle=True):
        """Node batch iterator

        Iterate all node by batch.

        Args:
            batch_size: The batch size of each batch of nodes.

            shuffle: Whether shuffle the nodes.

        Return:
            Batch iterator
        """
        if self.is_tensor():
            if shuffle:
                perm = paddle.randperm(self.num_nodes)
            else:
                perm = paddle.arange(self.num_nodes)
        else:
            perm = np.arange(self.num_nodes)
            if shuffle:
                np.random.shuffle(perm)

        start = 0
        while start < self.num_nodes:
            yield perm[start:start + batch_size]
            start += batch_size
Ejemplo n.º 10
0
 def test_out(self):
     paddle.disable_static()
     n = 10
     for dtype in ['int32', np.int64, 'float32', 'float64']:
         data_p = paddle.randperm(n, dtype)
         data_np = data_p.numpy()
         self.assertTrue(check_randperm_out(n, data_np),
                         msg=error_msg(data_np))
     paddle.enable_static()
Ejemplo n.º 11
0
def random_split(dataset, lengths, generator=None):
    """
    Randomly split a dataset into non-overlapping new datasets of given lengths.
    Optionally fix the generator for reproducible results, e.g.:

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
        generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().

     Returns:
        Datasets: A list of subset Datasets, which are the non-overlapping subsets of the original Dataset.

    Example code:

        .. code-block:: python

            import paddle
            from paddle.io import random_split

            a_list = paddle.io.random_split(range(10), [3, 7])
            print(len(a_list)) 
            # 2

            for idx, v in enumerate(a_list[0]):
                print(idx, v)

            # output of the first subset
            # 0 1
            # 1 3
            # 2 9

            for idx, v in enumerate(a_list[1]):
                print(idx, v)
            # output of the second subset
            # 0 5
            # 1 7
            # 2 8
            # 3 6
            # 4 0
            # 5 2
            # 6 4
    """
    # Cannot verify that dataset is Sized
    if sum(lengths) != len(dataset):  # type: ignore
        raise ValueError(
            "Sum of input lengths does not equal the length of the input dataset!"
        )
    # TODO(@Joejiong): support Variable or Tensor type with .tolist class member function.
    # For example var.item() and var.tolist()
    indices = paddle.randperm(sum(lengths)).numpy().tolist()
    return [
        Subset(dataset, indices[offset - length:offset])
        for offset, length in zip(_accumulate(lengths), lengths)
    ]
Ejemplo n.º 12
0
    def test_attr_tensor_API(self):
        startup_program = fluid.Program()
        train_program = fluid.Program()
        with fluid.program_guard(train_program, startup_program):
            n = 10
            data_1 = fluid.layers.fill_constant([n], "int64", 3)
            paddle.randperm(n=n, out=data_1)

            data_2 = paddle.randperm(n=n, dtype="int32", device="cpu")

            place = fluid.CPUPlace()
            if fluid.core.is_compiled_with_cuda():
                place = fluid.CUDAPlace(0)
            exe = fluid.Executor(place)

            exe.run(startup_program)
            outs = exe.run(train_program, fetch_list=[data_1, data_2])

            out_np = np.array(outs[0])
            self.assertTrue(check_randperm_out(n, out_np),
                            msg=error_msg(out_np))
Ejemplo n.º 13
0
def my_gt_argmax(overlaps):
    gt_max_overlaps = torch.max(overlaps, axis=0)
    gt_max_mask = overlaps == gt_max_overlaps
    gt_argmax_overlaps = []
    for i in range(overlaps.shape[-1]):
        gt_max_inds = torch.nonzero(gt_max_mask.cast('int')[:, i],
                                    as_tuple=False).flatten()
        gt_max_ind = torch.gather(gt_max_inds,
                                  torch.randperm(gt_max_inds.numel())[0])
        gt_argmax_overlaps.append(gt_max_ind)
    gt_argmax_overlaps = cat(gt_argmax_overlaps)
    return gt_argmax_overlaps
Ejemplo n.º 14
0
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.shape[0]
    index = paddle.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * paddle.index_select(x, index)  #xb# x[index, :]
    y_a, y_b = y, paddle.index_select(y, index)  #paddle.concat([y[int(i):int(i+1)] for i in index])# y[index]
    mixed_target = (y_a, y_b, lam)
    return mixed_x, mixed_target
Ejemplo n.º 15
0
    def test_generator_randperm_dygraph(self):
        """Test Generator seed."""

        fluid.enable_dygraph()

        gen = paddle.seed(12312321111)
        x = paddle.randperm(10)
        st1 = gen.get_state()
        x1 = paddle.randperm(10)
        gen.set_state(st1)
        x2 = paddle.randperm(10)
        gen.manual_seed(12312321111)
        x3 = paddle.randperm(10)
        x_np = x.numpy()
        x1_np = x1.numpy()
        x2_np = x2.numpy()
        x3_np = x3.numpy()

        if not core.is_compiled_with_cuda():
            print(">>>>>>> randperm dygraph >>>>>>>")
            self.assertTrue(np.allclose(x1_np, x2_np))
            self.assertTrue(np.allclose(x_np, x3_np))
Ejemplo n.º 16
0
def subsample_labels(labels, num_samples, positive_fraction):
    positive = torch.nonzero(mul((labels != config.ignore_label).cast('int'),
                                 (labels != 0).cast('int')).cast('bool'),
                             as_tuple=False).squeeze(1)
    negative = torch.nonzero(labels == 0, as_tuple=False).squeeze(1)

    num_pos = int(num_samples * positive_fraction)
    num_pos = min(positive.numel(), num_pos)
    num_neg = num_samples - num_pos
    num_neg = min(negative.numel(), num_neg)

    # randomly select positive and negative examples
    if type(num_pos) == torch.Tensor:
        num_pos = num_pos.numpy().item()
    if type(num_neg) == torch.Tensor:
        num_neg = num_neg.numpy().item()
    perm1 = torch.randperm(positive.numel())[:num_pos]
    perm2 = torch.randperm(negative.numel())[:num_neg]

    pos_idx = torch.gather(positive, perm1)
    neg_idx = torch.gather(negative, perm2)
    return pos_idx, neg_idx
Ejemplo n.º 17
0
def random_split(dataset, lengths, generator=None):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.
    Optionally fix the generator for reproducible results, e.g.:

    >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
        generator (Generator): from torch import default_generator, which is not use in paddle.
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = paddle.randperm(sum(lengths))
    return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
Ejemplo n.º 18
0
def mixup_data(x, y, alpha=1.0):
    """Mix the input data and label using mixup strategy,  returns mixed inputs,
    pairs of targets, and lambda

    Reference:
    Zhang, Hongyi, et al. “Mixup: Beyond Empirical Risk Minimization.” International Conference on Learning Representations, 2017.

    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.shape[0]
    index = paddle.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * paddle.index_select(x, index)
    y_a, y_b = y, paddle.index_select(y, index)
    mixed_target = (y_a, y_b, lam)
    return mixed_x, mixed_target
Ejemplo n.º 19
0
def train(args):
    # 使用 GPU训练
    if paddle.is_compiled_with_cuda():
        paddle.set_device("gpu:0")
    # 创建多进程的游戏环境
    envs = MultipleEnvironments(args.game, args.num_processes)
    # 固定初始化状态
    paddle.seed(123)
    # 创建模型
    model = Model(envs.num_states, envs.num_actions)
    # 加载预训练模型
    if args.trained_model is not None:
        model.load_dict(paddle.load(args.trained_model))
    # 创建保存模型的文件夹
    if not os.path.isdir(args.saved_path):
        os.makedirs(args.saved_path)
    paddle.save(model.state_dict(),
                "{}/model_{}.pdparams".format(args.saved_path, args.game))
    # 为游戏评估单独开一个进程
    mp = _mp.get_context("spawn")
    process = mp.Process(target=eval,
                         args=(args, envs.num_states, envs.num_actions))
    process.start()
    # 创建优化方法
    clip_grad = paddle.nn.ClipGradByNorm(clip_norm=0.5)
    optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
                                      learning_rate=args.lr,
                                      grad_clip=clip_grad)
    # 刚开始给每个进程的游戏执行初始化
    [agent_conn.send(("reset", None)) for agent_conn in envs.agent_conns]
    # 获取游戏初始的界面
    curr_states = [agent_conn.recv() for agent_conn in envs.agent_conns]
    curr_states = paddle.to_tensor(np.concatenate(curr_states, 0),
                                   dtype='float32')
    curr_episode = 0
    while True:
        curr_episode += 1
        old_log_policies, actions, values, states, rewards, dones = [], [], [], [], [], []
        for _ in range(args.num_local_steps):
            states.append(curr_states)
            # 执行预测
            logits, value = model(curr_states)
            # 计算每个动作的概率值
            policy = F.softmax(logits)
            # 根据每个标签的概率随机生成符合概率的标签
            old_m = Categorical(policy)
            action = old_m.sample([1]).squeeze()
            # 记录预测数据
            actions.append(action)
            values.append(value.squeeze())
            # 计算类别的概率的对数
            old_log_policy = old_m.log_prob(paddle.unsqueeze(action, axis=1))
            old_log_policy = paddle.squeeze(old_log_policy)
            old_log_policies.append(old_log_policy)
            # 向各个进程游戏发送动作
            [
                agent_conn.send(("step", int(act[0])))
                for agent_conn, act in zip(envs.agent_conns, action)
            ]
            # 将多进程的游戏数据打包
            state, reward, done, info = zip(
                *[agent_conn.recv() for agent_conn in envs.agent_conns])
            # 进行数据转换
            state = paddle.to_tensor(np.concatenate(state, 0), dtype='float32')
            # 转换为tensor数据
            reward = paddle.to_tensor(reward, dtype='float32')
            done = paddle.to_tensor(done, dtype='float32')
            # 记录预测数据
            rewards.append(reward)
            dones.append(done)
            curr_states = state
        # 根据上面最后的图像预测
        _, next_value, = model(curr_states)
        next_value = next_value.squeeze()
        old_log_policies = paddle.concat(old_log_policies).detach().squeeze()
        actions = paddle.concat(actions).squeeze()
        values = paddle.concat(values).squeeze().detach()
        states = paddle.concat(states).squeeze()

        gae = 0.0
        R = []
        for value, reward, done in list(zip(values, rewards, dones))[::-1]:
            gae = gae * args.gamma * args.tau
            gae = gae + reward + args.gamma * next_value.detach() * (
                1.0 - done) - value.detach()
            next_value = value
            R.append(gae + value)
        R = R[::-1]
        R = paddle.concat(R).detach()
        advantages = R - values
        for i in range(args.num_epochs):
            indice = paddle.randperm(args.num_local_steps * args.num_processes)
            for j in range(args.batch_size):
                batch_indices = indice[int(j * (
                    args.num_local_steps * args.num_processes / args.batch_size
                )):int((j + 1) * (args.num_local_steps * args.num_processes /
                                  args.batch_size))]
                # 根据拿到的图像执行预测
                logits, value = model(paddle.gather(states, batch_indices))
                # 计算每个动作的概率值
                new_policy = F.softmax(logits)
                # 计算类别的概率的对数
                new_m = Categorical(new_policy)
                new_log_policy = new_m.log_prob(
                    paddle.unsqueeze(paddle.gather(actions, batch_indices),
                                     axis=1))
                new_log_policy = paddle.squeeze(new_log_policy)
                # 计算actor损失
                ratio = paddle.exp(
                    new_log_policy -
                    paddle.gather(old_log_policies, batch_indices))
                advantage = paddle.gather(advantages, batch_indices)
                actor_loss = paddle.clip(ratio, 1.0 - args.epsilon,
                                         1.0 + args.epsilon) * advantage
                actor_loss = paddle.concat([
                    paddle.unsqueeze(ratio * advantage, axis=0),
                    paddle.unsqueeze(actor_loss, axis=0)
                ])
                actor_loss = -paddle.mean(paddle.min(actor_loss, axis=0))
                # 计算critic损失
                critic_loss = F.smooth_l1_loss(paddle.gather(R, batch_indices),
                                               value.squeeze())
                entropy_loss = paddle.mean(new_m.entropy())
                # 计算全部损失
                total_loss = actor_loss + critic_loss - args.beta * entropy_loss
                # 计算梯度
                total_loss.backward()
                optimizer.step()
                optimizer.clear_grad()
            paddle.save(
                model.state_dict(),
                "{}/model_{}.pdparams".format(args.saved_path, args.game))
        print("Episode: {}. Total loss: {:.4f}".format(curr_episode,
                                                       total_loss.numpy()[0]))
Ejemplo n.º 20
0
 def row_shuffle(embedding):
     return embedding[paddle.randperm(paddle.shape(embedding)[0])]
Ejemplo n.º 21
0
 def test_Variable():
     out = np.arange(10)
     paddle.randperm(n=10, out=out)
Ejemplo n.º 22
0
 def test_value():
     paddle.randperm(n=-3)
Ejemplo n.º 23
0
def randperm(n):
    return convertTensor(paddle.randperm(n, dtype='int32', name=None))
Ejemplo n.º 24
0
    def test_fixed_random_number(self):
        # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
        if not paddle.is_compiled_with_cuda():
            return

        print("Test Fixed Random number on GPU------>")
        paddle.disable_static()
        paddle.set_device('gpu')
        paddle.seed(2021)

        x = paddle.randperm(30000, dtype='int32').numpy()
        expect = [
            24562, 8409, 9379, 10328, 20503, 18059, 9681, 21883, 11783, 27413
        ]
        self.assertTrue(np.array_equal(x[0:10], expect))
        expect = [
            29477, 27100, 9643, 16637, 8605, 16892, 27767, 2724, 1612, 13096
        ]
        self.assertTrue(np.array_equal(x[10000:10010], expect))
        expect = [
            298, 4104, 16479, 22714, 28684, 7510, 14667, 9950, 15940, 28343
        ]
        self.assertTrue(np.array_equal(x[20000:20010], expect))

        x = paddle.randperm(30000, dtype='int64').numpy()
        expect = [
            6587, 1909, 5525, 23001, 6488, 14981, 14355, 3083, 29561, 8171
        ]
        self.assertTrue(np.array_equal(x[0:10], expect))
        expect = [
            23460, 12394, 22501, 5427, 20185, 9100, 5127, 1651, 25806, 4818
        ]
        self.assertTrue(np.array_equal(x[10000:10010], expect))
        expect = [5829, 4508, 16193, 24836, 8526, 242, 9984, 9243, 1977, 11839]
        self.assertTrue(np.array_equal(x[20000:20010], expect))

        x = paddle.randperm(30000, dtype='float32').numpy()
        expect = [
            5154., 10537., 14362., 29843., 27185., 28399., 27561., 4144.,
            22906., 10705.
        ]
        self.assertTrue(np.array_equal(x[0:10], expect))
        expect = [
            1958., 18414., 20090., 21910., 22746., 27346., 22347., 3002.,
            4564., 26991.
        ]
        self.assertTrue(np.array_equal(x[10000:10010], expect))
        expect = [
            25580., 12606., 553., 16387., 29536., 4241., 20946., 16899.,
            16339., 4662.
        ]
        self.assertTrue(np.array_equal(x[20000:20010], expect))

        x = paddle.randperm(30000, dtype='float64').numpy()
        expect = [
            19051., 2449., 21940., 11121., 282., 7330., 13747., 24321., 21147.,
            9163.
        ]
        self.assertTrue(np.array_equal(x[0:10], expect))
        expect = [
            15483., 1315., 5723., 20954., 13251., 25539., 5074., 1823., 14945.,
            17624.
        ]
        self.assertTrue(np.array_equal(x[10000:10010], expect))
        expect = [
            10516., 2552., 29970., 5941., 986., 8007., 24805., 26753., 12202.,
            21404.
        ]
        self.assertTrue(np.array_equal(x[20000:20010], expect))
        paddle.enable_static()
    def _hard_anchor_sampling(self, X, y_hat, y):
        """
        Args:
            X (Tensor): reshaped feats, shape = [N, H * W, feat_channels]
            y_hat (Tensor): reshaped label, shape = [N, H * W]
            y (Tensor): reshaped predict, shape = [N, H * W]
        """
        batch_size, feat_dim = paddle.shape(X)[0], paddle.shape(X)[-1]
        classes = []
        total_classes = 0
        for i in range(batch_size):
            current_y = y_hat[i]
            current_classes = paddle.unique(current_y)
            current_classes = [
                x for x in current_classes if x != self.ignore_index
            ]
            current_classes = [
                x for x in current_classes
                if (current_y == x).nonzero().shape[0] > self.max_views
            ]

            classes.append(current_classes)
            total_classes += len(current_classes)

        n_view = self.max_samples // total_classes
        n_view = min(n_view, self.max_views)

        X_ = []
        y_ = paddle.zeros([total_classes], dtype='float32')

        X_ptr = 0
        for i in range(batch_size):
            this_y_hat = y_hat[i]
            current_y = y[i]
            current_classes = classes[i]

            for cls_id in current_classes:
                hard_indices = paddle.logical_and(
                    (this_y_hat == cls_id), (current_y != cls_id)).nonzero()
                easy_indices = paddle.logical_and(
                    (this_y_hat == cls_id), (current_y == cls_id)).nonzero()

                num_hard = hard_indices.shape[0]
                num_easy = easy_indices.shape[0]

                if num_hard >= n_view / 2 and num_easy >= n_view / 2:
                    num_hard_keep = n_view // 2
                    num_easy_keep = n_view - num_hard_keep
                elif num_hard >= n_view / 2:
                    num_easy_keep = num_easy
                    num_hard_keep = n_view - num_easy_keep
                else:
                    num_hard_keep = num_hard
                    num_easy_keep = n_view - num_hard_keep

                indices = None
                if num_hard > 0:
                    perm = paddle.randperm(num_hard)
                    hard_indices = hard_indices[perm[:num_hard_keep]].reshape(
                        (-1, hard_indices.shape[-1]))
                    indices = hard_indices
                if num_easy > 0:
                    perm = paddle.randperm(num_easy)
                    easy_indices = easy_indices[perm[:num_easy_keep]].reshape(
                        (-1, easy_indices.shape[-1]))
                    if indices is None:
                        indices = easy_indices
                    else:
                        indices = paddle.concat((indices, easy_indices),
                                                axis=0)
                if indices is None:
                    raise UserWarning('hard sampling indice error')

                X_.append(paddle.index_select(X[i, :, :], indices.squeeze(1)))
                y_[X_ptr] = float(cls_id)
                X_ptr += 1
        X_ = paddle.stack(X_, axis=0)
        return X_, y_