def test_pairs_two_sets_no_filter(self):

        rootdir = self.tempdir
        d1 = rootdir / 'ai-upsample-peaks-residual_unet-run004' / 'ai-upsample-peaks-n25000'
        (d1 / 'snapshot').mkdir(exist_ok=True, parents=True)

        f1 = d1 / '001cell_resp.png'
        f1.touch()

        d2 = rootdir / 'ai-upsample-peaks-countception-run002' / 'ai-upsample-peaks-n50000'
        (d2 / 'snapshot').mkdir(exist_ok=True, parents=True)

        f2 = d2 / '001cell_resp.png'
        f2.touch()

        res_pairs, num_detectors = training.pair_detector_data(
            rootdir, data_type='train')
        self.assertEqual(num_detectors, 2)

        # Channel, tile, timepoint, but only timepoint applies to training data
        exp_pairs = {
            (None, None, 1): [f1, f2],
        }
        self.assertEqual(set(res_pairs), set(exp_pairs))
        for key in res_pairs:
            self.assertEqual(len(res_pairs[key]), num_detectors)
            self.assertEqual(list(sorted(res_pairs[key])),
                             list(sorted(exp_pairs[key])))
    def test_pairs_two_sets_any_data(self):

        rootdir = self.tempdir
        d1 = rootdir / 'SingleCell-countception'
        f1 = d1 / 's01' / 's01t001cell_resp.png'
        f1.parent.mkdir(parents=True, exist_ok=True)
        f1.touch()
        f2 = d1 / 's01' / 's01t002cell_resp.png'
        f2.parent.mkdir(parents=True, exist_ok=True)
        f2.touch()
        f3 = d1 / 's01t002cell_resp.png'
        f3.parent.mkdir(parents=True, exist_ok=True)
        f3.touch()

        d2 = rootdir / 'SingleCell-unet'
        f4 = d2 / 's01' / 's01t001cell_resp.png'
        f4.parent.mkdir(parents=True, exist_ok=True)
        f4.touch()
        f5 = d2 / 's01' / 's01t002cell_resp.png'
        f5.parent.mkdir(parents=True, exist_ok=True)
        f5.touch()
        f6 = d2 / 's01t002cell_resp.png'
        f6.parent.mkdir(parents=True, exist_ok=True)
        f6.touch()

        res_pairs, num_detectors = training.pair_detector_data(
            rootdir, data_type='any', detectors=('countception', 'unet'))
        self.assertEqual(num_detectors, 2)

        # Channel, tile, timepoint, but only channel applies to "any" data
        exp_pairs = {
            (pathlib.Path('s01/s01t001cell'), None, None): [f1, f4],
            (pathlib.Path('s01/s01t002cell'), None, None): [f2, f5],
            (pathlib.Path('s01t002cell'), None, None): [f3, f6],
        }
        self.assertEqual(set(res_pairs), set(exp_pairs))
        for key in res_pairs:
            self.assertEqual(len(res_pairs[key]), num_detectors)
            self.assertEqual(len(exp_pairs[key]), num_detectors)
            self.assertEqual(set(res_pairs[key]), set(exp_pairs[key]))