class DeterministicMetaLossEvaluator(object):
    """Compute the meta loss, but using the same batches of data for each call.

  in the same session run.

  This is used for things like ES where shared random numbers reduce gradient
  variance.
  """
    def __init__(self,
                 meta_loss_state_fn,
                 inner_loss_state_fn,
                 unroll_n_steps,
                 meta_loss_evals=5):
        # Compute 1 nested structure of tensor array for the inner loop data
        # and a second for the extra evaluations.

        # Fill a tensor array with all the batches requested, then use these instead
        # of resampling.

        self.meta_loss_evals = meta_loss_evals
        self.unroll_n_steps = unroll_n_steps

        def fill_batches(size, state_fn):
            """Fill a tensor array with batches of data."""
            dummy_batch = state_fn()
            tas = nest.map_structure(
                lambda b: tf.TensorArray(
                    b.dtype, size=size, clear_after_read=False), dummy_batch)

            cond = lambda i, ta: tf.less(i, size)

            def body(i, tas):
                batch = state_fn()
                out_tas = []
                for ta, b in py_utils.eqzip(nest.flatten(tas),
                                            nest.flatten(batch)):
                    out_tas.append(ta.write(i, b))
                return (i + 1, nest.pack_sequence_as(dummy_batch, out_tas))

            _, batches = tf.while_loop(cond, body,
                                       [tf.constant(0, dtype=tf.int32), tas])
            return batches

        self.meta_batches = fill_batches(meta_loss_evals * unroll_n_steps,
                                         meta_loss_state_fn)

        self.inner_batches_eval = fill_batches(
            meta_loss_evals * unroll_n_steps, inner_loss_state_fn)

        self.inner_batches = fill_batches(unroll_n_steps, inner_loss_state_fn)

    def get_batch(self, idx, batches):
        return nest.map_structure(lambda ta: ta.read(idx), batches)

    def get_meta_batch_state(self):
        def _fn(x):
            return x.stack()

        return nest.map_structure(_fn, self.meta_batches)

    def get_inner_batch_state(self):
        def _fn(x):
            return x.stack()

        return nest.map_structure(_fn, self.inner_batches)

    def _nest_bimap(self, fn, data1, data2):
        data = py_utils.eqzip(nest.flatten(data1), nest.flatten(data2))
        out = [fn(*a) for a in data]
        return nest.pack_sequence_as(data1, out)

    def get_phi_trajectory(self, learner, inner_batches=None, init_state=None):
        """Compute a inner-parameter trajectory."""
        if inner_batches is None:
            inner_batches = self.inner_batches
        else:
            # convert the batches object to a tensorarray.
            def to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.unroll_n_steps).unstack(t)

            inner_batches = nest.map_structure(to_ta, inner_batches)

        if init_state is None:
            init_state = learner.current_state()
            init_state = tf_utils.force_copy(init_state)

        def body(learner_state, ta, i):
            batch = self.get_batch(i, batches=inner_batches)
            _, next_state = learner.loss_and_next_state(learner_state,
                                                        loss_state=batch)

            # shift because this is the next state.
            next_ta = self._nest_bimap(lambda t, v: t.write(i + 1, v), ta,
                                       next_state)
            return next_state, next_ta, i + 1

        def cond(learner_state, ta, i):  # pylint: disable=unused-argument
            return tf.less(i, self.unroll_n_steps)

        def make_ta(x):
            ta = tf.TensorArray(dtype=x.dtype, size=self.unroll_n_steps + 1)
            return ta.write(0, x)

        _, ta, _ = tf.while_loop(cond, body, [
            init_state,
            nest.map_structure(make_ta, init_state),
            tf.constant(0)
        ])
        return nest.map_structure(lambda x: x.stack(), ta)

    def get_avg_loss(self, learner, inner_batches, init_state):
        """Compute average loss for unroll."""
        if inner_batches is None:
            inner_batches = self.inner_batches
        else:
            # convert the batches object to a tensorarray.
            def to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.unroll_n_steps).unstack(t)

            inner_batches = nest.map_structure(to_ta, inner_batches)

        if init_state is None:
            init_state = learner.current_state()
            init_state = tf_utils.force_copy(init_state)

        def body(a, i):
            batch = self.get_batch(i * self.batches_per_step,
                                   batches=inner_batches)
            l = learner.meta_loss(init_state, loss_state=batch)
            return a + l, i + 1

        def cond(_, i):
            return tf.less(i, self.unroll_n_steps)

        a, _ = tf.while_loop(cond, body, [tf.constant(0.0), tf.constant(0)])
        return a / self.unroll_n_steps

    def __call__(self,
                 learner,
                 meta_batches=None,
                 inner_batches=None,
                 init_state=None,
                 unroll_n_steps=None):
        if unroll_n_steps is None:
            unroll_n_steps = self.unroll_n_steps
        else:
            print("Using passed in unroll steps")

        if inner_batches is None:
            inner_batches = self.inner_batches
        else:
            # convert the batches object to a tensorarray.
            def to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.unroll_n_steps).unstack(t)

            inner_batches = nest.map_structure(to_ta, inner_batches)

        if meta_batches is None:
            meta_batches = self.meta_batches
        else:
            # convert the batches object to a tensorarray.
            def ml_to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.meta_loss_evals *
                                      self.unroll_n_steps).unstack(t)

            meta_batches = nest.map_structure(ml_to_ta, meta_batches)

        if init_state is None:
            init_state = learner.current_state()
            init_state = tf_utils.force_copy(init_state)

        current_state = (tf.constant(0, dtype=tf.int32),
                         tf.constant(0., dtype=tf.float32), init_state)

        def loss_and_next_state_fn((idx, l, state)):
            batch = self.get_batch(idx, batches=inner_batches)
            l, s = learner.loss_and_next_state(state, loss_state=batch)
            return (idx + 1, l, s)

        def accumulate_fn((idx, _, s), (a_meta, a_inner)):
            """Accumulate loss for fold learning process."""
            cond = lambda i, a: tf.less(i, self.meta_loss_evals)

            def body_meta(i, a):
                # minus 1 as this takes the following step.
                batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i,
                                       batches=meta_batches)
                return (i + 1, a + learner.meta_loss(s, loss_state=batch))

            _, extra_losses = tf.while_loop(cond, body_meta, loop_vars=[0, 0.])

            def body_inner(i, a):
                # minus 1 as this takes the following step.
                batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i,
                                       batches=meta_batches)
                return (i + 1, a + learner.inner_loss(s, loss_state=batch))

            _, inner_losses = tf.while_loop(cond,
                                            body_inner,
                                            loop_vars=[0, 0.])

            return a_meta + extra_losses, a_inner + inner_losses

        (_, _,
         final_state), (meta_loss_sum,
                        _) = learning_process.fold_learning_process(
                            unroll_n_steps,  # Note this is not self version.
                            loss_and_next_state_fn,
                            accumulate_fn=accumulate_fn,
                            start_state=current_state,
                            accumulator_start_state=(tf.constant(
                                0., dtype=tf.float32),
                                                     tf.constant(
                                                         0.,
                                                         dtype=tf.float32)))

        # TODO(lmetz) this should be shifted to compute loss for all but shifted 1
        meta_loss = (meta_loss_sum) / tf.to_float(
            unroll_n_steps) / tf.to_float(self.meta_loss_evals)

        return meta_loss, final_state
    loss_and_next_state_fn = lambda (l, state): learner.loss_and_next_state(
        state)

    def accumulate_fn((l, s), a):
        if extra_loss_eval > 0:
            cond = lambda i, a: tf.less(i, extra_loss_eval)
            body = lambda i, a: (i + 1, a + learner.meta_loss(s))
            _, extra_losses = tf.while_loop(cond, body, loop_vars=[0, 0.])
            return a + extra_losses
        else:
            return a + l

    (_, final_state), training_loss = learning_process.fold_learning_process(
        unroll_n_steps,
        loss_and_next_state_fn,
        accumulate_fn=accumulate_fn,
        start_state=current_state,
        accumulator_start_state=tf.constant(0., dtype=tf.float32),
    )

    meta_loss = (training_loss) / (tf.to_float(unroll_n_steps) *
                                   tf.to_float(extra_loss_eval))

    return tf.identity(meta_loss), nest.map_structure(tf.identity, final_state)


def make_push_op(learner, ds, failed_push, should_push, to_push, final_state,
                 pre_step_index):
    """Helper that make the op that pushes gradients, and assigns next state."""
    # This is what pushes gradient tensors to a shared location.
    push = lambda: ds.push_tensors(to_push, pre_step_index)