Ejemplo n.º 1
0
    def init_policy_train_op(self, loss_policy, loss_policy_sampled, wd_dict):
        if self.config.use_adam:
            self.stepsize = tf.Variable(np.float32(np.array(1e-4)),
                                        dtype=tf.float32)
            self.updates = tf.train.AdamOptimizer(
                self.stepsize).minimize(loss_policy)
            self.queue_runner = None
        elif self.config.use_sgd:
            self.stepsize = tf.Variable(np.float32(np.array(self.config.lr)),
                                        dtype=tf.float32)
            self.updates = tf.train.MomentumOptimizer(
                self.stepsize * (1. - self.config.mom),
                self.config.mom).minimize(loss_policy)
            self.queue_runner = None
        else:
            self.stepsize = tf.Variable(np.float32(np.array(self.config.lr)),
                                        dtype=tf.float32)
            self.updates, self.queue_runner = kfac.KfacOptimizer(
                learning_rate=self.stepsize,
                cold_lr=self.stepsize / 3.,
                momentum=self.config.mom,
                clip_kl=self.config.kl_desired,
                upper_bound_kl=self.config.upper_bound_kl,
                epsilon=self.config.epsilon,
                stats_decay=self.config.stats_decay,
                async=self.config.async_kfac,
                kfac_update=self.config.kfac_update,
                cold_iter=self.config.cold_iter,
                weight_decay_dict=wd_dict).minimize(loss_policy,
                                                    loss_policy_sampled,
                                                    self.policy_var_list)

        return self.updates, self.queue_runner
Ejemplo n.º 2
0
def _make_distributed_train_op(task_id, num_worker_tasks, num_ps_tasks,
                               layer_collection):
    """Creates optimizer and distributed training op.

  Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
  the train op.

  Args:
   task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
    num_worker_tasks: int. Number of workers in this distributed training setup.
    num_ps_tasks: int. Number of parameter servers holding variables. If 0,
      parameter servers are not used.
    layer_collection: LayerCollection instance describing model architecture.
      Used by K-FAC to construct preconditioner.

  Returns:
    sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
      optimizer.
    optimizer: Instance of `KfacOptimizer`.
    global_step: `tensor`, Global step.
  """
    tf.logging.info("Task id : %d", task_id)
    with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
        global_step = tf.train.get_or_create_global_step()
        optimizer = kfac.KfacOptimizer(learning_rate=0.0001,
                                       cov_ema_decay=0.95,
                                       damping=0.001,
                                       layer_collection=layer_collection,
                                       momentum=0.9)
        sync_optimizer = tf.train.SyncReplicasOptimizer(
            opt=optimizer,
            replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
            total_num_replicas=num_worker_tasks)
        return sync_optimizer, optimizer, global_step
Ejemplo n.º 3
0
    def init_vf_train_op(self, loss_vf, loss_vf_sampled, wd_dict):
        if self.config.use_adam_vf:
            # 0.001
            self.update_op = tf.train.AdamOptimizer(
                learning_rate=0.001).minimize(loss_vf)
            self.queue_runner = None
        elif self.config.use_sgd_vf:
            # 0.001*(1.-0.9), 0.9
            self.update_op = tf.train.MomentumOptimizer(
                0.001 * (1. - 0.9), 0.9).minimize(loss_vf)
            self.queue_runner = None
        else:
            self.update_op, self.queue_runner = kfac.KfacOptimizer(
                learning_rate=self.config.lr_vf,
                cold_lr=self.config.lr_vf / 3.,
                momentum=self.config.mom_vf,
                clip_kl=self.config.kl_desired_vf,
                upper_bound_kl=False,
                epsilon=self.config.epsilon_vf,
                stats_decay=self.config.stats_decay_vf,
                async=self.config.async_kfac,
                kfac_update=self.config.kfac_update_vf,
                cold_iter=self.config.cold_iter_vf,
                weight_decay_dict=wd_dict).minimize(loss_vf, loss_vf_sampled,
                                                    self.var_list)

        with tf.control_dependencies([self.update_op]):
            self.train = tf.group(self.update_averages)

        return self.train, self.queue_runner
Ejemplo n.º 4
0
    def __init__(self, ob_dim, ac_dim):  #pylint: disable=W0613
        X = tf.placeholder(tf.float32,
                           shape=[None, ob_dim * 2 + ac_dim * 2 + 2
                                  ])  # batch of observations
        vtarg_n = tf.placeholder(tf.float32, shape=[None], name='vtarg')
        wd_dict = {}
        h1 = tf.nn.elu(
            dense(X,
                  64,
                  "h1",
                  weight_init=U.normc_initializer(1.0),
                  bias_init=0,
                  weight_loss_dict=wd_dict))
        h2 = tf.nn.elu(
            dense(h1,
                  64,
                  "h2",
                  weight_init=U.normc_initializer(1.0),
                  bias_init=0,
                  weight_loss_dict=wd_dict))
        vpred_n = dense(h2,
                        1,
                        "hfinal",
                        weight_init=U.normc_initializer(1.0),
                        bias_init=0,
                        weight_loss_dict=wd_dict)[:, 0]
        sample_vpred_n = vpred_n + tf.random_normal(tf.shape(vpred_n))
        wd_loss = tf.get_collection("vf_losses", None)
        loss = U.mean(tf.square(vpred_n - vtarg_n)) + tf.add_n(wd_loss)
        loss_sampled = U.mean(
            tf.square(vpred_n - tf.stop_gradient(sample_vpred_n)))
        self._predict = U.function([X], vpred_n)
        optim = kfac.KfacOptimizer(learning_rate=0.001, cold_lr=0.001*(1-0.9), momentum=0.9, \
                                    clip_kl=0.3, epsilon=0.1, stats_decay=0.95, \
                                    async=1, kfac_update=2, cold_iter=50, \
                                    weight_decay_dict=wd_dict, max_grad_norm=None)
        vf_var_list = []
        for var in tf.trainable_variables():
            if "vf" in var.name:
                vf_var_list.append(var)

        update_op, self.q_runner = optim.minimize(loss,
                                                  loss_sampled,
                                                  var_list=vf_var_list)
        self.do_update = U.function([X, vtarg_n], update_op)  #pylint: disable=E1101
        U.initialize()  # Initialize uninitialized TF variables
Ejemplo n.º 5
0
def minimize_loss_single_machine(loss,
                                 accuracy,
                                 layer_collection,
                                 session_config=None):
  """Minimize loss with K-FAC on a single machine.

  A single Session is responsible for running all of K-FAC's ops.

  Args:
    loss: 0-D Tensor. Loss to be minimized.
    accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
    layer_collection: LayerCollection instance describing model architecture.
      Used by K-FAC to construct preconditioner.
    session_config: None or tf.ConfigProto. Configuration for tf.Session().

  Returns:
    final value for 'accuracy'.
  """
  # Train with K-FAC.
  global_step = tf.train.get_or_create_global_step()
  optimizer = kfac.KfacOptimizer(
      learning_rate=0.0001,
      cov_ema_decay=0.95,
      damping=0.001,
      layer_collection=layer_collection,
      momentum=0.9)
  train_op = optimizer.minimize(loss, global_step=global_step)

  tf.logging.info("Starting training.")
  with tf.train.MonitoredTrainingSession(config=session_config) as sess:
    while not sess.should_stop():
      global_step_, loss_, accuracy_, _, _ = sess.run(
          [global_step, loss, accuracy, train_op, optimizer.cov_update_op])

      if global_step_ % 100 == 0:
        sess.run(optimizer.inv_update_op)

      if global_step_ % 100 == 0:
        tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                        global_step_, loss_, accuracy_)

  return accuracy_
Ejemplo n.º 6
0
    def model_fn(features, labels, mode, params):
        """Model function for MLP trained with K-FAC.

    Args:
      features: Tensor of shape [batch_size, input_size]. Input features.
      labels: Tensor of shape [batch_size]. Target labels for training.
      mode: tf.estimator.ModeKey. Must be TRAIN.
      params: ignored.

    Returns:
      EstimatorSpec for training.

    Raises:
      ValueError: If 'mode' is anything other than TRAIN.
    """
        del params

        if mode != tf.estimator.ModeKeys.TRAIN:
            raise ValueError("Only training is supported with this API.")

        # Build a ConvNet.
        layer_collection = kfac.LayerCollection()
        loss, accuracy = build_model(features,
                                     labels,
                                     num_labels=10,
                                     layer_collection=layer_collection,
                                     register_layers_manually=_USE_MANUAL_REG)
        if not _USE_MANUAL_REG:
            layer_collection.auto_register_layers()

        # Train with K-FAC.
        global_step = tf.train.get_or_create_global_step()
        optimizer = kfac.KfacOptimizer(
            learning_rate=tf.train.exponential_decay(0.00002,
                                                     global_step,
                                                     10000,
                                                     0.5,
                                                     staircase=True),
            cov_ema_decay=0.95,
            damping=0.001,
            layer_collection=layer_collection,
            momentum=0.9)

        (cov_update_thunks,
         inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()

        def make_update_op(update_thunks):
            update_ops = [thunk() for thunk in update_thunks]
            return tf.group(*update_ops)

        def make_batch_executed_op(update_thunks, batch_size=1):
            return tf.group(*kfac.utils.batch_execute(
                global_step, update_thunks, batch_size=batch_size))

        # Run cov_update_op every step. Run 1 inv_update_ops per step.
        cov_update_op = make_update_op(cov_update_thunks)
        with tf.control_dependencies([cov_update_op]):
            # But make sure to execute all the inverse ops on the first step
            inverse_op = tf.cond(
                tf.equal(global_step,
                         0), lambda: make_update_op(inv_update_thunks),
                lambda: make_batch_executed_op(inv_update_thunks))
            with tf.control_dependencies([inverse_op]):
                train_op = optimizer.minimize(loss, global_step=global_step)

        # Print metrics every 5 sec.
        hooks = [
            tf.train.LoggingTensorHook({
                "loss": loss,
                "accuracy": accuracy
            },
                                       every_n_secs=5),
        ]
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          training_hooks=hooks)
Ejemplo n.º 7
0
def minimize_loss_single_machine_manual(loss,
                                        accuracy,
                                        layer_collection,
                                        device=None,
                                        session_config=None):
    """Minimize loss with K-FAC on a single machine(Illustrative purpose only).

  This function does inverse and covariance computation manually
  for illustrative pupose. Check `minimize_loss_single_machine` for
  automatic inverse and covariance op placement and execution.
  A single Session is responsible for running all of K-FAC's ops. The covariance
  and inverse update ops are placed on `device`. All model variables are on CPU.

  Args:
    loss: 0-D Tensor. Loss to be minimized.
    accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
    layer_collection: LayerCollection instance describing model architecture.
      Used by K-FAC to construct preconditioner.
    device: string or None. The covariance and inverse update ops are run on
      this device. If empty or None, the default device will be used.
      (Default: None)
    session_config: None or tf.ConfigProto. Configuration for tf.Session().

  Returns:
    final value for 'accuracy'.
  """
    device_list = [] if not device else [device]

    # Train with K-FAC.
    g_step = tf.train.get_or_create_global_step()
    optimizer = kfac.KfacOptimizer(learning_rate=0.0001,
                                   cov_ema_decay=0.95,
                                   damping=0.001,
                                   layer_collection=layer_collection,
                                   placement_strategy="round_robin",
                                   cov_devices=device_list,
                                   inv_devices=device_list,
                                   trans_devices=device_list,
                                   momentum=0.9)
    (cov_update_thunks,
     inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()

    def make_update_op(update_thunks):
        update_ops = [thunk() for thunk in update_thunks]
        return tf.group(*update_ops)

    cov_update_op = make_update_op(cov_update_thunks)
    with tf.control_dependencies([cov_update_op]):
        inverse_op = tf.cond(tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
                             lambda: make_update_op(inv_update_thunks),
                             tf.no_op)
        with tf.control_dependencies([inverse_op]):
            with tf.device(device):
                train_op = optimizer.minimize(loss, global_step=g_step)

    tf.logging.info("Starting training.")
    with tf.train.MonitoredTrainingSession(config=session_config) as sess:
        while not sess.should_stop():
            global_step_, loss_, accuracy_, _ = sess.run(
                [g_step, loss, accuracy, train_op])

            if global_step_ % _REPORT_EVERY == 0:
                tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                                global_step_, loss_, accuracy_)

    return accuracy_
Ejemplo n.º 8
0
def learn(env,
          policy,
          vf,
          gamma,
          lam,
          timesteps_per_batch,
          num_timesteps,
          animate=False,
          callback=None,
          desired_kl=0.002):

    obfilter = ZFilter(env.observation_space.shape)

    max_pathlength = env.spec.timestep_limit
    stepsize = tf.Variable(initial_value=np.float32(np.array(0.03)),
                           name='stepsize')

    X_v, vtarg_n_v, loss2, loss_sampled2 = vf.update_info
    optim2 = kfac.KfacOptimizer(learning_rate=0.001, cold_lr=0.001*(1-0.9), momentum=0.9, \
                                clip_kl=0.3, epsilon=0.1, stats_decay=0.95, \
                                async=0, kfac_update=2, cold_iter=50, \
                                weight_decay_dict=vf.wd_dict, max_grad_norm=None)
    vf_var_list = []
    for var in tf.trainable_variables():
        if "vf" in var.name:
            vf_var_list.append(var)
    update_op2 = optim2.minimize(loss2, loss_sampled2, var_list=vf_var_list)

    ob_p, oldac_p, adv_p, loss, loss_sampled = policy.update_info
    optim = kfac.KfacOptimizer(learning_rate=stepsize, cold_lr=stepsize*(1-0.9), momentum=0.9, kfac_update=2,\
                                epsilon=1e-2, stats_decay=0.99, async=0, cold_iter=1,
                                weight_decay_dict=policy.wd_dict, max_grad_norm=None)
    pi_var_list = []
    for var in tf.trainable_variables():
        if "pi" in var.name:
            pi_var_list.append(var)
    update_op = optim.minimize(loss, loss_sampled, var_list=pi_var_list)

    sess = tf.get_default_session()
    sess.run(tf.variables_initializer(set(tf.global_variables())))

    i = 0
    timesteps_so_far = 0
    while True:
        if timesteps_so_far > num_timesteps:
            break
        logger.log("********** Iteration %i ************" % i)

        # Collect paths until we have enough timesteps
        timesteps_this_batch = 0
        paths = []
        while True:
            path = rollout(env,
                           policy,
                           max_pathlength,
                           animate=(len(paths) == 0 and (i % 10 == 0)
                                    and animate),
                           obfilter=obfilter)
            paths.append(path)
            n = pathlength(path)
            timesteps_this_batch += n
            timesteps_so_far += n
            if timesteps_this_batch > timesteps_per_batch:
                break

        # Estimate advantage function
        vtargs = []
        advs = []
        for path in paths:
            rew_t = path["reward"]
            return_t = discount(rew_t, gamma)
            vtargs.append(return_t)
            vpred_t = vf.predict(path)
            vpred_t = np.append(vpred_t,
                                0.0 if path["terminated"] else vpred_t[-1])
            delta_t = rew_t + gamma * vpred_t[1:] - vpred_t[:-1]
            adv_t = discount(delta_t, gamma * lam)
            advs.append(adv_t)

        # Update value function
        paths_ = []
        for p in paths:
            l = pathlength(p)
            act = p["action_dist"].astype('float32')
            paths_.append(
                np.concatenate([p['observation'], act,
                                np.ones((l, 1))],
                               axis=1))
        X1 = np.concatenate(paths_)
        y = np.concatenate(vtargs)
        logger.record_tabular("EVBefore",
                              explained_variance(vf._predict(X1), y))
        #        for _ in range(20):
        #            sess.run(update_op2, {X_v:X1, vtarg_n_v:y}) #do_update2(X, y)
        logger.record_tabular("EVAfter",
                              explained_variance(vf._predict(X1), y))

        # Build arrays for policy update
        ob_no = np.concatenate([path["observation"] for path in paths])
        action_na = np.concatenate([path["action"] for path in paths])
        oldac_dist = np.concatenate([path["action_dist"] for path in paths])
        adv_n = np.concatenate(advs)
        standardized_adv_n = (adv_n - adv_n.mean()) / (adv_n.std() + 1e-8)

        # Policy update
        sess.run(update_op, {
            ob_p: ob_no,
            oldac_p: action_na,
            adv_p: standardized_adv_n
        })

        min_stepsize = np.float32(1e-8)
        max_stepsize = np.float32(1e0)

        # Adjust stepsize
        kl = policy.compute_kl(ob_no, oldac_dist)
        if kl > desired_kl * 2:
            logger.log("kl too high")
            tf.assign(stepsize, tf.maximum(min_stepsize,
                                           stepsize / 1.5)).eval()
        elif kl < desired_kl / 2:
            logger.log("kl too low")
            tf.assign(stepsize, tf.minimum(max_stepsize,
                                           stepsize * 1.5)).eval()
        else:
            logger.log("kl just right!")

        logger.record_tabular(
            "EpRewMean", np.mean([path["reward"].sum() for path in paths]))
        logger.record_tabular(
            "EpRewSEM",
            np.std([
                path["reward"].sum() / np.sqrt(len(paths)) for path in paths
            ]))
        logger.record_tabular("EpLenMean",
                              np.mean([pathlength(path) for path in paths]))
        logger.record_tabular("KL", kl)
        if callback:
            callback()
        logger.dump_tabular()
        i += 1
Ejemplo n.º 9
0
def learn(env, policy, vf, gamma, lam, timesteps_per_batch, num_timesteps,
    animate=False, callback=None, desired_kl=0.002):

    obfilter = ZFilter(env.observation_space.shape)

    max_pathlength = env.spec.timestep_limit
    stepsize = tf.Variable(initial_value=np.float32(np.array(0.03)), name='stepsize')
    inputs, loss, loss_sampled = policy.update_info
    optim = kfac.KfacOptimizer(learning_rate=stepsize, cold_lr=stepsize*(1-0.9), momentum=0.9, kfac_update=2,\
                                epsilon=1e-2, stats_decay=0.99, async=1, cold_iter=1,
                                weight_decay_dict=policy.wd_dict, max_grad_norm=None)
    pi_var_list = []
    for var in tf.trainable_variables():
        if "pi" in var.name:
            pi_var_list.append(var)

    update_op, q_runner = optim.minimize(loss, loss_sampled, var_list=pi_var_list)
    do_update = U.function(inputs, update_op)
    U.initialize()

    # start queue runners
    enqueue_threads = []
    coord = tf.train.Coordinator()
    for qr in [q_runner, vf.q_runner]:
        assert (qr != None)
        enqueue_threads.extend(qr.create_threads(U.get_session(), coord=coord, start=True))

    i = 0
    timesteps_so_far = 0
    while True:
        if timesteps_so_far > num_timesteps:
            break
        logger.log("********** Iteration %i ************"%i)

        # Collect paths until we have enough timesteps
        timesteps_this_batch = 0
        paths = []
        while True:
            path = rollout(env, policy, max_pathlength, animate=(len(paths)==0 and (i % 10 == 0) and animate), obfilter=obfilter)
            paths.append(path)
            n = pathlength(path)
            timesteps_this_batch += n
            #timesteps_so_far += n
            if timesteps_this_batch > timesteps_per_batch:
                break
        timesteps_so_far += policy.batch

        # Estimate advantage function
        vtargs = []
        advs = []
        for path in paths:
            rew_t = path["reward"]
            return_t = common.discount(rew_t, gamma)
            vtargs.append(return_t)
            vpred_t = vf.predict(path)
            vpred_t = np.append(vpred_t, 0.0 if path["terminated"] else vpred_t[-1])
            delta_t = rew_t + gamma*vpred_t[1:] - vpred_t[:-1]
            adv_t = common.discount(delta_t, gamma * lam)
            advs.append(adv_t)
        # Update value function
        vf.fit(paths, vtargs)

        # Build arrays for policy update
        ob_no = np.concatenate([path["observation"] for path in paths])
        action_na = np.concatenate([path["action"] for path in paths])
        oldac_dist = np.concatenate([path["action_dist"] for path in paths])
        adv_n = np.concatenate(advs)

        # shape things into correct batchsize
        batch = policy.batch
        ob_no = np.concatenate([path["observation"] for path in paths])[:batch]
        action_na = np.concatenate([path["action"] for path in paths])[:batch]
        oldac_dist = np.concatenate([path["action_dist"] for path in paths])[:batch]
        adv_n = np.concatenate(advs)[:batch]        

        standardized_adv_n = (adv_n - adv_n.mean()) / (adv_n.std() + 1e-8)

        # Policy update
        do_update(ob_no, action_na, standardized_adv_n)

        min_stepsize = np.float32(1e-8)
        max_stepsize = np.float32(1e0)
        # Adjust stepsize
        kl = policy.compute_kl(ob_no, action_na, oldac_dist)
        if kl > desired_kl * 2:
            logger.log("kl too high")
            U.eval(tf.assign(stepsize, tf.maximum(min_stepsize, stepsize / 1.5)))
        elif kl < desired_kl / 2:
            logger.log("kl too low")
            U.eval(tf.assign(stepsize, tf.minimum(max_stepsize, stepsize * 1.5)))            
        else:
            logger.log("kl just right!")

        logger.record_tabular("EpRewMean", np.mean([path["reward"].sum() for path in paths]))
        logger.record_tabular("EpRewSEM", np.std([path["reward"].sum()/np.sqrt(len(paths)) for path in paths]))
        logger.record_tabular("EpLenMean", np.mean([pathlength(path) for path in paths]))
        logger.record_tabular("KL", kl)
        eprewmean = np.mean([path["reward"].sum() for path in paths])
        if callback is not None:
            callback(locals(), globals())
        logger.dump_tabular()
        i += 1

    coord.request_stop()
    coord.join(enqueue_threads)
Ejemplo n.º 10
0
def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
                              checkpoint_dir, loss, accuracy, layer_collection):
  """Minimize loss with an synchronous implementation of K-FAC.

  Different tasks are responsible for different parts of K-FAC's Ops. The first
  60% of tasks update weights; the next 20% accumulate covariance statistics;
  the last 20% invert the matrices used to precondition gradients.

  Args:
    task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
    num_worker_tasks: int. Number of workers in this distributed training setup.
    num_ps_tasks: int. Number of parameter servers holding variables. If 0,
      parameter servers are not used.
    master: string. IP and port of TensorFlow runtime process. Set to empty
      string to run locally.
    checkpoint_dir: string or None. Path to store checkpoints under.
    loss: 0-D Tensor. Loss to be minimized.
    accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
      run with each step.
    layer_collection: LayerCollection instance describing model architecture.
      Used by K-FAC to construct preconditioner.

  Returns:
    final value for 'accuracy'.

  Raises:
    ValueError: if task_id >= num_worker_tasks.
  """
  with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
    global_step = tf.train.get_or_create_global_step()
    optimizer = kfac.KfacOptimizer(
        learning_rate=0.0001,
        cov_ema_decay=0.95,
        damping=0.001,
        layer_collection=layer_collection,
        momentum=0.9)
    inv_update_queue = kfac.op_queue.OpQueue(optimizer.inv_update_ops)
    sync_optimizer = tf.train.SyncReplicasOptimizer(
        opt=optimizer,
        replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))
    train_op = sync_optimizer.minimize(loss, global_step=global_step)

  tf.logging.info("Starting training.")
  is_chief = (task_id == 0)
  hooks = [sync_optimizer.make_session_run_hook(is_chief)]
  with tf.train.MonitoredTrainingSession(
      master=master,
      is_chief=is_chief,
      checkpoint_dir=checkpoint_dir,
      hooks=hooks,
      stop_grace_period_secs=0) as sess:
    while not sess.should_stop():
      # Choose which op this task is responsible for running.
      if _is_gradient_task(task_id, num_worker_tasks):
        learning_op = train_op
      elif _is_cov_update_task(task_id, num_worker_tasks):
        learning_op = optimizer.cov_update_op
      elif _is_inv_update_task(task_id, num_worker_tasks):
        # TODO(duckworthd): Running this op before cov_update_op has been run a
        # few times can result in "InvalidArgumentError: Cholesky decomposition
        # was not successful." Delay running this op until cov_update_op has
        # been run a few times.
        learning_op = inv_update_queue.next_op(sess)
      else:
        raise ValueError("Which op should task %d do?" % task_id)

      global_step_, loss_, accuracy_, _ = sess.run(
          [global_step, loss, accuracy, learning_op])
      tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
                      loss_, accuracy_)

  return accuracy_