def test_basic(self): folder = '/mnt/local/data/omniglot' omniglot = OmniglotDataset(folder, 'train') preprocessor = NormalizationPreprocessor() for bsize in [1, 2]: sampler = CRPSampler(0) sampler2 = SemiSupervisedEpisodeSampler(sampler, 0) it = SemiSupervisedEpisodeIterator( omniglot, sampler2, batch_size=bsize, nclasses=10, nquery=5, preprocessor=preprocessor, expand=True, fix_unknown=True, label_ratio=0.5, nd=5, sd=1, md=2, alpha=0.5, theta=1.0) for x in range(2): b = it.next() print(b) print('support', tf.reduce_max(b.train_images), tf.reduce_min(b.train_images), tf.shape(b.train_images)) print('support label', b.train_labels, tf.shape(b.train_labels)) print('support gt', b.train_groundtruth, tf.shape(b.train_groundtruth)) print('query', tf.reduce_max(b.test_images), tf.reduce_min(b.test_images), tf.shape(b.test_images)) print('query label', b.test_labels, tf.shape(b.test_labels))
def test_basic(self): n = 20 m = 10 sampler = CRPSampler(0) sampler.set_dataset(Dataset(n, m)) collection = sampler.sample_collection(5, 2, alpha=0.3, theta=1.0) print(collection['support']) print(collection['query'])
def test_markov_hierarchy(self): k = 10 n = 100 m = 100 subsampler = CRPSampler(0) blender = MarkovSwitchBlender(np.ones([3]) / 3.0, 0.5, 0) sampler = HierarchicalEpisodeSampler(subsampler, blender, True, 0) sampler.set_dataset(Dataset(k, n, m)) collection = sampler.sample_collection(30, 2, alpha=0.5, theta=1.0, nstage=3, max_num=60, max_num_per_cls=m) print(collection['support'])
def test_basic(self): k = 10 n = 100 m = 100 subsampler = CRPSampler(0) blender = BlurBlender(window_size=20, stride=5, nrun=10, seed=0) sampler = HierarchicalEpisodeSampler(subsampler, blender, False, 0) sampler.set_dataset(Dataset(k, n, m)) for _ in range(10): collection = sampler.sample_collection(50, 2, alpha=0.5, theta=1.0, nstage=5, max_num=100, max_num_per_cls=m) print(collection['support'])
def test_basic(self): n = 20 m = 10 sampler = CRPSampler(0) sampler = SemiSupervisedEpisodeSampler(sampler, 0) sampler.set_dataset(Dataset(n, m)) for x in range(100): collection = sampler.sample_collection( 10, 2, alpha=0.3, theta=1.0, max_num_per_cls=m, max_num=40, label_ratio=0.5) s, q = collection['support'], collection['query'] print('Support', s) print('Query', q)
def test_distractor_basic(self): n = 20 m = 10 sampler = CRPSampler(0) sampler = SemiSupervisedEpisodeSampler(sampler, 0) sampler.set_dataset(Dataset(n, m)) print('start') print(sampler.cls_dict) for x in range(100): print(x) collection = sampler.sample_collection( 10, 2, nd=5, sd=3, md=2, alpha=0.3, theta=1.0, max_num_per_cls=m, max_num=40, label_ratio=0.5) s, q = collection['support'], collection['query']
def __init__(self, seed): super(SeqCRPSampler, self).__init__(seed) self._crp_sampler = CRPSampler(seed)