Пример #1
0
 def test_basic(self):
     n = 20
     m = 10
     sampler = CRPSampler(0)
     sampler.set_dataset(Dataset(n, m))
     collection = sampler.sample_collection(5, 2, alpha=0.3, theta=1.0)
     print(collection['support'])
     print(collection['query'])
 def test_basic(self):
   folder = '/mnt/local/data/omniglot'
   omniglot = OmniglotDataset(folder, 'train')
   preprocessor = NormalizationPreprocessor()
   for bsize in [1, 2]:
     sampler = CRPSampler(0)
     sampler2 = SemiSupervisedEpisodeSampler(sampler, 0)
     it = SemiSupervisedEpisodeIterator(
         omniglot,
         sampler2,
         batch_size=bsize,
         nclasses=10,
         nquery=5,
         preprocessor=preprocessor,
         expand=True,
         fix_unknown=True,
         label_ratio=0.5,
         nd=5,
         sd=1,
         md=2,
         alpha=0.5,
         theta=1.0)
     for x in range(2):
       b = it.next()
       print(b)
       print('support', tf.reduce_max(b.train_images),
             tf.reduce_min(b.train_images), tf.shape(b.train_images))
       print('support label', b.train_labels, tf.shape(b.train_labels))
       print('support gt', b.train_groundtruth, tf.shape(b.train_groundtruth))
       print('query', tf.reduce_max(b.test_images),
             tf.reduce_min(b.test_images), tf.shape(b.test_images))
       print('query label', b.test_labels, tf.shape(b.test_labels))
Пример #3
0
 def test_markov_hierarchy(self):
     k = 10
     n = 100
     m = 100
     subsampler = CRPSampler(0)
     blender = MarkovSwitchBlender(np.ones([3]) / 3.0, 0.5, 0)
     sampler = HierarchicalEpisodeSampler(subsampler, blender, True, 0)
     sampler.set_dataset(Dataset(k, n, m))
     collection = sampler.sample_collection(30,
                                            2,
                                            alpha=0.5,
                                            theta=1.0,
                                            nstage=3,
                                            max_num=60,
                                            max_num_per_cls=m)
     print(collection['support'])
Пример #4
0
 def test_basic(self):
     k = 10
     n = 100
     m = 100
     subsampler = CRPSampler(0)
     blender = BlurBlender(window_size=20, stride=5, nrun=10, seed=0)
     sampler = HierarchicalEpisodeSampler(subsampler, blender, False, 0)
     sampler.set_dataset(Dataset(k, n, m))
     for _ in range(10):
         collection = sampler.sample_collection(50,
                                                2,
                                                alpha=0.5,
                                                theta=1.0,
                                                nstage=5,
                                                max_num=100,
                                                max_num_per_cls=m)
     print(collection['support'])
Пример #5
0
class SeqCRPSampler(EpisodeSampler):

  def __init__(self, seed):
    super(SeqCRPSampler, self).__init__(seed)
    self._crp_sampler = CRPSampler(seed)

  def sample_episode_classes(self,
                             n,
                             stages=2,
                             alpha=0.5,
                             theta=1.0,
                             max_num=-1,
                             max_num_per_cls=20):
    """See EpisodeSampler class for documentation.

    Args:
      n: Int. Number of classes for each stage.
      stages: Int. Number of sequential stages.
      alpha: Float. Discount parameter.
      theta: Float. Strength parameter.
      max_num: Int. Maximum number of images.
      max_num_per_class: Int. Maximum number of images per class.
    """
    result = []
    cur_max = 0
    assert n % stages == 0
    for i in range(stages):
      result_ = self._crp_sampler.sample_episode_classes(
          n // stages,
          alpha=alpha,
          theta=theta,
          max_num=max_num,
          max_num_per_cls=max_num_per_cls)
      result_ = np.array(result_)
      # print(result_)
      result.extend(list(result_ + cur_max))
      cur_max += result_.max() + 1
    return result
 def test_basic(self):
   n = 20
   m = 10
   sampler = CRPSampler(0)
   sampler = SemiSupervisedEpisodeSampler(sampler, 0)
   sampler.set_dataset(Dataset(n, m))
   for x in range(100):
     collection = sampler.sample_collection(
         10,
         2,
         alpha=0.3,
         theta=1.0,
         max_num_per_cls=m,
         max_num=40,
         label_ratio=0.5)
     s, q = collection['support'], collection['query']
     print('Support', s)
     print('Query', q)
 def test_distractor_basic(self):
   n = 20
   m = 10
   sampler = CRPSampler(0)
   sampler = SemiSupervisedEpisodeSampler(sampler, 0)
   sampler.set_dataset(Dataset(n, m))
   print('start')
   print(sampler.cls_dict)
   for x in range(100):
     print(x)
     collection = sampler.sample_collection(
         10,
         2,
         nd=5,
         sd=3,
         md=2,
         alpha=0.3,
         theta=1.0,
         max_num_per_cls=m,
         max_num=40,
         label_ratio=0.5)
     s, q = collection['support'], collection['query']
Пример #8
0
 def __init__(self, seed):
   super(SeqCRPSampler, self).__init__(seed)
   self._crp_sampler = CRPSampler(seed)