Exemplo n.º 1
0
 def __init__(self, q, batch_iter):
     super(BatchProducer, self).__init__()
     threading.Thread.__init__(self)
     self.q = q
     self.batch_iter = batch_iter
     self.log = logger.get()
     self._stoper = threading.Event()
     self.daemon = True
Exemplo n.º 2
0
    def __init__(self,
                 num,
                 batch_size=1,
                 progress_bar=False,
                 log_epoch=10,
                 get_fn=None,
                 cycle=False,
                 shuffle=True,
                 stagnant=False,
                 seed=2,
                 num_batches=-1):
        """Construct a batch iterator.

    Args:
        data: numpy.ndarray, (N, D), N is the number of examples, D is the
        feature dimension.
        labels: numpy.ndarray, (N), N is the number of examples.
        batch_size: int, batch size.
    """

        self._num = num
        self._batch_size = batch_size
        self._step = 0
        self._num_steps = int(np.ceil(self._num / float(batch_size)))
        if num_batches > 0:
            self._num_steps = min(self._num_steps, num_batches)
        self._pb = None
        self._variables = None
        self._get_fn = get_fn
        self.get_fn = get_fn
        self._cycle = cycle
        self._shuffle_idx = np.arange(self._num)
        self._shuffle = shuffle
        self._random = np.random.RandomState(seed)
        if shuffle:
            self._random.shuffle(self._shuffle_idx)
        self._shuffle_flag = False
        self._stagnant = stagnant
        self._log_epoch = log_epoch
        self._log = logger.get()
        self._epoch = 0
        if progress_bar:
            self._pb = pb.get(self._num_steps)
            pass
        self._mutex = threading.Lock()
        pass
Exemplo n.º 3
0
 def __init__(self,
              batch_iter,
              max_queue_size=10,
              num_threads=5,
              log_queue=20,
              name=None):
     """
     Data provider wrapper that supports concurrent data fetching.
     """
     super(ConcurrentBatchIterator, self).__init__()
     self.max_queue_size = max_queue_size
     self.num_threads = num_threads
     self.q = queue.Queue(maxsize=max_queue_size)
     self.log = logger.get()
     self.batch_iter = batch_iter
     self.fetchers = []
     self.init_fetchers()
     self.counter = 0
     self.relaunch = True
     self._stopped = False
     self.log_queue = log_queue
     self.name = name
Exemplo n.º 4
0
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

from fewshot.utils import logger

log = logger.get()

MODEL_REGISTRY = {}


def RegisterModel(model_name):
    """Registers a model class"""
    def decorator(f):
        MODEL_REGISTRY[model_name] = f
        return f

    return decorator


def get_model(model_name, *args, **kwargs):
    log.info("Model {}".format(model_name))
    if model_name in MODEL_REGISTRY:
        return MODEL_REGISTRY[model_name](*args, **kwargs)
    else:
        raise ValueError("Model class does not exist {}".format(model_name))