def map_reduce(last_elem: tf.Tensor, data: tf.data.Dataset, map_fn, reduce_fn=tf.add, **kwargs): """ Iterate over elements in a tf.data.Dataset. Fetches new elements until "last_elem" appears at `idx[-1]`. TODO: remove 'last_elem' as soon as tensorflow iterators support some `has_next` functionality :param last_elem: the last element :param data: tf.data.Dataset containing `(idx, val)` with idx as a vector of shape `(batch_size,)` :param map_fn: function taking arguments `(idx, val)` :param reduce_fn: function taking two return values of `map_fn` and reducing them into one return value :param kwargs: additional arguments passed to the `tf.while loop` :return: """ iterator = data.make_initializable_iterator() def cond(idx, val): return tf.not_equal(tf.gather(idx, tf.size(idx) - 1), last_elem) def body_fn(old_idx, old_val): idx, val = iterator.get_next() return idx, reduce_fn(old_val, map_fn(idx, val)) def init_vals(): idx, val = iterator.get_next() return idx, map_fn(idx, val) with tf.control_dependencies([iterator.initializer]): _, reduced = tf.while_loop(cond, body_fn, init_vals(), **kwargs) return reduced
def get_input_nodes(dataset: tf.data.Dataset) -> Dict[str, str]: """ Args: dataset: A `tf.data.Dataset` instance. Returns: Dictionary from input feature-name to the corresponding node in the graph. """ dataset_iter = dataset.make_initializable_iterator() features = dataset_iter.get_next() if isinstance(features, tuple): # thus including input, target, etc. features = features[0] # input only, assuming that input is the first. if isinstance(features, dict): raise TypeError input_nodes = {} for input_name, tensor in features.items(): input_nodes[input_name] = tensor.name return input_nodes
def fit_dataset(self, dataset: tf.data.Dataset): dataset = dataset.batch(self.default_batch_size).prefetch( self.default_batch_size) iterator = dataset.make_initializable_iterator() with tf.Session() as self._sess: for i in range(5): self._sess.run(iterator.initializer) logging.info('Fitting layer %d' % i) batch_n = 0 self._load_or_init_session() while True: try: batch = iterator.get_next() stack_batch_op = tf.stack(batch) stacked_batch = self._sess.run(stack_batch_op) for step in range(self.epochs): feed_dict = { self._x0: stacked_batch, self._corruption_level: self.corruption_level } self._sess.run(self.train_steps[i], feed_dict=feed_dict) self._log_progress(batch_n, step, stacked_batch, i) batch_n += 1 if (batch_n + 1) % 25 == 0: self._write_summaries(stacked_batch) except tf.errors.OutOfRangeError: break logging.info( 'Saving trained params to %s with global_step %s' % (self.checkpoint_file, self.global_step)) self._saver.save(self._sess, self.checkpoint_file, global_step=self.global_step)