示例#1
0
    def datasets(self, params: SemiSupervisedParams):
        dataset_fn = datasets[params.dataset]

        test_x, test_y = dataset_fn(False)
        train_x, train_y = dataset_fn(True)

        indexs, un_indexs, val_indexs = splits.semi_split(
            train_y,
            n_percls=params.n_percls,
            val_size=params.val_size,
            repeat_sup=False)
        self.logger.info('sup/unsup/val : {}'.format(
            (len(indexs), len(un_indexs), len(val_indexs))))
        mean, std = norm_val.get(params.dataset, [None, None])
        toTensor = ToNormTensor(mean, std)
        weak = Weak(mean, std)
        strong = Strong(mean, std)

        sup_set = (DatasetBuilder(
            train_x, train_y).add_x(transform=weak).add_y().subset(indexs))
        if len(sup_set) < params.batch_size:
            sup_set.virtual_sample(params.batch_size)

        unsup_set = (DatasetBuilder(train_x, train_y).toggle_id().add_x(
            transform=weak).add_x(transform=strong).add_y().subset(un_indexs))
        self.cl_set = unsup_set

        sup_dataloader = sup_set.DataLoader(batch_size=params.batch_size,
                                            num_workers=params.num_workers,
                                            shuffle=True)
        self.sup_dataloader = sup_dataloader

        unsup_dataloader = unsup_set.DataLoader(batch_size=params.batch_size *
                                                params.uratio,
                                                num_workers=1,
                                                shuffle=True)

        self.unsup_dataloader = DataBundler().add(unsup_dataloader).to(
            self.device)

        val_dataloader = (DatasetBuilder(
            train_x[val_indexs],
            train_y[val_indexs]).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers,
                shuffle=True))

        test_dataloader = (DatasetBuilder(
            test_x, test_y).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers,
                shuffle=True))

        self.regist_databundler(train=DataBundler().cycle(sup_dataloader).add(
            unsup_dataloader).zip_mode(),
                                eval=val_dataloader,
                                test=test_dataloader)
        self.to(self.device)
示例#2
0
    def datasets(self, params: FitEvalParams):
        from thexp import DatasetBuilder
        from data.constant import norm_val
        from data.transforms import ToNormTensor, Weak, Strong

        mean, std = norm_val.get(params.dataset, [None, None])
        weak = Weak(mean, std)

        dataset_fn = datasets.datasets[params.dataset]
        train_x, train_y = dataset_fn(True)
        train_y = np.array(train_y)
        from data.noisy import symmetric_noisy

        train_x, train_y = train_x[:params.train_size], train_y[:params.train_size]
        part_size = params.train_size // 2
        noisy_x1, noisy_true_y1 = train_x[:part_size], train_y[:part_size]
        noisy_x2, noisy_true_y2 = train_x[part_size:], train_y[part_size:]

        noisy_ratio = 0.9
        noisy_y1 = symmetric_noisy(noisy_true_y1, noisy_ratio, n_classes=params.n_classes)
        noisy_y2 = symmetric_noisy(noisy_true_y2, noisy_ratio, n_classes=params.n_classes)

        self.logger.info('noisy dataset ratio: ',
                         (noisy_true_y1 == noisy_y1).mean(),
                         (noisy_true_y2 == noisy_y2).mean())

        noisy_set1 = (
            DatasetBuilder(noisy_x1, noisy_true_y1)
                .add_labels(noisy_y2, source_name='noisy_y')
                .toggle_id()
                .add_x(transform=weak)
                .add_y()
                .add_y(source='noisy_y')
        )
        noisy_set2 = (
            DatasetBuilder(noisy_x2, noisy_true_y2)
                .add_labels(noisy_y2, source_name='noisy_y')
                .toggle_id()
                .add_x(transform=weak)
                .add_y()
                .add_y(source='noisy_y')
        )

        self.noisy_loader1 = noisy_set1.DataLoader(batch_size=params.batch_size,
                                                   num_workers=params.num_workers,
                                                   drop_last=True,
                                                   shuffle=True)
        self.noisy_loader2 = noisy_set2.DataLoader(batch_size=params.batch_size,
                                                   num_workers=params.num_workers,
                                                   drop_last=True,
                                                   shuffle=True)
        self.eval_state = 1 - int(params.eval_mode == 'clean')
        self.toggle_dataset((self.eval_state % 2) == 0)
        self.to(self.device)
示例#3
0
    def datasets(self, params: DivideMixParams):
        from data.dataxy import datasets
        dataset_fn = datasets[params.dataset]

        test_x, test_y = dataset_fn(False)
        train_x, train_y = dataset_fn(True)

        mean, std = norm_val.get(params.dataset, [None, None])
        toTensor = ToNormTensor(mean, std)
        weak = Weak(mean, std)
        strong = Strong(mean, std)

        if params.noisy_type == 'asymmetric':
            from data.noisy import asymmetric_noisy
            noisy_y = asymmetric_noisy(train_y,
                                       params.noisy_ratio,
                                       n_classes=params.n_classes)

        elif params.noisy_type == 'symmetric':
            from data.noisy import symmetric_noisy
            noisy_y = symmetric_noisy(train_y,
                                      params.noisy_ratio,
                                      n_classes=params.n_classes)

        else:
            assert False, params.noisy_type
        self.train_set_pack = [train_x, np.array(train_y), noisy_y]

        self.logger.info('noisy acc = {}'.format((train_y == noisy_y).mean()))

        train_set = (DatasetBuilder(train_x, train_y).add_labels(
            noisy_y, 'noisy_y').toggle_id().add_x(transform=weak).add_x(
                transform=strong).add_y().add_y(source='noisy_y'))
        train_dataloader = train_set.DataLoader(batch_size=params.batch_size *
                                                2,
                                                num_workers=params.num_workers,
                                                shuffle=True)
        from thexp import DataBundler

        self.eval_train_dataloader = (DataBundler().add(
            DatasetBuilder(train_x, noisy_y).toggle_id().add_x(
                transform=toTensor).add_y().DataLoader(
                    batch_size=params.batch_size,
                    num_workers=params.num_workers // 2,
                    shuffle=False)).to(self.device))

        test_dataloader = (DatasetBuilder(
            test_x, test_y).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers // 2,
                shuffle=False))

        self.regist_databundler(train=train_dataloader, test=test_dataloader)
        self.to(self.device)
示例#4
0
    def on_train_epoch_begin(self, trainer: Trainer, func,
                             params: DivideMixParams, *args, **kwargs):
        if params.eidx < params.warm_up:
            pass
        else:
            if params.eidx % 2 == 0:
                prob, self.all_loss[1] = self.eval_train(
                    self.model2, self.all_loss[1])  # type: np.ndarray, list
                pred = (prob > params.p_threshold)
            else:
                prob, self.all_loss[0] = self.eval_train(
                    self.model, self.all_loss[0])  # type: np.ndarray, list
                pred = (prob > params.p_threshold)

            pred_idx = pred.nonzero()[0]
            unpred_idx = (1 - pred).nonzero()[0]

            train_x, train_y, noisy_y = self.train_set_pack
            clean = (noisy_y == train_y)
            acc = (pred[clean]).mean()
            self.logger.info('Numer of labeled samples', pred.sum(),
                             'clean ratio = {}'.format(acc))

            mean, std = norm_val.get(params.dataset, [None, None])
            weak = Weak(mean, std)

            labeled_dataloader = (DatasetBuilder(train_x, train_y).add_labels(
                noisy_y, source_name='nys').add_labels(
                    prob, source_name='nprob').add_x(transform=weak).add_x(
                        transform=weak).add_y().add_y(source='nys').add_y(
                            source='nprob').subset(pred_idx).DataLoader(
                                params.batch_size,
                                shuffle=True,
                                drop_last=True,
                                num_workers=params.num_workers))

            unlabeled_dataloader = (DatasetBuilder(
                train_x, train_y).add_labels(noisy_y, source_name='nys').add_x(
                    transform=weak).add_x(transform=weak).add_y().add_y(
                        source='nys').subset(unpred_idx).DataLoader(
                            params.batch_size,
                            shuffle=True,
                            drop_last=True,
                            num_workers=params.num_workers))
            bundler = DataBundler()
            bundler.add(labeled_dataloader).cycle(
                unlabeled_dataloader).zip_mode()
            self.logger.info('new training dataset', bundler)
            self.regist_databundler(train=bundler.to(self.device))
示例#5
0
    def datasets(self, params: SemiSupervisedParams):
        dataset_fn = datasets[params.dataset]

        test_x, test_y = dataset_fn(False)
        train_x, train_y = dataset_fn(True)

        indexs, un_indexs, val_indexs = splits.semi_split(
            train_y, n_percls=params.n_percls, val_size=5000, repeat_sup=False)

        mean, std = norm_val.get(params.dataset, [None, None])
        toTensor = ToNormTensor(mean, std)
        weak = Weak(mean, std)

        sup_set = (DatasetBuilder(
            train_x, train_y).add_x(transform=weak).add_y().subset(indexs))

        params.K = params.default(2, True)
        unsup_set = DatasetBuilder(train_x, train_y)
        for _ in range(params.K):
            unsup_set.add_x(transform=weak)
        unsup_set = unsup_set.add_y().subset(un_indexs)

        sup_dataloader = sup_set.DataLoader(batch_size=params.batch_size,
                                            num_workers=params.num_workers,
                                            shuffle=True)

        unsup_dataloader = unsup_set.DataLoader(batch_size=params.batch_size,
                                                num_workers=params.num_workers,
                                                shuffle=True)

        val_dataloader = (DatasetBuilder(
            train_x[val_indexs],
            train_y[val_indexs]).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers,
                shuffle=True))

        test_dataloader = (DatasetBuilder(
            test_x, test_y).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers,
                shuffle=True))

        self.regist_databundler(train=DataBundler().cycle(sup_dataloader).add(
            unsup_dataloader).zip_mode(),
                                eval=val_dataloader,
                                test=test_dataloader)
        self.to(self.device)
示例#6
0
    def datasets(self, params: GlobalParams):
        dataset_fn = datasets[params.dataset]

        test_x, testy = dataset_fn(False)
        train_x, train_y = dataset_fn(True)

        train_idx, val_idx = splits.train_val_split(train_y,
                                                    val_size=params.val_size)

        mean, std = norm_val.get(params.dataset, [None, None])
        toTensor = ToNormTensor(mean, std)
        weak = Weak(mean, std)
        strong = Strong(mean, std)

        test_dataloader = (DatasetBuilder(
            test_x, testy).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size, num_workers=params.num_workers))

        train_dataloader = (DatasetBuilder(train_x, train_y).toggle_id().add_x(
            transform=weak).add_x(transform=strong).add_y().subset(train_idx))

        if params.distributed:
            from torch.utils.data.distributed import DistributedSampler
            sampler = DistributedSampler(train_dataloader)
        else:
            sampler = None
        self.train_size = len(train_dataloader)
        train_dataloader = train_dataloader.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            sampler=sampler,
            shuffle=not params.distributed)

        val_datalaoder = (DatasetBuilder(train_x, train_y).add_x(
            transform=toTensor).add_y().subset(val_idx).DataLoader(
                batch_size=params.batch_size, num_workers=params.num_workers))

        self.regist_databundler(train=train_dataloader,
                                eval=val_datalaoder,
                                test=test_dataloader)
        print('dataloader in rank {}'.format(self.params.local_rank))
        print(self.params.local_rank, self.train_dataloader)
        print(self.params.local_rank, self._databundler_dict)
        print(self.params.local_rank, train_dataloader)
        self.to(self.device)
    def datasets(self, params: NoisyParams):
        self.rnd.mark('kk')
        params.noisy_type = params.default('symmetric', True)
        params.noisy_ratio = params.default(0.2, True)

        from data.constant import norm_val
        mean, std = norm_val.get(params.dataset, [None, None])
        from data.transforms import ToNormTensor
        toTensor = ToNormTensor(mean, std)
        from data.transforms import Weak
        weak = Weak(mean, std)
        from data.transforms import Strong

        dataset_fn = datasets.datasets[params.dataset]
        train_x, train_y = dataset_fn(True)
        train_y = np.array(train_y)
        from thexp import DatasetBuilder

        from data.noisy import symmetric_noisy
        noisy_y = symmetric_noisy(train_y,
                                  params.noisy_ratio,
                                  n_classes=params.n_classes)
        clean_mask = (train_y == noisy_y)

        noisy_mask = np.logical_not(clean_mask)
        noisy_mask = np.where(noisy_mask)[0]

        nmask_a = noisy_mask[:len(noisy_mask) // 2]
        nmask_b = noisy_mask[len(noisy_mask) // 2:]

        clean_x, clean_y = train_x[clean_mask], noisy_y[clean_mask]
        clean_true_y = train_y[clean_mask]

        raw_x, raw_true_y = train_x[nmask_a], train_y[nmask_a]
        raw_y = noisy_y[nmask_a]

        change_x, change_true_y, change_y = train_x[nmask_b], train_y[
            nmask_b], noisy_y[nmask_b]

        first_x, first_y, first_true_y = (
            clean_x + raw_x,
            np.concatenate([clean_y, raw_y]),
            np.concatenate([clean_true_y, raw_true_y]),
        )

        second_x, second_y, second_true_y = (
            clean_x + change_x,
            np.concatenate([clean_y, change_y]),
            np.concatenate([clean_true_y, change_true_y]),
        )

        first_set = (DatasetBuilder(first_x, first_true_y).add_labels(
            first_y, 'noisy_y').toggle_id().add_x(
                transform=weak).add_y().add_y(source='noisy_y'))
        second_set = (DatasetBuilder(second_x, second_true_y).add_labels(
            second_y, 'noisy_y').toggle_id().add_x(
                transform=weak).add_y().add_y(source='noisy_y'))

        self.first_dataloader = first_set.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            drop_last=True,
            shuffle=True)

        self.second_dataloader = second_set.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            drop_last=True,
            shuffle=True)

        self.second_dataloader = second_set.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            drop_last=True,
            shuffle=True)
        self.second = False
        self.regist_databundler(train=self.first_dataloader)
        self.cur_set = 0
        self.to(self.device)
示例#8
0
    def datasets(self, params: NoisyParams):
        params.noisy_type = params.default('symmetric', True)
        params.noisy_ratio = params.default(0.2, True)

        import numpy as np
        dataset_fn = datasets[params.dataset]
        test_x, test_y = dataset_fn(False)
        train_x, train_y = dataset_fn(True)

        # train_ids, query_ids = splits.train_val_split(train_y, val_size=params.val_size)

        query_ids, train_ids, eval_ids = splits.semi_split(
            train_y,
            params.query_size // params.n_classes,
            val_size=params.val_size,
            repeat_sup=False)

        # train_ids = train_ids[:3000]
        self.train_size = len(train_ids)
        train_x, query_x, eval_x = train_x[train_ids], train_x[
            query_ids], train_x[eval_ids]
        train_y, query_y, eval_y = train_y[train_ids], train_y[
            query_ids], train_y[eval_ids]

        mean, std = norm_val.get(params.dataset, [None, None])
        toTensor = ToNormTensor(mean, std)
        weak = Weak(mean, std)
        strong = Strong(mean, std)

        if params.noisy_type == 'asymmetric':
            from data.noisy import asymmetric_noisy
            noisy_y = asymmetric_noisy(train_y,
                                       params.noisy_ratio,
                                       n_classes=params.n_classes)

        elif params.noisy_type == 'symmetric':
            from data.noisy import symmetric_noisy
            noisy_y = symmetric_noisy(train_y,
                                      params.noisy_ratio,
                                      n_classes=params.n_classes)

        else:
            assert False

        self.logger.info('noisy acc = {}'.format((train_y == noisy_y).mean()))
        self.rnd.shuffle()

        self.logger.info(len(train_y), len(train_x), len(noisy_y))
        train_set = (DatasetBuilder(train_x, train_y).add_labels(
            noisy_y, 'noisy_y').toggle_id().add_x(transform=strong))

        params.K = params.default(0, True)
        for _ in range(params.K):
            train_set.add_x(transform=weak)

        train_set = (train_set.add_y().add_y(source='noisy_y'))

        if params.distributed:
            from torch.utils.data.distributed import DistributedSampler
            sampler = DistributedSampler(train_set, num_replicas=4)
            self.sampler_a = sampler
        else:
            sampler = None
        train_set = train_set.DataLoader(batch_size=params.batch_size,
                                         num_workers=params.num_workers,
                                         sampler=sampler,
                                         shuffle=not params.distributed)

        query_set = (DatasetBuilder(query_x,
                                    query_y).add_x(transform=strong).add_y())
        if params.distributed:
            from torch.utils.data.distributed import DistributedSampler
            sampler = DistributedSampler(train_set, num_replicas=4)
            self.sampler_b = sampler
        else:
            sampler = None
        query_set = query_set.DataLoader(batch_size=params.batch_size,
                                         num_workers=params.num_workers,
                                         sampler=sampler,
                                         shuffle=not params.distributed)

        val_dataloader = (
            DatasetBuilder(eval_x, eval_y).add_x(transform=toTensor).add_y().
            DataLoader(
                batch_size=params.batch_size,
                shuffle=
                False,  # do not shuffle # no shuffle for probe, so a batch is class balanced.(?)
                num_workers=params.num_workers))

        train_dataloader = DataBundler().add(train_set).cycle(
            query_set).zip_mode()

        test_dataloader = (DatasetBuilder(
            test_x, test_y).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers,
                shuffle=True))

        self.regist_databundler(train=train_dataloader,
                                eval=val_dataloader,
                                test=test_dataloader)
        self.to(self.device)
示例#9
0
    def datasets(self, params: NoisyParams):
        self.rnd.mark('fix_noisy')
        params.noisy_type = params.default('symmetric', True)
        params.noisy_ratio = params.default(0.2, True)

        dataset_fn = datasets[params.dataset]
        test_x, test_y = dataset_fn(False)
        train_x, train_y = dataset_fn(True)

        train_ids, val_ids = splits.train_val_split(train_y, val_size=5000)

        train_x, val_x = train_x[train_ids], train_x[val_ids]
        train_y, val_y = train_y[train_ids], train_y[val_ids]

        mean, std = norm_val.get(params.dataset, [None, None])
        toTensor = ToNormTensor(mean, std)
        weak = Weak(mean, std)
        strong = Strong(mean, std)

        if params.noisy_type == 'asymmetric':
            from data.noisy import asymmetric_noisy
            noisy_y = asymmetric_noisy(train_y,
                                       params.noisy_ratio,
                                       n_classes=params.n_classes)

        elif params.noisy_type == 'symmetric':
            from data.noisy import symmetric_noisy
            noisy_y = symmetric_noisy(train_y,
                                      params.noisy_ratio,
                                      n_classes=params.n_classes)

        else:
            assert False
        clean_mask = train_y == noisy_y
        self.logger.info('noisy acc = {}'.format((train_y == noisy_y).mean()))
        self.rnd.shuffle()

        train_set = (DatasetBuilder(train_x, train_y).add_labels(
            noisy_y, 'noisy_y').toggle_id().add_x(transform=weak).add_x(
                transform=strong).add_y().add_y(source='noisy_y'))
        from thextra.noisy_sampler import NoisySampler
        sampler = None
        if params.order_sampler:
            sampler = NoisySampler(train_set, clean_mask)

        self.train_size = len(train_set)
        train_dataloader = train_set.DataLoader(batch_size=params.batch_size,
                                                num_workers=params.num_workers,
                                                drop_last=True,
                                                sampler=sampler,
                                                shuffle=True)

        val_dataloader = (DatasetBuilder(
            val_x, val_y).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers,
                shuffle=True))

        test_dataloader = (DatasetBuilder(
            test_x, test_y).add_x(transform=toTensor).add_y().DataLoader(
                batch_size=params.batch_size,
                num_workers=params.num_workers,
                shuffle=True))

        self.regist_databundler(train=train_dataloader,
                                eval=val_dataloader,
                                test=test_dataloader)
        self.to(self.device)
示例#10
0
    def datasets(self, params: EvalParams):
        params.noisy_type = params.default('symmetric', True)
        params.noisy_ratio = params.default(0.2, True)

        from data.constant import norm_val
        mean, std = norm_val.get(params.dataset, [None, None])
        from data.transforms import ToNormTensor
        toTensor = ToNormTensor(mean, std)
        from data.transforms import Weak
        weak = Weak(mean, std)
        from data.transforms import Strong

        dataset_fn = datasets.datasets[params.dataset]
        train_x, train_y = dataset_fn(True)
        train_y = np.array(train_y)
        from thexp import DatasetBuilder

        from data.noisy import symmetric_noisy
        noisy_y = symmetric_noisy(train_y,
                                  params.noisy_ratio,
                                  n_classes=params.n_classes)
        clean_mask = (train_y == noisy_y)
        noisy_mask = np.logical_not(clean_mask)

        if params.eval_mode in ['full', 'same_epoch', 'same_acc']:
            first_x, first_y = train_x[clean_mask], noisy_y[clean_mask]
            first_true_y = train_y[clean_mask]
        elif params.eval_mode in ['mix', 'raw', 'direct']:
            first_x, first_y = train_x, noisy_y
            first_true_y = train_y
        else:
            assert False

        second_x, second_true_y = train_x[noisy_mask], train_y[noisy_mask]
        second_y = noisy_y[noisy_mask]

        self.logger.info('noisy acc = {}'.format(
            (first_true_y == first_y).mean()))
        self.logger.info('noisy acc = {}'.format(
            (second_true_y == second_y).mean()))
        self.rnd.shuffle()

        first_set = (DatasetBuilder(first_x, first_true_y).add_labels(
            first_y, 'noisy_y').toggle_id().add_x(
                transform=weak).add_y().add_y(source='noisy_y'))
        noisy_set = (DatasetBuilder(second_x, second_true_y).add_labels(
            second_y, 'noisy_y').toggle_id().add_x(
                transform=weak).add_y().add_y(source='noisy_y'))

        first_dataloader = first_set.DataLoader(batch_size=params.batch_size,
                                                num_workers=params.num_workers,
                                                drop_last=True,
                                                shuffle=True)

        self.second_dataloader = noisy_set.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            drop_last=True,
            shuffle=True)
        self.second = False
        self.regist_databundler(train=first_dataloader)
        self.to(self.device)
        if params.eval_mode == 'direct':
            self.change_dataset()
示例#11
0
    def change_dataset(self):
        """
        根据当前的 filter_mem,按 thresh 将其分为有监督和无监督
        :return:
        """
        from data.constant import norm_val
        from data.transforms import ToNormTensor, Weak, Strong

        train_x, train_y, noisy_y = self.train_set

        filter_prob = self.filter_mem.cpu().numpy()
        clean_mask = filter_prob > 0.5
        self.logger.info('sup size', clean_mask.sum())
        if clean_mask.all() or not np.logical_not(clean_mask).any():
            return

        clean_ids = np.where(clean_mask)[0]
        noisy_ids = np.where(np.logical_not(clean_mask))[0]

        mean, std = norm_val.get(params.dataset, [None, None])
        weak = Weak(mean, std)
        strong = Strong(mean, std)

        supervised_dataloader = (
            DatasetBuilder(train_x, train_y)
                .add_labels(noisy_y, source_name='ny')
                .toggle_id()
                .add_x(strong)
                .add_y()
                .add_y(source='ny')
                .subset(clean_ids)
                .DataLoader(params.batch_size // 2,
                            shuffle=True,
                            num_workers=0,
                            drop_last=True)
        )

        unsupervised_dataloader = (
            DatasetBuilder(train_x, train_y)
                .add_labels(noisy_y, source_name='ny')
                .add_labels(filter_prob, source_name='nprob')
                .toggle_id()
                .add_x(strong)
                .add_x(strong)
                .add_y()
                .add_y(source='ny')
                .add_y(source='nprob')
                .subset(noisy_ids)
                .DataLoader(params.batch_size // 2,
                            shuffle=True,
                            num_workers=0,
                            drop_last=True)
        )
        if len(supervised_dataloader) > len(unsupervised_dataloader):
            train_dataloader = (
                DataBundler()
                    .add(supervised_dataloader)
                    .cycle(unsupervised_dataloader)
            )
        else:
            train_dataloader = (
                DataBundler()
                    .cycle(supervised_dataloader)
                    .add(unsupervised_dataloader)
            )
        if len(unsupervised_dataloader) == 0 or len(supervised_dataloader) == 0:
            self.ssl_dataloader = None
            return

        self.ssl_dataloader = train_dataloader.zip_mode().to(self.device)
        self.logger.info('ssl loader size', train_dataloader)
        self.ssl_loaderiter = iter(self.ssl_dataloader)