def make_dequeue_op(self): from TFUtil import cond return self._as_dict(cond( self.train_flag, lambda: self._as_list(self.train_queue.dequeue()), lambda: self._as_list(self.eval_queue.dequeue()), name="queue_dequeue"))
def cond_on_train(self, fn_train, fn_eval): """ Uses fn_train() or fn_eval() base on self.train_flag. It will be a branched evaluation. :param ()->tf.Tensor fn_train: :param ()->tf.Tensor fn_eval: :return: fn_train() if self.train_flag else fn_eval() :rtype: tf.Tensor """ from TFUtil import cond return cond(self.train_flag, fn_train, fn_eval)
def __init__(self, extern_data, capacity=100, seed=1, with_batch=False, enqueue_data=None): """ :param ExternData extern_data: this specifies the data keys :param int capacity: :param int seed: :param bool with_batch: whether we have the batch-dim in input/output :param dict[str,tf.Tensor] enqueue_data: if provided, will be the input """ self.extern_data = extern_data self.data_keys = extern_data.data.keys() self.with_batch = with_batch self.enqueue_data = enqueue_data # http://stackoverflow.com/questions/41187745/tensorflow-how-can-i-evaluate-a-validation-data-queue-multiple-times-during-tra/44067467#44067467 # I.e. we need two separate queues, one for training (RandomShuffleQueue) and one for eval (FIFOQueue), # and switch between the dequeue via tf.cond. from TFUtil import cond, get_global_train_flag_placeholder self.train_flag = get_global_train_flag_placeholder() self.names = list(self.data_keys) self.dtypes = [self.extern_data.data[key].dtype for key in self.data_keys] self.shapes = { key: data.batch_shape if with_batch else data.shape for (key, data) in self.extern_data.data.items()} for key, data in self.extern_data.data.items(): for axis in data.get_axes_with_size(): self.shapes["%s/size%i" % (key, axis)] = (None,) if with_batch else () self.enqueue_placeholders = None if not self.enqueue_data: self.enqueue_placeholders = { key: tf.placeholder(**self.extern_data.data[key].get_placeholder_kwargs(with_batch=with_batch)) for key in self.data_keys} for key in self.data_keys: for axis in self.extern_data.data[key].get_axes_with_size(): name = "%s/size%i" % (key, axis) self.names += [name] self.dtypes += [self.extern_data.data[key].size_dtype] self.enqueue_placeholders[name] = tf.placeholder( **self.extern_data.data[key].get_size_placeholder_kwargs(axis, with_batch=with_batch)) self.enqueue_data = self.enqueue_placeholders # TF recommendation: capacity = min_after_dequeue + (num_threads + a small safety margin) * batch_size self.capacity = capacity self.train_queue_min_after_dequeue = int(capacity * 0.8) self.train_queue = tf.RandomShuffleQueue( capacity=self.capacity, min_after_dequeue=self.train_queue_min_after_dequeue, names=self.names, dtypes=self.dtypes, seed=seed, name="train_queue") self.eval_queue = tf.FIFOQueue( capacity=self.capacity, names=self.names, dtypes=self.dtypes, name="eval_queue") self.train_queue_size = self.train_queue.size() self.eval_queue_size = self.eval_queue.size() self.dequeue_size_op = cond( self.train_flag, lambda: self.train_queue.size() - self.train_queue_min_after_dequeue, lambda: self.eval_queue.size()) self.have_more_op = tf.greater(self.dequeue_size_op, 0, name="queue_have_more") self.one_more_enqueue_is_enough_op = tf.greater(self.dequeue_size_op, -1, name="queue_have_more") self.enqueue_op = cond( self.train_flag, lambda: self.train_queue.enqueue(self.enqueue_data), lambda: self.eval_queue.enqueue(self.enqueue_data), name="queue_enqueue")