def test_num_steps(self):
    logdir = self._test_dir('test_num_steps')
    with tf.Graph().as_default():
      gstep = tf.contrib.framework.get_or_create_global_step()
      do_step = tf.assign_add(gstep, 1)
      # Do 3 steps and save.
      hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
      scaffold = monitored_session.Scaffold().finalize()
      with monitored_session.MonitoredSession(hooks=hooks) as session:
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertTrue(session.should_stop())
        save_path = scaffold.saver.save(session._coordinated_creator.tf_sess,
                                        os.path.join(logdir, 'step-3'))
      # Restore and do 4 steps.
      def load_ckpt(scaffold, sess):
        scaffold.saver.restore(sess, save_path)

      session_creator = monitored_session.ChiefSessionCreator(
          scaffold=monitored_session.Scaffold(init_fn=load_ckpt))
      hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=4)]
      with monitored_session.MonitoredSession(
          hooks=hooks, session_creator=session_creator) as session:
        self.assertEqual(4, session.run(do_step))
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertTrue(session.should_stop())
 def test_retry_on_aborted_error(self):
     # Tests that we silently retry on abort.  Note that this does not test
     # recovery as we do not use a CheckpointSaver in this test.
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         do_step = tf.assign_add(gstep, 1)
         scaffold = monitored_session.Scaffold()
         hook = RaiseOnceAtCountN(
             4, tf.errors.AbortedError(None, None, 'Abort'))
         with monitored_session.MonitoredSession('',
                                                 scaffold=scaffold,
                                                 hooks=[hook]) as session:
             self.assertEqual(0, session.run(gstep))
             self.assertEqual(1, session.run(do_step))
             self.assertEqual(2, session.run(do_step))
             self.assertFalse(session.should_stop())
             # Here at step 3, the hook triggers and raises AbortedError.  The
             # MonitoredSession automatically retries and restart from a freshly
             # initialized session, so the step is back to 0 and running do_step
             # moves it to 1.
             self.assertEqual(1, session.run(do_step))
             self.assertFalse(session.should_stop())
             self.assertTrue(hook.raised)
             self.assertEqual(2, session.run(do_step))
             self.assertFalse(session.should_stop())
 def test_recover_and_retry_on_aborted_error(self):
     # Tests that we silently retry and recover on abort.  This test uses
     # a CheckpointSaver to have something to recover from.
     logdir = self._test_dir('test_recover_and_retry_on_aborted_error')
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         do_step = tf.assign_add(gstep, 1)
         scaffold = monitored_session.Scaffold()
         abort_hook = RaiseOnceAtCountN(
             4, tf.errors.AbortedError(None, None, 'Abort'))
         # Save after each step.
         ckpt_hook = basic_session_run_hooks.CheckpointSaverHook(
             logdir, save_steps=1, scaffold=scaffold)
         hooks = [abort_hook, ckpt_hook]
         with monitored_session.MonitoredSession('',
                                                 scaffold=scaffold,
                                                 checkpoint_dir=logdir,
                                                 hooks=hooks) as session:
             self.assertEqual(0, session.run(gstep))
             self.assertEqual(1, session.run(do_step))
             self.assertEqual(2, session.run(do_step))
             self.assertFalse(session.should_stop())
             # Here at step 3, the hook triggers and raises AbortedError.  The
             # MonitoredSession automatically restores and retries.
             self.assertEqual(3, session.run(do_step))
             self.assertTrue(abort_hook.raised)
             self.assertFalse(session.should_stop())
             self.assertEqual(4, session.run(do_step))
             self.assertFalse(session.should_stop())
def main(_):
    # Configuration.
    num_unrolls = FLAGS.num_steps

    if FLAGS.seed:
        tf.set_random_seed(FLAGS.seed)

    # Problem.
    problem, net_config, net_assignments = util.get_config(
        FLAGS.problem, FLAGS.path)

    optimizer = meta.MetaOptimizer(**net_config)
    meta_loss = optimizer.meta_loss(problem,
                                    1,
                                    net_assignments=net_assignments)
    _, update, reset, cost_op, _ = meta_loss

    with ms.MonitoredSession() as sess:
        # Prevent accidental changes to the graph.
        tf.get_default_graph().finalize()

        total_time = 0
        total_cost = 0
        for _ in xrange(FLAGS.num_epochs):
            # Training.
            time, cost = util.run_epoch(sess, cost_op, [update], reset,
                                        num_unrolls)
            total_time += time
            total_cost += cost

        # Results.
        util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost,
                         total_time, FLAGS.num_epochs)
    def _create_session(self):
        """Factory for the RecoverableSession.

    Returns:
      A session, initialized or recovered as needed.
    """
        if self._is_chief:
            tf_sess = self._session_manager.prepare_session(
                self._master,
                saver=self._scaffold.saver,
                checkpoint_dir=self._checkpoint_dir,
                config=self._config,
                init_op=self._scaffold.init_op,
                init_feed_dict=self._scaffold.init_feed_dict,
                init_fn=self._scaffold.init_fn)
        else:
            tf_sess = self._session_manager.wait_for_session(
                self._master, config=self._config)
        # Keep the tf_sess for quick runs of global step when needed.
        self._tf_sess = tf_sess
        self._coord = coordinator.Coordinator(
            clean_stop_exception_types=self._clean_stop_exception_types)
        self._coordinated_threads_to_join = queue_runner.start_queue_runners(
            sess=tf_sess, coord=self._coord)
        return coordinated_session.CoordinatedSession(
            monitored_session.MonitoredSession(
                tf_sess, self._monitors, self._scaffold.global_step_tensor),
            self._coord, self._coordinated_threads_to_join)
    def testCallsMonitorsWithLastStep(self):
        with tf.Graph().as_default(), tf.Session() as sess:
            global_step_tensor = tf.contrib.framework.create_global_step()
            mock_mon = FakeMonitor()
            mock_mon2 = FakeMonitor()
            mon_sess = monitored_session.MonitoredSession(
                sess=sess,
                monitors=[mock_mon, mock_mon2],
                global_step_tensor=global_step_tensor)
            inc_5 = tf.assign_add(global_step_tensor, 5)
            # Initialize global_step_tensor to '0':
            sess.run(tf.initialize_all_variables())

            mon_sess.run(fetches=[inc_5])
            for mon in [mock_mon, mock_mon2]:
                self.assertEqual(mon.begin_step, 1)
                self.assertEqual(mon.end_step, 1)

            mon_sess.run(fetches=[inc_5])
            for mon in [mock_mon, mock_mon2]:
                self.assertEqual(mon.begin_step, 6)
                self.assertEqual(mon.end_step, 6)

            mon_sess.run(fetches=[inc_5])
            for mon in [mock_mon, mock_mon2]:
                self.assertEqual(mon.begin_step, 11)
                self.assertEqual(mon.end_step, 11)
def main(_):
    # Configuration.

    # Problem.
    problem, net_config, net_assignments = util.get_config(FLAGS.problem)

    loss = problem()
    global_step = tf.Variable(0, dtype=tf.int64)
    # Optimizer setup.

    adam_opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    opt2 = adam_opt.minimize(loss, global_step)

    adagrad_opt = tf.train.AdagradOptimizer(FLAGS.learning_rate)
    optimizer = l2l_optimizer.L2LOptimizer(internal_optimizer=adam_opt,
                                           loss_func=problem,
                                           opt_last=FLAGS.opt_last,
                                           preprocessor=LogAndSign(10),
                                           co_opt=FLAGS.co_opt,
                                           rnn_layer_cnt=FLAGS.layer,
                                           delta_ratio=FLAGS.delta_ratio,
                                           update_ratio=FLAGS.update_ratio,
                                           dynamic_unroll=FLAGS.dynamic_unroll)

    opt = optimizer.minimize(loss,
                             global_step=global_step,
                             unroll_len=FLAGS.unroll_len)
    if FLAGS.mode == 1:
        print('use adam opt')
        opt = opt2
    else:
        print('use l2l opt')
    slot_reset = tf.variables_initializer(optimizer._slot_vars +
                                          optimizer._opt_vars)
    with ms.MonitoredSession() as sess:
        # Prevent accidental changes to the graph.
        tf.get_default_graph().finalize()

        print('trainable variables')
        trainable_vars = tf.trainable_variables()
        for v in trainable_vars:
            print("parameter:", v.name, "device:", v.device, "shape:",
                  v.get_shape())

        best_evaluation = float("inf")
        total_time = 0
        accum_loss = 0.0
        total_cost = 0
        for e in xrange(FLAGS.num_epochs):
            # Training.
            step, curr_loss, _ = sess.run([global_step, loss, opt])
            accum_loss += curr_loss
            if step % 100 == 0:
                print('step:%d,loss:%f' % (step, accum_loss / 100))
                accum_loss = 0

            if step % FLAGS.reset_interval == 0:
                #print('reset')
                sess.run(slot_reset)
Exemple #8
0
def run_tf_loop():
  fx_array = create_tf_loop()

  with ms.MonitoredSession() as sess:
    tf.get_default_graph().finalize()

    loss, size = sess.run(fx_array)
    return loss
Exemple #9
0
def main(_):
    # Configuration.
    num_unrolls = FLAGS.num_steps
    if FLAGS.seed:
        tf.set_random_seed(FLAGS.seed)

    # Problem.
    problem, net_config, net_assignments = util.get_config(FLAGS.problem, FLAGS.path)

    # Optimizer setup.
    if FLAGS.optimizer == "Adam":
        cost_op = problem()
        problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        problem_reset = tf.variables_initializer(problem_vars)

        optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
        optimizer_reset = tf.variables_initializer(optimizer.get_slot_names())
        update = optimizer.minimize(cost_op)
        reset = [problem_reset, optimizer_reset]
    elif FLAGS.optimizer == "L2L":
        if FLAGS.path is None:
            logging.warning("Evaluating untrained L2L optimizer")
        optimizer = meta.MetaOptimizer(**net_config)
        meta_loss = optimizer.meta_loss(problem, 1, net_assignments=net_assignments)
        _, update, reset, cost_op, _ = meta_loss
    else:
        raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with ms.MonitoredSession() as sess:
        sess.run(reset)
        # Prevent accidental changes to the graph.
        tf.get_default_graph().finalize()

        total_time = 0
        total_cost = 0
        loss_record = []
        for e in xrange(FLAGS.num_epochs):
            # Training.
            time, cost = util.run_eval_epoch(sess, cost_op, [update], num_unrolls)
            total_time += time
            total_cost += sum(cost) / num_unrolls
            loss_record += cost

        # Results.
        util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost,
                         total_time, FLAGS.num_epochs)

    if FLAGS.output_path is not None:
        if not os.path.exists(FLAGS.output_path):
            os.mkdir(FLAGS.output_path)
    output_file = '{}/{}_eval_loss_record.pickle-{}'.format(FLAGS.output_path, FLAGS.optimizer, FLAGS.problem)
    with open(output_file, 'wb') as l_record:
        pickle.dump(loss_record, l_record)
    print("Saving evaluate loss record {}".format(output_file))
Exemple #10
0
def main(_):
    # Configuration.
    num_unrolls = FLAGS.num_steps

    if FLAGS.seed:
        tf.set_random_seed(FLAGS.seed)

    # Problem.
    problem, net_config, net_assignments = util.get_config(
        FLAGS.problem,
        main_parade_path,
        first_batch_parade_path,
        path=FLAGS.path)

    # Optimizer setup.
    if FLAGS.optimizer == "Adam":
        cost_op = problem()
        problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        problem_reset = tf.variables_initializer(problem_vars)

        optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
        # optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
        optimizer_reset = tf.variables_initializer(optimizer.get_slot_names())
        grads_and_vars = optimizer.compute_gradients(cost_op)
        grads, v = zip(*grads_and_vars)
        grads, _ = tf.clip_by_global_norm(grads, 1.)
        update = optimizer.apply_gradients(zip(grads, v))
        # update = optimizer.minimize(cost_op)
        reset = [problem_reset, optimizer_reset]
    elif FLAGS.optimizer == "L2L":
        if FLAGS.path is None:
            logging.warning("Evaluating untrained L2L optimizer")
        optimizer = meta.MetaOptimizer(**net_config)
        meta_loss = optimizer.meta_loss(problem,
                                        1,
                                        net_assignments=net_assignments)
        _, update, reset, cost_op, _ = meta_loss
    else:
        raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer))

    with ms.MonitoredSession() as sess:
        # Prevent accidental changes to the graph.
        tf.get_default_graph().finalize()

        total_time = 0
        total_cost = 0
        for i in xrange(FLAGS.num_epochs):
            # Training.
            time, cost = util.run_epoch(sess, cost_op, [update], reset,
                                        num_unrolls)
            total_time += time
            total_cost += cost

        # Results.
        util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost,
                         total_time, FLAGS.num_epochs)
 def test_recovery(self):
   logdir = self._test_dir('test_recovery')
   with tf.Graph().as_default():
     gstep = tf.contrib.framework.get_or_create_global_step()
     do_step = tf.assign_add(gstep, 1)
     scaffold = monitored_session.Scaffold()
     # Use a hook to save the model every 100 steps.  It also saves it at
     # the end.
     hooks = [basic_session_run_hooks.CheckpointSaverHook(
         logdir, save_steps=1, scaffold=scaffold)]
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir),
         hooks=hooks) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
     # A restart will find the checkpoint and recover automatically.
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir)) as session:
       self.assertEqual(2, session.run(gstep))
 def test_stop_cleanly_when_no_exception_in_with_body(self):
   # Tests that regular exceptions pass through
   with tf.Graph().as_default():
     gstep = tf.contrib.framework.get_or_create_global_step()
     do_step = tf.assign_add(gstep, 1)
     session = monitored_session.MonitoredSession()
     with session:
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
       self.assertFalse(session.should_stop())
     # Should have closed.
     self.assertTrue(session.should_stop())
     self.assertTrue(session._is_closed())
Exemple #13
0
def main(_):
  # Configuration.
  num_unrolls = FLAGS.num_steps

  if FLAGS.seed:
    tf.set_random_seed(FLAGS.seed)

  # Problem.
  problem, net_config, net_assignments = util.get_config(FLAGS.problem,
                                                         FLAGS.path, mode='test')

  # Optimizer setup.
  if FLAGS.optimizer == "Adam":
    cost_op = problem()
    problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    problem_reset = tf.variables_initializer(problem_vars)

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    optimizer_reset = tf.variables_initializer(optimizer.get_slot_names())
    update = optimizer.minimize(cost_op)
    reset = [problem_reset, optimizer_reset]
  elif FLAGS.optimizer == "L2L":
    if FLAGS.path is None:
      logging.warning("Evaluating untrained L2L optimizer")
    optimizer = meta.MetaOptimizer(FLAGS.problem,, FLAGS.num_particle,  **net_config)
    meta_loss = optimizer.meta_loss(problem, 1, net_assignments=net_assignments, model_path = FLAGS.path)
    loss, update, reset, cost_op, x_final, constant = meta_loss
  else:
    raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer))
  with ms.MonitoredSession() as sess:
    # Prevent accidental changes to the graph.
    tf.get_default_graph().finalize()
    min_loss_record = []
    all_time_loss_record = []
    total_time = 0
    total_cost = 0
    x_record = [[sess.run(item) for item in x_final]]
    for _ in xrange(FLAGS.num_epochs):
      # Training.
      time, cost,  constants = util.eval_run_epoch(sess, cost_op, [update], reset,
                                  num_unrolls, x_final, constant)
      total_time += time
      all_time_loss_record.append(cost)
    with open('./{}/evaluate_record.pickle'.format(FLAGS.path),'wb') as l_record:
      record = {'all_time_loss_record':all_time_loss_record,'loss':cost,\
                'constants':[sess.run(item) for item in constants],\
                }
      pickle.dump(record, l_record)
    # Results.
    util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost,
                     total_time, FLAGS.num_epochs)
    def test_last_step(self):
        logdir = self._test_dir('test_last_step')
        with tf.Graph().as_default():
            gstep = tf.contrib.framework.get_or_create_global_step()
            do_step = tf.assign_add(gstep, 1)
            scaffold = monitored_session.Scaffold()
            # Run till step 3 and save.
            hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)]
            with monitored_session.MonitoredSession('',
                                                    scaffold=scaffold,
                                                    hooks=hooks) as session:
                self.assertEqual(0, session.run(gstep))
                self.assertFalse(session.should_stop())
                self.assertEqual(1, session.run(do_step))
                self.assertFalse(session.should_stop())
                self.assertEqual(2, session.run(do_step))
                self.assertFalse(session.should_stop())
                self.assertEqual(3, session.run(do_step))
                self.assertTrue(session.should_stop())
                save_path = scaffold.saver.save(session.session,
                                                os.path.join(logdir, 'step-3'))
            # Run till step 5 and save.
            def load_ckpt(scaffold, sess):
                scaffold.saver.restore(sess, save_path)

            scaffold = monitored_session.Scaffold(init_fn=load_ckpt)
            hooks = [basic_session_run_hooks.StopAtStepHook(last_step=5)]
            with monitored_session.MonitoredSession('',
                                                    scaffold=scaffold,
                                                    hooks=hooks) as session:
                self.assertEqual(3, session.run(gstep))
                self.assertFalse(session.should_stop())
                self.assertEqual(4, session.run(do_step))
                self.assertFalse(session.should_stop())
                self.assertEqual(5, session.run(do_step))
                self.assertTrue(session.should_stop())
 def test_regular_exception_reported_to_coord_pass_through_return(self):
   # Tests that regular exceptions reported to the coordinator from a thread
   # passes through returning from a "with MonitoredSession" block and
   # set the session in stop mode.
   with tf.Graph().as_default():
     gstep = tf.contrib.framework.get_or_create_global_step()
     session = monitored_session.MonitoredSession()
     with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
       with session:
         self.assertEqual(0, session.run(gstep))
         # Report an exception through the coordinator.
         try:
           raise RuntimeError('a thread wants to stop')
         except RuntimeError as e:
           session._coordinated_creator.coord.request_stop(e)
         self.assertTrue(session.should_stop())
def main(_):
  # Configuration.
  num_unrolls = FLAGS.num_steps // FLAGS.unroll_length

  if FLAGS.save_path is not None:
    if os.path.exists(FLAGS.save_path):
      raise ValueError("Folder {} already exists".format(FLAGS.save_path))
    else:
      os.mkdir(FLAGS.save_path)

  # Problem.
  problem, net_config, net_assignments = util.get_config(FLAGS.problem)

  # Optimizer setup.
  optimizer = meta.MetaOptimizer(**net_config)
  minimize = optimizer.meta_minimize(
      problem, FLAGS.unroll_length,
      learning_rate=FLAGS.learning_rate,
      net_assignments=net_assignments,
      second_derivatives=FLAGS.second_derivatives)
  step, update, reset, cost_op, _ = minimize
  no_op = tf.no_op()
  with ms.MonitoredSession() as sess:
    # Prevent accidental changes to the graph.
    tf.get_default_graph().finalize()

    trainable_vars = tf.trainable_variables()
    print('trainable variables')
    for v in trainable_vars:
        print("parameter:", v.name, "device:", v.device, "shape:", v.get_shape())

    best_evaluation = float("inf")
    total_time = 0
    total_cost = 0
    curr_step = 0
    for e in xrange(FLAGS.num_epochs):
        #time, curr_loss = util.run_epoch(sess, cost_op, [update, step], no_op, 1)
        curr_loss = sess.run([cost_op, update, step])[0]
        total_cost += curr_loss
        curr_step += 1
        if curr_step % 100 ==0:
            print('step:%d,loss:%f' % (curr_step,total_cost/100))
            total_cost = 0

        if curr_step % FLAGS.reset_interval == 0:
            print('reset states')
            sess.run(reset)
 def test_exit_cleanly_on_stop_iteration_exception(self):
   # Tests that we stop cleanly when OutOfRange is raised.
   with tf.Graph().as_default():
     gstep = tf.contrib.framework.get_or_create_global_step()
     do_step = tf.assign_add(gstep, 1)
     hook = RaiseOnceAtCountN(2, StopIteration)
     session = monitored_session.MonitoredSession(hooks=[hook])
     # session should cleanly exit from the context.
     with session:
       self.assertEqual(0, session.run(gstep))
       self.assertFalse(session.should_stop())
       # Here at step 1, the hook triggers and raises StopIteration. The
       # session should go into should_stop() mode. It should raise the
       # exception. So next step should not be executed.
       session.run(do_step)
       self.assertTrue(False)
     self.assertTrue(session.should_stop())
 def test_raises_regular_exceptions_in_with_body(self):
   # Tests that regular exceptions in "with body" are seen outside.
   with tf.Graph().as_default():
     gstep = tf.contrib.framework.get_or_create_global_step()
     do_step = tf.assign_add(gstep, 1)
     session = monitored_session.MonitoredSession()
     # We should see that exception.
     with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
       with session:
         self.assertEqual(1, session.run(do_step))
         self.assertEqual(2, session.run(do_step))
         self.assertFalse(session.should_stop())
         # Will be visible outside the "with body".
         raise RuntimeError('regular exception')
     # Should have closed.
     self.assertTrue(session.should_stop())
     self.assertTrue(session._is_closed())
    def testShouldStop(self):
        with tf.Graph().as_default(), tf.Session() as sess:
            global_step_tensor = tf.contrib.framework.create_global_step()
            mock_mon = FakeMonitor()
            mock_mon2 = FakeMonitor()
            mon_sess = monitored_session.MonitoredSession(
                sess=sess,
                monitors=[mock_mon, mock_mon2],
                global_step_tensor=global_step_tensor)
            tf.constant([0], name='a_tensor')
            sess.run(tf.initialize_all_variables())

            mon_sess.run(fetches='a_tensor')
            self.assertFalse(mon_sess.should_stop())

            mock_mon.should_stop = True
            mon_sess.run(fetches='a_tensor')
            self.assertTrue(mon_sess.should_stop())
    def testMonitorRequestWithColonZero(self):
        with tf.Graph().as_default(), tf.Session() as sess:
            global_step_tensor = tf.contrib.framework.create_global_step()
            mock_mon = FakeMonitor()
            mock_mon2 = FakeMonitor()
            mon_sess = monitored_session.MonitoredSession(
                sess=sess,
                monitors=[mock_mon, mock_mon2],
                global_step_tensor=global_step_tensor)
            a_tensor = tf.constant([0], name='a_tensor')
            tf.constant([5], name='another_tensor')
            mock_mon.requested_tensors = ['another_tensor']
            mock_mon2.requested_tensors = ['another_tensor:0']
            sess.run(tf.initialize_all_variables())

            output = mon_sess.run(fetches=a_tensor)
            self.assertEqual(output, [0])
            self.assertEqual(mock_mon.output['another_tensor'], [5])
            self.assertEqual(mock_mon2.output['another_tensor:0'], [5])
    def testCallsMonitorsBeginAndEnd(self):
        with tf.Graph().as_default(), tf.Session() as sess:
            global_step_tensor = tf.contrib.framework.create_global_step()
            mock_mon = FakeMonitor()
            mock_mon2 = FakeMonitor()
            mon_sess = monitored_session.MonitoredSession(
                sess=sess,
                monitors=[mock_mon, mock_mon2],
                global_step_tensor=global_step_tensor)
            a_tensor = tf.constant([0], name='a_tensor')
            sess.run(tf.initialize_all_variables())
            sess.run(global_step_tensor.assign(10))
            mon_sess.run(fetches=a_tensor)

            for mon in [mock_mon, mock_mon2]:
                self.assertEqual(mon.output, {})
                self.assertEqual(mon.begin_step, 11)
                self.assertEqual(mon.end_step, 11)
                self.assertEqual(mon.call_counter['step_end'], 1)
                self.assertEqual(mon.call_counter['step_begin'], 1)
 def test_regular_exception_pass_through_run(self):
   # Tests that regular exceptions just pass through a "with
   # MonitoredSession" block and set the session in stop mode.
   with tf.Graph().as_default():
     gstep = tf.contrib.framework.get_or_create_global_step()
     do_step = tf.assign_add(gstep, 1)
     hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
     session = monitored_session.MonitoredSession(hooks=[hook])
     with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
       with session:
         self.assertEqual(0, session.run(gstep))
         self.assertEqual(1, session.run(do_step))
         self.assertEqual(2, session.run(do_step))
         self.assertFalse(session.should_stop())
         # This triggers the hook and raises the exception
         session.run(do_step)
         # We should not hit this
         self.assertFalse(True)
     self.assertTrue(hook.raised)
     self.assertTrue(session.should_stop())
 def testRunPassesAllArguments(self):
     with tf.Graph().as_default(), tf.Session() as sess:
         global_step_tensor = tf.contrib.framework.create_global_step()
         mock_run = FakeSession(sess)
         mon_sess = monitored_session.MonitoredSession(
             sess=mock_run,
             monitors=[],
             global_step_tensor=global_step_tensor)
         a_tensor = tf.constant([0], name='a_tensor')
         sess.run(tf.initialize_all_variables())
         output = mon_sess.run(fetches=a_tensor,
                               feed_dict='a_feed',
                               options='an_option',
                               run_metadata='a_metadata')
         self.assertEqual(output, [0])
         self.assertEqual(
             mock_run.args_called, {
                 'feed_dict': 'a_feed',
                 'options': 'an_option',
                 'run_metadata': 'a_metadata'
             })
 def test_regular_exception_reported_to_coord_pass_through_run(self):
     # Tests that regular exceptions reported to the coordinator from a thread
     # passes through a "run()" call within a "with MonitoredSession" block and
     # set the session in stop mode.
     with tf.Graph().as_default():
         gstep = tf.contrib.framework.get_or_create_global_step()
         scaffold = monitored_session.Scaffold()
         session = monitored_session.MonitoredSession('', scaffold=scaffold)
         with self.assertRaisesRegexp(RuntimeError,
                                      'a thread wants to stop'):
             with session:
                 self.assertEqual(0, session.run(gstep))
                 # Report an exception through the coordinator.
                 try:
                     raise RuntimeError('a thread wants to stop')
                 except RuntimeError as e:
                     session.coord.request_stop(e)
                 # Call run() which should raise the reported exception.
                 self.assertEqual(0, session.run(gstep))
                 # We should not hit this
                 self.assertFalse(True)
Exemple #25
0
def main(_):
    # Configuration.

    # Problem.
    with tf.variable_scope("problem",
                           partitioner=tf.min_max_variable_partitioner(
                               max_partitions=2, min_slice_size=10 << 10)):
        problem, net_config, net_assignments = util.get_config(FLAGS.problem)
        loss = problem()
    global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
    # Optimizer setup.
    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='problem')

    print(var_list)
    #adam_opt = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
    adam_opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    opt = adam_opt.minimize(loss, global_step)

    if FLAGS.mode != 1:
        optimizer = l2l_optimizer.L2LOptimizer(
            internal_optimizer=adam_opt,
            loss_func=problem,
            opt_last=FLAGS.opt_last,
            preprocessor=LogAndSign(10),
            co_opt=FLAGS.co_opt,
            rnn_layer_cnt=FLAGS.layer,
            delta_ratio=FLAGS.delta_ratio,
            update_ratio=FLAGS.update_ratio,
            dynamic_unroll=FLAGS.dynamic_unroll)

        opt = optimizer.minimize(loss,
                                 global_step=global_step,
                                 unroll_len=FLAGS.unroll_len)

    if FLAGS.mode == 1:
        print('use adam opt')
    else:
        print('use l2l opt')

    slot_reset = tf.no_op()
    if FLAGS.mode != 1:
        slot_reset = tf.variables_initializer(optimizer._slot_vars +
                                              optimizer._opt_vars)
    init = tf.group(
        *[tf.global_variables_initializer(),
          tf.local_variables_initializer()])

    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    print(var_list)
    #saver = tf.train.Saver(var_list = var_list)

    with ms.MonitoredSession() as sess:
        #with tf.Session() as sess:
        # Prevent accidental changes to the graph.
        tf.get_default_graph().finalize()
        sess.run(init)
        print('trainable variables')
        trainable_vars = tf.trainable_variables()
        for v in trainable_vars:
            print("parameter:", v.name, "device:", v.device, "shape:",
                  v.get_shape())

        best_evaluation = float("inf")
        total_time = 0
        accum_loss = 0.0
        total_cost = 0
        for e in xrange(FLAGS.num_epochs):
            # Training.
            step, curr_loss, _ = sess.run([global_step, loss, opt])
            accum_loss += curr_loss
            if step % 100 == 0:
                print('step:%d,loss:%f' % (step, accum_loss / 100))
                accum_loss = 0

            if step % FLAGS.reset_interval == 0:
                #print('reset')
                sess.run(slot_reset)
 def test_defaults(self):
     with tf.Graph().as_default():
         a_var = tf.Variable(0)
         with monitored_session.MonitoredSession('') as session:
             self.assertEqual(0, session.run(a_var))
Exemple #27
0
def main(_):
  # Configuration.
  num_unrolls = FLAGS.num_steps // FLAGS.unroll_length

  # if FLAGS.save_path is not None:
  #   if os.path.exists(FLAGS.save_path):
  #     raise ValueError("Folder {} already exists".format(FLAGS.save_path))
  #   else:
  #     os.mkdir(FLAGS.save_path)

  # Problem.
  problem, net_config, net_assignments = util.get_config(
      FLAGS.problem, main_parade_path, first_batch_parade_path)

  # Optimizer setup.
  optimizer = meta.MetaOptimizer(**net_config)
  minimize = optimizer.meta_minimize(
      problem, FLAGS.unroll_length,
      learning_rate=FLAGS.learning_rate,
      net_assignments=net_assignments,
      second_derivatives=FLAGS.second_derivatives)
  step, update, reset, cost_op, _ = minimize

  with ms.MonitoredSession() as sess:
    # Prevent accidental changes to the graph.
    tf.get_default_graph().finalize()
    writer = tf.summary.FileWriter('summary')
    writer.add_graph(tf.get_default_graph())
    best_evaluation = float("inf")
    total_time = 0
    total_cost = 0
    for e in xrange(FLAGS.num_epochs):
      # Training.
      time, cost = util.run_epoch(sess, cost_op, [update, step], reset,
                                  num_unrolls)
      total_time += time
      total_cost += cost

      # Logging.
      if (e + 1) % FLAGS.log_period == 0:
        util.print_stats("Epoch {}".format(e + 1), total_cost, total_time,
                         FLAGS.log_period)
        total_time = 0
        total_cost = 0

      # Evaluation.
      if (e + 1) % FLAGS.evaluation_period == 0:
        eval_cost = 0
        eval_time = 0
        for _ in xrange(FLAGS.evaluation_epochs):
          time, cost = util.run_epoch(sess, cost_op, [update], reset,
                                      num_unrolls)
          eval_time += time
          eval_cost += cost

        util.print_stats("EVALUATION", eval_cost, eval_time,
                         FLAGS.evaluation_epochs)

        if FLAGS.save_path is not None and eval_cost < best_evaluation:
          print("Removing previously saved meta-optimizer")
          for f in os.listdir(FLAGS.save_path):
            os.remove(os.path.join(FLAGS.save_path, f))
          print("Saving meta-optimizer to {}".format(FLAGS.save_path))
          optimizer.save(sess, FLAGS.save_path)
          best_evaluation = eval_cost
Exemple #28
0
def _monitored_train(graph,
                     output_dir,
                     train_op,
                     loss_op,
                     global_step_tensor=None,
                     init_op=None,
                     init_feed_dict=None,
                     init_fn=None,
                     log_every_steps=10,
                     supervisor_is_chief=True,
                     supervisor_master='',
                     supervisor_save_model_secs=600,
                     keep_checkpoint_max=5,
                     supervisor_save_summaries_steps=100,
                     feed_fn=None,
                     steps=None,
                     fail_on_nan_loss=True,
                     hooks=None,
                     max_steps=None):
  """Train a model via monitored_session.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
      This feed dictionary will be used when `init_op` is evaluated.
    init_fn: Optional callable passed to Supervisor to initialize the model.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.
    supervisor_save_model_secs: Save model every
      `supervisor_save_model_secs` seconds when training.
    keep_checkpoint_max: The maximum number of recent checkpoint files to
      keep. As new files are created, older files are deleted. If None or 0,
      all checkpoint files are kept. This is simply passed as the max_to_keep
      arg to tf.Saver constructor.
    supervisor_save_summaries_steps: Save summaries every
      `supervisor_save_summaries_steps` seconds when training.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    steps: Trains for this many steps (e.g. current global step + `steps`).
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.
    hooks: List of `SessionRunHook` subclass instances. Used for callbacks
      inside the training loop.
    max_steps: Number of total steps for which to train model. If `None`,
      train forever. Two calls fit(steps=100) means 200 training iterations.
      On the other hand two calls of fit(max_steps=100) means, second call
      will not do any iteration since first call did all 100 steps.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
      is not provided. See `tf.contrib.framework.get_global_step` for how we
      look up the latter if not provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
      evaluates to `NaN`.
    ValueError: If both `steps` and `max_steps` are not `None`.
  """
  if (steps is not None) and (max_steps is not None):
    raise ValueError('Can not provide both steps and max_steps.')
  if not output_dir:
    raise ValueError('Output directory should be non-empty %s.' % output_dir)
  if train_op is None:
    raise ValueError('Missing train_op.')
  if loss_op is None:
    raise ValueError('Missing loss_op.')
  if hooks is None:
    hooks = []
  if not isinstance(hooks, list):
    raise ValueError('Hooks should be a list.')
  with graph.as_default():
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)
  if global_step_tensor is None:
    raise ValueError('No "global_step" was provided or found in the graph.')

  if max_steps is not None:
    try:
      start_step = checkpoints.load_variable(output_dir,
                                             global_step_tensor.name)
      if max_steps <= start_step:
        logging.info('Skipping training since max_steps has already saved.')
        return None
    except:  # pylint: disable=bare-except
      pass

  # Adapted SessionRunHooks such as ExportMonitor depend on the
  # CheckpointSaverHook to be executed before they should be executed.
  # The `hooks` param comprises of deprecated monitor hooks
  # (such as ExportMonitor). Appending them after the basic_session_run_hooks.
  all_hooks = []
  with graph.as_default():
    all_hooks.extend([
        basic_session_run_hooks.NanTensorHook(
            loss_op, fail_on_nan_loss=fail_on_nan_loss),
        basic_session_run_hooks.LoggingTensorHook({
            'loss': loss_op.name,
            'step': global_step_tensor.name
        }, every_n_iter=log_every_steps),
    ])

    scaffold = monitored_session.Scaffold(
        init_op=init_op,
        init_feed_dict=init_feed_dict,
        init_fn=init_fn,
        saver=tf_saver.Saver(
            sharded=True, max_to_keep=keep_checkpoint_max, defer_build=True))

    if not supervisor_is_chief:
      session_creator = monitored_session.WorkerSessionCreator(
          scaffold=scaffold,
          master=supervisor_master)
    else:
      session_creator = monitored_session.ChiefSessionCreator(
          scaffold=scaffold,
          checkpoint_dir=output_dir,
          master=supervisor_master)
      summary_writer = summary_writer_cache.SummaryWriterCache.get(output_dir)
      all_hooks.append(
          basic_session_run_hooks.StepCounterHook(
              summary_writer=summary_writer))
      all_hooks.append(
          basic_session_run_hooks.SummarySaverHook(
              save_steps=supervisor_save_summaries_steps,
              summary_writer=summary_writer,
              scaffold=scaffold))
      if supervisor_save_model_secs > 0:
        all_hooks.append(
            basic_session_run_hooks.CheckpointSaverHook(
                output_dir,
                save_secs=supervisor_save_model_secs,
                scaffold=scaffold))

    if steps is not None or max_steps is not None:
      all_hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
    all_hooks.extend(hooks)

    with monitored_session.MonitoredSession(
        session_creator=session_creator,
        hooks=all_hooks) as super_sess:
      loss = None
      while not super_sess.should_stop():
        _, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else
                                 None)
      return loss
Exemple #29
0
def main(_):
    # Configuration.
    num_unrolls = FLAGS.num_steps

    if FLAGS.seed:
        tf.set_random_seed(FLAGS.seed)

    # Problem.
    # problem, net_config, net_assignments = util.get_config(FLAGS.problem,
    #                                                        FLAGS.path)
    param_dict = {}
    param_dict['bs'] = FLAGS.bs
    param_dict['m'] = FLAGS.m
    param_dict['n'] = FLAGS.n
    print(param_dict)
    problem, net_config, net_assignments = util.get_config(
        FLAGS.problem,
        net_name="RNNprop",
        mode=FLAGS.mode,  #加入mode
        num_linear_heads=1,
        init=FLAGS.init,
        path=FLAGS.path,
        param=param_dict)

    # Optimizer setup.
    if FLAGS.optimizer == "Adam":
        cost_op = problem()
        problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        problem_reset = tf.variables_initializer(problem_vars)

        optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
        optimizer_reset = tf.variables_initializer(optimizer.get_slot_names())
        update = optimizer.minimize(cost_op)
        reset = [problem_reset, optimizer_reset]
    elif FLAGS.optimizer == "L2L":
        if FLAGS.path is None:
            logging.warning("Evaluating untrained L2L optimizer")
        optimizer = meta.MetaOptimizer(FLAGS.num_mt, FLAGS.beta1, FLAGS.beta2,
                                       **net_config)
        # meta_loss = optimizer.meta_loss(problem, 1, net_assignments=net_assignments)

        meta_loss, _, _, _, _, seq_step, \
        _, _, _, _, _, _ = optimizer.meta_loss(problem, 1, net_assignments=net_assignments)
        #这里原来是各种名字的变量的,但是似乎object never used就是指这些,那我就全部用下划线变量代替了

        _, update, reset, cost_op, _ = meta_loss

    else:
        raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with ms.MonitoredSession() as sess:
        # with tf.Session(config=config) as sess:
        sess.run(reset)
        # Prevent accidental changes to the graph.
        tf.get_default_graph().finalize()

        total_time = 0
        total_cost = 0
        loss_record = []
        for ep in xrange(FLAGS.num_epochs):
            # Training.
            time, cost = util.run_eval_epoch(sess,
                                             cost_op, [update],
                                             num_unrolls,
                                             step=seq_step)
            total_time += time

            total_cost += sum(cost) / num_unrolls
            loss_record += cost
            print(ep, cost[-1])
        # Results.
        util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost,
                         total_time, FLAGS.num_epochs)
    with open(
            '{}/{}_eval_loss_record.pickle'.format(FLAGS.path,
                                                   FLAGS.optimizer),
            'wb') as l_record:
        pickle.dump(loss_record, l_record)
    print("Saving evaluate loss record {}".format(FLAGS.path))
Exemple #30
0
def main(_):
    # Configuration.
    if FLAGS.if_cl:
        num_steps = [100, 200, 500, 1000, 1500, 2000, 2500, 3000]
        num_unrolls = [int(ns / FLAGS.unroll_length) for ns in num_steps]
        num_unrolls_eval = num_unrolls[1:]
        min_num_eval = 5
        curriculum_idx = 0
    else:
        num_unrolls = FLAGS.num_steps // FLAGS.unroll_length

    if FLAGS.save_path is not None:
        if not os.path.exists(FLAGS.save_path):
            os.mkdir(FLAGS.save_path)

    # Problem.
    problem, net_config, net_assignments = util.get_config(
        FLAGS.problem,
        mode=FLAGS.mode,  #加入mode
        num_linear_heads=1,
        init=FLAGS.init,
        path=FLAGS.path,
        param=param_dict)

    # Optimizer setup.
    optimizer = meta.MetaOptimizer(FLAGS.num_mt, **net_config)
    minimize, scale, var_x, constants, subsets,\
    loss_mt, steps_mt, update_mt, reset_mt, mt_labels, mt_inputs, hess_norm_approx = optimizer.meta_minimize(
        problem, FLAGS.unroll_length,
        learning_rate=FLAGS.learning_rate,
        net_assignments=net_assignments,
        second_derivatives=FLAGS.second_derivatives)
    step, update, reset, cost_op, _ = minimize

    if FLAGS.if_mt:
        data_mt = data_loader(problem, var_x, constants, subsets, scale,
                              FLAGS.optimizers, FLAGS.unroll_length,
                              hess_norm_approx)
        if FLAGS.if_cl:
            mt_ratios = [float(r) for r in FLAGS.mt_ratios.split()]

    p_val_x = []
    for k in var_x:
        p_val_x.append(tf.placeholder(tf.float32, shape=k.shape))
    assign_ops = [
        tf.assign(var_x[k_id], p_val_x[k_id]) for k_id in range(len(p_val_x))
    ]

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    start_time = timer()
    with ms.MonitoredSession() as sess:

        def assign_func(val_x):
            sess.run(assign_ops,
                     feed_dict={p: v
                                for p, v in zip(p_val_x, val_x)})

        tf.get_default_graph().finalize()

        best_evaluation = float("inf")
        train_loss_record = []
        eval_loss_record = []
        num_steps_train = []
        num_eval = 0
        improved = False
        mti = -1
        task_id_record = []
        for e in xrange(FLAGS.num_epochs):
            # choose task
            if FLAGS.if_mt:
                if FLAGS.if_cl:
                    mt_ratio = mt_ratios[curriculum_idx]
                else:
                    mt_ratio = FLAGS.mt_ratio
                if random.random() < mt_ratio:
                    mti = (mti + 1) % FLAGS.num_mt
                    task_i = mti
                else:
                    task_i = -1
                task_id_record.append(task_i)
            else:
                task_i = -1
            # Training.
            if FLAGS.if_cl:
                num_unrolls_cur = num_unrolls[curriculum_idx]
            else:
                num_unrolls_cur = num_unrolls
            if task_i == -1:
                time, cost = util.run_epoch(
                    sess,
                    cost_op, [update, step],
                    reset,
                    num_unrolls_cur,
                    scale=scale,
                    rd_scale=FLAGS.if_scale,
                    rd_scale_bound=FLAGS.rd_scale_bound,
                    assign_func=assign_func,
                    var_x=var_x,
                    if_hess_init=FLAGS.init == "hessian",
                    hess_norm=hess_norm_approx)
            else:
                data_e = data_mt.get_data(task_i,
                                          sess,
                                          num_unrolls_cur,
                                          assign_func,
                                          FLAGS.rd_scale_bound,
                                          if_scale=FLAGS.if_scale,
                                          mt_k=FLAGS.k,
                                          if_hess_init=FLAGS.init == "hessian")
                time, cost = util.run_epoch(
                    sess,
                    loss_mt[task_i], [update_mt[task_i], steps_mt[task_i]],
                    reset_mt[task_i],
                    num_unrolls_cur,
                    scale=scale,
                    rd_scale=FLAGS.if_scale,
                    rd_scale_bound=FLAGS.rd_scale_bound,
                    assign_func=assign_func,
                    var_x=var_x,
                    task_i=task_i,
                    data=data_e,
                    label_pl=mt_labels[task_i],
                    input_pl=mt_inputs[task_i])
            train_loss_record.append(cost)

            # Evaluation.
            if (e + 1) % FLAGS.evaluation_period == 0:
                if FLAGS.if_cl:
                    num_unrolls_eval_cur = num_unrolls_eval[curriculum_idx]
                else:
                    num_unrolls_eval_cur = num_unrolls
                num_eval += 1

                eval_cost = 0
                for _ in xrange(FLAGS.evaluation_epochs):
                    time, cost = util.run_epoch(sess, cost_op, [update], reset,
                                                num_unrolls_eval_cur)
                    eval_cost += cost

                if FLAGS.if_cl:
                    num_steps_cur = num_steps[curriculum_idx]
                else:
                    num_steps_cur = FLAGS.num_steps
                print("epoch={}, num_steps={}, eval loss={}".format(
                    e, num_steps_cur, eval_cost / FLAGS.evaluation_epochs),
                      flush=True)
                eval_loss_record.append(eval_cost / FLAGS.evaluation_epochs)
                num_steps_train.append(num_steps_cur)

                if not FLAGS.if_cl:
                    if eval_cost < best_evaluation:
                        best_evaluation = eval_cost
                        optimizer.save(sess, FLAGS.save_path, e + 1)
                        optimizer.save(sess, FLAGS.save_path, 0)
                        print("Saving optimizer...")
                    continue

                # update curriculum
                if eval_cost < best_evaluation:
                    best_evaluation = eval_cost
                    improved = True
                    # save model
                    optimizer.save(sess, FLAGS.save_path, curriculum_idx)
                    optimizer.save(sess, FLAGS.save_path, 0)
                elif num_eval >= min_num_eval and improved:
                    # restore model
                    optimizer.restore(sess, FLAGS.save_path, curriculum_idx)
                    num_eval = 0
                    improved = False
                    curriculum_idx += 1
                    if curriculum_idx >= len(num_unrolls):
                        curriculum_idx = -1
                    # new evaluation
                    eval_cost = 0
                    for _ in xrange(FLAGS.evaluation_epochs):
                        time, cost = util.run_epoch(
                            sess, cost_op, [update], reset,
                            num_unrolls_eval[curriculum_idx])
                        eval_cost += cost
                    best_evaluation = eval_cost
                    print("epoch={}, num_steps={}, eval loss={}".format(
                        e, num_steps[curriculum_idx],
                        eval_cost / FLAGS.evaluation_epochs),
                          flush=True)
                    eval_loss_record.append(eval_cost /
                                            FLAGS.evaluation_epochs)
                    num_steps_train.append(num_steps[curriculum_idx])
                elif num_eval >= min_num_eval and not improved:
                    print("no improve during curriculum {} --> stop".format(
                        curriculum_idx))
                    break

        print("total time = {}s...".format(timer() - start_time))
        # output
        with open('{}/log.pickle'.format(FLAGS.save_path), 'wb') as l_record:
            records = {
                "eval_loss": eval_loss_record,
                "train_loss": train_loss_record,
                "task_id": task_id_record,
                "num_steps": num_steps_train
            }
            pickle.dump(records, l_record)
            l_record.close()