Example #1
0
def map_reduce(last_elem: tf.Tensor, data: tf.data.Dataset, map_fn, reduce_fn=tf.add, **kwargs):
    """
    Iterate over elements in a tf.data.Dataset.
    Fetches new elements until "last_elem" appears at `idx[-1]`.

    TODO: remove 'last_elem' as soon as tensorflow iterators support some `has_next` functionality

    :param last_elem: the last element
    :param data: tf.data.Dataset containing `(idx, val)` with idx as a vector of shape `(batch_size,)`
    :param map_fn: function taking arguments `(idx, val)`
    :param reduce_fn: function taking two return values of `map_fn` and reducing them into one return value
    :param kwargs: additional arguments passed to the `tf.while loop`
    :return:
    """
    iterator = data.make_initializable_iterator()

    def cond(idx, val):
        return tf.not_equal(tf.gather(idx, tf.size(idx) - 1), last_elem)

    def body_fn(old_idx, old_val):
        idx, val = iterator.get_next()

        return idx, reduce_fn(old_val, map_fn(idx, val))

    def init_vals():
        idx, val = iterator.get_next()
        return idx, map_fn(idx, val)

    with tf.control_dependencies([iterator.initializer]):
        _, reduced = tf.while_loop(cond, body_fn, init_vals(), **kwargs)

    return reduced
Example #2
0
def get_input_nodes(dataset: tf.data.Dataset) -> Dict[str, str]:
    """
  Args:
    dataset: A `tf.data.Dataset` instance.

  Returns:
    Dictionary from input feature-name to the corresponding node in the graph.
  """
    dataset_iter = dataset.make_initializable_iterator()
    features = dataset_iter.get_next()

    if isinstance(features, tuple):  # thus including input, target, etc.
        features = features[0]  # input only, assuming that input is the first.

    if isinstance(features, dict):
        raise TypeError

    input_nodes = {}
    for input_name, tensor in features.items():
        input_nodes[input_name] = tensor.name
    return input_nodes
Example #3
0
    def fit_dataset(self, dataset: tf.data.Dataset):
        dataset = dataset.batch(self.default_batch_size).prefetch(
            self.default_batch_size)
        iterator = dataset.make_initializable_iterator()
        with tf.Session() as self._sess:
            for i in range(5):
                self._sess.run(iterator.initializer)
                logging.info('Fitting layer %d' % i)
                batch_n = 0
                self._load_or_init_session()
                while True:
                    try:
                        batch = iterator.get_next()
                        stack_batch_op = tf.stack(batch)
                        stacked_batch = self._sess.run(stack_batch_op)

                        for step in range(self.epochs):
                            feed_dict = {
                                self._x0: stacked_batch,
                                self._corruption_level: self.corruption_level
                            }
                            self._sess.run(self.train_steps[i],
                                           feed_dict=feed_dict)
                            self._log_progress(batch_n, step, stacked_batch, i)

                        batch_n += 1

                        if (batch_n + 1) % 25 == 0:
                            self._write_summaries(stacked_batch)

                    except tf.errors.OutOfRangeError:
                        break

                logging.info(
                    'Saving trained params to %s with global_step %s' %
                    (self.checkpoint_file, self.global_step))
                self._saver.save(self._sess,
                                 self.checkpoint_file,
                                 global_step=self.global_step)