コード例 #1
0
ファイル: TFDataPipeline.py プロジェクト: TarrySingh/returnn
 def make_dequeue_op(self):
   from TFUtil import cond
   return self._as_dict(cond(
     self.train_flag,
     lambda: self._as_list(self.train_queue.dequeue()),
     lambda: self._as_list(self.eval_queue.dequeue()),
     name="queue_dequeue"))
コード例 #2
0
ファイル: TFNetwork.py プロジェクト: ZhangAustin/returnn
  def cond_on_train(self, fn_train, fn_eval):
    """
    Uses fn_train() or fn_eval() base on self.train_flag.
    It will be a branched evaluation.

    :param ()->tf.Tensor fn_train:
    :param ()->tf.Tensor fn_eval:
    :return: fn_train() if self.train_flag else fn_eval()
    :rtype: tf.Tensor
    """
    from TFUtil import cond
    return cond(self.train_flag, fn_train, fn_eval)
コード例 #3
0
ファイル: TFDataPipeline.py プロジェクト: TarrySingh/returnn
  def __init__(self, extern_data, capacity=100, seed=1, with_batch=False, enqueue_data=None):
    """
    :param ExternData extern_data: this specifies the data keys
    :param int capacity:
    :param int seed:
    :param bool with_batch: whether we have the batch-dim in input/output
    :param dict[str,tf.Tensor] enqueue_data: if provided, will be the input
    """
    self.extern_data = extern_data
    self.data_keys = extern_data.data.keys()
    self.with_batch = with_batch
    self.enqueue_data = enqueue_data

    # http://stackoverflow.com/questions/41187745/tensorflow-how-can-i-evaluate-a-validation-data-queue-multiple-times-during-tra/44067467#44067467
    # I.e. we need two separate queues, one for training (RandomShuffleQueue) and one for eval (FIFOQueue),
    # and switch between the dequeue via tf.cond.
    from TFUtil import cond, get_global_train_flag_placeholder
    self.train_flag = get_global_train_flag_placeholder()
    self.names = list(self.data_keys)
    self.dtypes = [self.extern_data.data[key].dtype for key in self.data_keys]
    self.shapes = {
      key: data.batch_shape if with_batch else data.shape
      for (key, data) in self.extern_data.data.items()}
    for key, data in self.extern_data.data.items():
      for axis in data.get_axes_with_size():
        self.shapes["%s/size%i" % (key, axis)] = (None,) if with_batch else ()

    self.enqueue_placeholders = None
    if not self.enqueue_data:
      self.enqueue_placeholders = {
        key: tf.placeholder(**self.extern_data.data[key].get_placeholder_kwargs(with_batch=with_batch))
        for key in self.data_keys}
      for key in self.data_keys:
        for axis in self.extern_data.data[key].get_axes_with_size():
          name = "%s/size%i" % (key, axis)
          self.names += [name]
          self.dtypes += [self.extern_data.data[key].size_dtype]
          self.enqueue_placeholders[name] = tf.placeholder(
            **self.extern_data.data[key].get_size_placeholder_kwargs(axis, with_batch=with_batch))
      self.enqueue_data = self.enqueue_placeholders

    # TF recommendation: capacity = min_after_dequeue + (num_threads + a small safety margin) * batch_size
    self.capacity = capacity
    self.train_queue_min_after_dequeue = int(capacity * 0.8)
    self.train_queue = tf.RandomShuffleQueue(
      capacity=self.capacity, min_after_dequeue=self.train_queue_min_after_dequeue,
      names=self.names, dtypes=self.dtypes,
      seed=seed, name="train_queue")
    self.eval_queue = tf.FIFOQueue(
      capacity=self.capacity, names=self.names, dtypes=self.dtypes, name="eval_queue")
    self.train_queue_size = self.train_queue.size()
    self.eval_queue_size = self.eval_queue.size()
    self.dequeue_size_op = cond(
      self.train_flag,
      lambda: self.train_queue.size() - self.train_queue_min_after_dequeue,
      lambda: self.eval_queue.size())
    self.have_more_op = tf.greater(self.dequeue_size_op, 0, name="queue_have_more")
    self.one_more_enqueue_is_enough_op = tf.greater(self.dequeue_size_op, -1, name="queue_have_more")
    self.enqueue_op = cond(
      self.train_flag,
      lambda: self.train_queue.enqueue(self.enqueue_data),
      lambda: self.eval_queue.enqueue(self.enqueue_data),
      name="queue_enqueue")