def get_avg_loss(self, learner, inner_batches, init_state): """Compute average loss for unroll.""" if inner_batches is None: inner_batches = self.inner_batches else: # convert the batches object to a tensorarray. def to_ta(t): return tf.TensorArray(dtype=t.dtype, size=self.unroll_n_steps).unstack(t) inner_batches = nest.map_structure(to_ta, inner_batches) if init_state is None: init_state = learner.current_state() init_state = tf_utils.force_copy(init_state) def body(a, i): batch = self.get_batch(i * self.batches_per_step, batches=inner_batches) l = learner.meta_loss(init_state, loss_state=batch) return a + l, i + 1 def cond(_, i): return tf.less(i, self.unroll_n_steps) a, _ = tf.while_loop(cond, body, [tf.constant(0.0), tf.constant(0)]) return a / self.unroll_n_steps
def build_evaluation_graph(loss_module_fn=gin.REQUIRED, learner_fn=gin.REQUIRED): """Build the evaluation graph for inner-training.""" global_step = tf.train.get_or_create_global_step() loss_module = loss_module_fn() learner, theta_mod = learner_fn(loss_module) initial_state = learner.initial_state() reset_state_op = learner.assign_state(initial_state) state = learner.current_state() state = tf_utils.force_copy(state) with tf.control_dependencies(nest.flatten(state)): last_loss, new_state = learner.loss_and_next_state(state) with tf.control_dependencies([last_loss] + nest.flatten(new_state)): train_op = learner.assign_state(new_state) update_global_step = global_step.assign_add(1) train_op = tf.group(train_op, update_global_step, name="train_op") load_vars = list(theta_mod.get_variables(tf.GraphKeys.GLOBAL_VARIABLES)) meta_loss = learner.meta_loss(state) meta_loss = tf.Print(meta_loss, [meta_loss], "meta_loss") inner_loss = learner.inner_loss(state) inner_loss = tf.Print(inner_loss, [inner_loss], "inner_loss") # TODO(lmetz) this should only be scalars. train_op = tf.group(train_op, name="train_op") return { "train_op": train_op, "global_step": global_step, "init_op": reset_state_op, "checkpoint_vars": load_vars, "meta_loss": meta_loss, "train_loss": inner_loss, "current_state": state, "next_state": new_state, }
def get_phi_trajectory(self, learner, inner_batches=None, init_state=None): """Compute a inner-parameter trajectory.""" if inner_batches is None: inner_batches = self.inner_batches else: # convert the batches object to a tensorarray. def to_ta(t): return tf.TensorArray(dtype=t.dtype, size=self.unroll_n_steps).unstack(t) inner_batches = nest.map_structure(to_ta, inner_batches) if init_state is None: init_state = learner.current_state() init_state = tf_utils.force_copy(init_state) def body(learner_state, ta, i): batch = self.get_batch(i, batches=inner_batches) _, next_state = learner.loss_and_next_state(learner_state, loss_state=batch) # shift because this is the next state. next_ta = self._nest_bimap(lambda t, v: t.write(i + 1, v), ta, next_state) return next_state, next_ta, i + 1 def cond(learner_state, ta, i): # pylint: disable=unused-argument return tf.less(i, self.unroll_n_steps) def make_ta(x): ta = tf.TensorArray(dtype=x.dtype, size=self.unroll_n_steps + 1) return ta.write(0, x) _, ta, _ = tf.while_loop(cond, body, [ init_state, nest.map_structure(make_ta, init_state), tf.constant(0) ]) return nest.map_structure(lambda x: x.stack(), ta)
def compute_meta_loss(learner, unroll_n_steps, init_state=None, extra_loss_eval=5): """Helper function to compute the training objective. This function unrolls `unroll_n_steps` and accumulates the loss. Additionally, to lower variance, at each new state, an extra extra_loss_eval losses are computed and added to the loss. TODO(lmetz) a more rigorous anylisis of variance of gradients to pick these parameters. Args: learner: Learner instance unroll_n_steps: number of steps to unroll init_state: initial LearnerState extra_loss_eval: int Returns: meta_loss, final LearnerState """ if init_state is None: init_state = learner.current_state() init_state = tf_utils.force_copy(init_state) current_state = (tf.constant(0., dtype=tf.float32), init_state) loss_and_next_state_fn = lambda (l, state): learner.loss_and_next_state( state) def accumulate_fn((l, s), a): if extra_loss_eval > 0: cond = lambda i, a: tf.less(i, extra_loss_eval) body = lambda i, a: (i + 1, a + learner.meta_loss(s)) _, extra_losses = tf.while_loop(cond, body, loop_vars=[0, 0.]) return a + extra_losses else: return a + l
def __call__(self, learner, meta_batches=None, inner_batches=None, init_state=None, unroll_n_steps=None): if unroll_n_steps is None: unroll_n_steps = self.unroll_n_steps else: print("Using passed in unroll steps") if inner_batches is None: inner_batches = self.inner_batches else: # convert the batches object to a tensorarray. def to_ta(t): return tf.TensorArray(dtype=t.dtype, size=self.unroll_n_steps).unstack(t) inner_batches = nest.map_structure(to_ta, inner_batches) if meta_batches is None: meta_batches = self.meta_batches else: # convert the batches object to a tensorarray. def ml_to_ta(t): return tf.TensorArray(dtype=t.dtype, size=self.meta_loss_evals * self.unroll_n_steps).unstack(t) meta_batches = nest.map_structure(ml_to_ta, meta_batches) if init_state is None: init_state = learner.current_state() init_state = tf_utils.force_copy(init_state) current_state = (tf.constant(0, dtype=tf.int32), tf.constant(0., dtype=tf.float32), init_state) def loss_and_next_state_fn((idx, l, state)): batch = self.get_batch(idx, batches=inner_batches) l, s = learner.loss_and_next_state(state, loss_state=batch) return (idx + 1, l, s) def accumulate_fn((idx, _, s), (a_meta, a_inner)): """Accumulate loss for fold learning process.""" cond = lambda i, a: tf.less(i, self.meta_loss_evals) def body_meta(i, a): # minus 1 as this takes the following step. batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i, batches=meta_batches) return (i + 1, a + learner.meta_loss(s, loss_state=batch)) _, extra_losses = tf.while_loop(cond, body_meta, loop_vars=[0, 0.]) def body_inner(i, a): # minus 1 as this takes the following step. batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i, batches=meta_batches) return (i + 1, a + learner.inner_loss(s, loss_state=batch)) _, inner_losses = tf.while_loop(cond, body_inner, loop_vars=[0, 0.]) return a_meta + extra_losses, a_inner + inner_losses